Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "hatchling.build"

[project]
name = "aiproxyguard-python-sdk"
version = "0.1.0"
version = "0.2.0"
description = "Official Python SDK for AIProxyGuard - LLM security proxy for prompt injection detection"
readme = "README.md"
license = "Apache-2.0"
Expand Down
7 changes: 5 additions & 2 deletions src/aiproxyguard/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
... print(f"Blocked: {result.category}")
"""

from ._version import __version__
from .client import AIProxyGuard, ApiMode
from .decorators import GuardConfigurationError, guard, guard_output
from .exceptions import (
Expand All @@ -23,22 +24,24 @@
Action,
CheckResult,
CloudCheckResult,
FeedbackResult,
HealthStatus,
ReadyStatus,
ServiceInfo,
ThreatDetail,
)

__version__ = "0.1.0"

__all__ = [
# Version
"__version__",
# Client
"AIProxyGuard",
"ApiMode",
# Models
"Action",
"CheckResult",
"CloudCheckResult",
"FeedbackResult",
"HealthStatus",
"ReadyStatus",
"ServiceInfo",
Expand Down
3 changes: 3 additions & 0 deletions src/aiproxyguard/_version.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
"""Package version."""

__version__ = "0.2.0"
140 changes: 116 additions & 24 deletions src/aiproxyguard/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import httpx

from ._version import __version__
from .exceptions import (
AIProxyGuardError,
ConnectionError,
Expand All @@ -22,6 +23,7 @@
from .models import (
CheckResult,
CloudCheckResult,
FeedbackResult,
HealthStatus,
ReadyStatus,
ServiceInfo,
Expand Down Expand Up @@ -145,7 +147,11 @@ def set_api_key(self, api_key: str | None) -> None:

def _get_headers(self) -> dict[str, str]:
"""Build request headers."""
headers = {"Content-Type": "application/json"}
headers = {
"Content-Type": "application/json",
"X-SDK-Version": __version__,
"X-SDK-Type": "python-sdk",
}
if self._api_key:
headers["X-API-Key"] = self._api_key
return headers
Expand Down Expand Up @@ -176,13 +182,36 @@ def _truncate_error_text(self, text: str) -> str:
return text
return text[:_MAX_ERROR_TEXT_LENGTH] + "..."

def _parse_retry_after(self, value: str | None) -> int | None:
"""Parse Retry-After header (integer seconds or HTTP-date)."""
if not value:
return None
# Try integer seconds first (most common)
try:
return int(value)
except ValueError:
pass
# Try HTTP-date format (e.g., "Wed, 21 Oct 2015 07:28:00 GMT")
try:
from email.utils import parsedate_to_datetime

retry_dt = parsedate_to_datetime(value)
from datetime import datetime, timezone

now = datetime.now(timezone.utc)
delta = (retry_dt - now).total_seconds()
return max(0, int(delta))
except (ValueError, TypeError):
# Invalid format, ignore
return None

def _handle_error(self, response: httpx.Response) -> None:
"""Handle error responses from the API."""
if response.status_code == 429:
retry_after = response.headers.get("Retry-After")
retry_after = self._parse_retry_after(response.headers.get("Retry-After"))
raise RateLimitError(
"Rate limited",
retry_after=int(retry_after) if retry_after else None,
retry_after=retry_after,
)

# 5xx errors are server errors (retryable)
Expand Down Expand Up @@ -363,6 +392,47 @@ def do_check() -> CloudCheckResult:

return self._retry_sync(do_check)

def feedback(
self,
check_id: str,
feedback: str,
comment: str | None = None,
) -> FeedbackResult:
"""Submit feedback for a check result (cloud mode only).

Use this to report false positives or confirm correct detections,
which helps improve detection accuracy over time.

Args:
check_id: The check ID from CloudCheckResult.id.
feedback: Either "confirmed" (correct detection) or "false_positive".
comment: Optional comment explaining the feedback.

Returns:
FeedbackResult confirming the feedback was recorded.

Raises:
AIProxyGuardError: If not in cloud mode or request fails.
ValidationError: If check_id not found or invalid feedback value.
"""
if self._api_mode != ApiMode.CLOUD:
raise AIProxyGuardError("feedback() requires cloud API mode")

if feedback not in ("confirmed", "false_positive"):
raise ValidationError("feedback must be 'confirmed' or 'false_positive'")

client = self._get_client()
payload: dict[str, Any] = {"check_id": check_id, "feedback": feedback}
if comment:
payload["comment"] = comment

def do_feedback() -> FeedbackResult:
response = client.post("/api/v1/feedback", json=payload)
self._handle_error(response)
return FeedbackResult.from_dict(response.json())

return self._retry_sync(do_feedback)

def check_batch(self, texts: list[str]) -> list[CheckResult]:
"""Check multiple texts for prompt injection.

Expand Down Expand Up @@ -479,6 +549,31 @@ async def do_check() -> CloudCheckResult:

return await self._retry_async(do_check)

async def feedback_async(
self,
check_id: str,
feedback: str,
comment: str | None = None,
) -> FeedbackResult:
"""Async version of feedback()."""
if self._api_mode != ApiMode.CLOUD:
raise AIProxyGuardError("feedback_async() requires cloud API mode")

if feedback not in ("confirmed", "false_positive"):
raise ValidationError("feedback must be 'confirmed' or 'false_positive'")

client = self._get_async_client()
payload: dict[str, Any] = {"check_id": check_id, "feedback": feedback}
if comment:
payload["comment"] = comment

async def do_feedback() -> FeedbackResult:
response = await client.post("/api/v1/feedback", json=payload)
self._handle_error(response)
return FeedbackResult.from_dict(response.json())

return await self._retry_async(do_feedback)

async def check_batch_async(
self, texts: list[str], max_concurrency: int | None = None
) -> list[CheckResult]:
Expand Down Expand Up @@ -561,6 +656,22 @@ async def __aenter__(self) -> AIProxyGuard:
async def __aexit__(self, *args: Any) -> None:
await self.aclose()

def _close_async_client_sync(self, client: httpx.AsyncClient) -> None:
"""Close an async client from a sync context."""
if client.is_closed:
return
try:
# If there's a running loop, schedule the close
loop = asyncio.get_running_loop()
loop.create_task(client.aclose())
except RuntimeError:
# No running loop - create a temporary one to close properly
loop = asyncio.new_event_loop()
try:
loop.run_until_complete(client.aclose())
finally:
loop.close()

def close(self) -> None:
"""Close all clients and release resources.

Expand All @@ -573,30 +684,11 @@ def close(self) -> None:
self._client = None
# Close pending async client from set_api_key()
if self._pending_async_close:
try:
# Best effort - httpx AsyncClient has _closed flag we can check
if not getattr(self._pending_async_close, "_closed", True):
# Create event loop if needed for cleanup
try:
loop = asyncio.get_running_loop()
loop.create_task(self._pending_async_close.aclose())
except RuntimeError:
# No event loop - just let it be garbage collected
pass
except Exception:
pass
self._close_async_client_sync(self._pending_async_close)
self._pending_async_close = None
# Close current async client if it exists
if self._async_client:
try:
if not getattr(self._async_client, "_closed", True):
try:
loop = asyncio.get_running_loop()
loop.create_task(self._async_client.aclose())
except RuntimeError:
pass
except Exception:
pass
self._close_async_client_sync(self._async_client)
self._async_client = None

async def aclose(self) -> None:
Expand Down
24 changes: 24 additions & 0 deletions src/aiproxyguard/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,30 @@ def from_dict(cls, data: dict[str, Any]) -> CloudCheckResult:
)


@dataclass(frozen=True)
class FeedbackResult:
"""Result from submitting feedback for a check.

Attributes:
success: Whether the feedback was submitted successfully.
check_id: The check ID that was updated.
feedback: The feedback value that was recorded.
"""

success: bool
check_id: str
feedback: str

@classmethod
def from_dict(cls, data: dict[str, Any]) -> FeedbackResult:
"""Create FeedbackResult from API response dictionary."""
return cls(
success=data["success"],
check_id=data["check_id"],
feedback=data["feedback"],
)


@dataclass(frozen=True)
class ServiceInfo:
"""Service information from the AIProxyGuard API.
Expand Down
Loading
Loading