diff --git a/pyproject.toml b/pyproject.toml index f2ad130..544821d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "aiproxyguard-python-sdk" -version = "0.1.0" +version = "0.2.0" description = "Official Python SDK for AIProxyGuard - LLM security proxy for prompt injection detection" readme = "README.md" license = "Apache-2.0" diff --git a/src/aiproxyguard/__init__.py b/src/aiproxyguard/__init__.py index 60c4e09..f154770 100644 --- a/src/aiproxyguard/__init__.py +++ b/src/aiproxyguard/__init__.py @@ -8,6 +8,7 @@ ... print(f"Blocked: {result.category}") """ +from ._version import __version__ from .client import AIProxyGuard, ApiMode from .decorators import GuardConfigurationError, guard, guard_output from .exceptions import ( @@ -23,15 +24,16 @@ Action, CheckResult, CloudCheckResult, + FeedbackResult, HealthStatus, ReadyStatus, ServiceInfo, ThreatDetail, ) -__version__ = "0.1.0" - __all__ = [ + # Version + "__version__", # Client "AIProxyGuard", "ApiMode", @@ -39,6 +41,7 @@ "Action", "CheckResult", "CloudCheckResult", + "FeedbackResult", "HealthStatus", "ReadyStatus", "ServiceInfo", diff --git a/src/aiproxyguard/_version.py b/src/aiproxyguard/_version.py new file mode 100644 index 0000000..90931a1 --- /dev/null +++ b/src/aiproxyguard/_version.py @@ -0,0 +1,3 @@ +"""Package version.""" + +__version__ = "0.2.0" diff --git a/src/aiproxyguard/client.py b/src/aiproxyguard/client.py index fa68a15..3cac035 100644 --- a/src/aiproxyguard/client.py +++ b/src/aiproxyguard/client.py @@ -11,6 +11,7 @@ import httpx +from ._version import __version__ from .exceptions import ( AIProxyGuardError, ConnectionError, @@ -22,6 +23,7 @@ from .models import ( CheckResult, CloudCheckResult, + FeedbackResult, HealthStatus, ReadyStatus, ServiceInfo, @@ -145,7 +147,11 @@ def set_api_key(self, api_key: str | None) -> None: def _get_headers(self) -> dict[str, str]: """Build request headers.""" - headers = {"Content-Type": "application/json"} + headers = { + "Content-Type": "application/json", + "X-SDK-Version": __version__, + "X-SDK-Type": "python-sdk", + } if self._api_key: headers["X-API-Key"] = self._api_key return headers @@ -176,13 +182,36 @@ def _truncate_error_text(self, text: str) -> str: return text return text[:_MAX_ERROR_TEXT_LENGTH] + "..." + def _parse_retry_after(self, value: str | None) -> int | None: + """Parse Retry-After header (integer seconds or HTTP-date).""" + if not value: + return None + # Try integer seconds first (most common) + try: + return int(value) + except ValueError: + pass + # Try HTTP-date format (e.g., "Wed, 21 Oct 2015 07:28:00 GMT") + try: + from email.utils import parsedate_to_datetime + + retry_dt = parsedate_to_datetime(value) + from datetime import datetime, timezone + + now = datetime.now(timezone.utc) + delta = (retry_dt - now).total_seconds() + return max(0, int(delta)) + except (ValueError, TypeError): + # Invalid format, ignore + return None + def _handle_error(self, response: httpx.Response) -> None: """Handle error responses from the API.""" if response.status_code == 429: - retry_after = response.headers.get("Retry-After") + retry_after = self._parse_retry_after(response.headers.get("Retry-After")) raise RateLimitError( "Rate limited", - retry_after=int(retry_after) if retry_after else None, + retry_after=retry_after, ) # 5xx errors are server errors (retryable) @@ -363,6 +392,47 @@ def do_check() -> CloudCheckResult: return self._retry_sync(do_check) + def feedback( + self, + check_id: str, + feedback: str, + comment: str | None = None, + ) -> FeedbackResult: + """Submit feedback for a check result (cloud mode only). + + Use this to report false positives or confirm correct detections, + which helps improve detection accuracy over time. + + Args: + check_id: The check ID from CloudCheckResult.id. + feedback: Either "confirmed" (correct detection) or "false_positive". + comment: Optional comment explaining the feedback. + + Returns: + FeedbackResult confirming the feedback was recorded. + + Raises: + AIProxyGuardError: If not in cloud mode or request fails. + ValidationError: If check_id not found or invalid feedback value. + """ + if self._api_mode != ApiMode.CLOUD: + raise AIProxyGuardError("feedback() requires cloud API mode") + + if feedback not in ("confirmed", "false_positive"): + raise ValidationError("feedback must be 'confirmed' or 'false_positive'") + + client = self._get_client() + payload: dict[str, Any] = {"check_id": check_id, "feedback": feedback} + if comment: + payload["comment"] = comment + + def do_feedback() -> FeedbackResult: + response = client.post("/api/v1/feedback", json=payload) + self._handle_error(response) + return FeedbackResult.from_dict(response.json()) + + return self._retry_sync(do_feedback) + def check_batch(self, texts: list[str]) -> list[CheckResult]: """Check multiple texts for prompt injection. @@ -479,6 +549,31 @@ async def do_check() -> CloudCheckResult: return await self._retry_async(do_check) + async def feedback_async( + self, + check_id: str, + feedback: str, + comment: str | None = None, + ) -> FeedbackResult: + """Async version of feedback().""" + if self._api_mode != ApiMode.CLOUD: + raise AIProxyGuardError("feedback_async() requires cloud API mode") + + if feedback not in ("confirmed", "false_positive"): + raise ValidationError("feedback must be 'confirmed' or 'false_positive'") + + client = self._get_async_client() + payload: dict[str, Any] = {"check_id": check_id, "feedback": feedback} + if comment: + payload["comment"] = comment + + async def do_feedback() -> FeedbackResult: + response = await client.post("/api/v1/feedback", json=payload) + self._handle_error(response) + return FeedbackResult.from_dict(response.json()) + + return await self._retry_async(do_feedback) + async def check_batch_async( self, texts: list[str], max_concurrency: int | None = None ) -> list[CheckResult]: @@ -561,6 +656,22 @@ async def __aenter__(self) -> AIProxyGuard: async def __aexit__(self, *args: Any) -> None: await self.aclose() + def _close_async_client_sync(self, client: httpx.AsyncClient) -> None: + """Close an async client from a sync context.""" + if client.is_closed: + return + try: + # If there's a running loop, schedule the close + loop = asyncio.get_running_loop() + loop.create_task(client.aclose()) + except RuntimeError: + # No running loop - create a temporary one to close properly + loop = asyncio.new_event_loop() + try: + loop.run_until_complete(client.aclose()) + finally: + loop.close() + def close(self) -> None: """Close all clients and release resources. @@ -573,30 +684,11 @@ def close(self) -> None: self._client = None # Close pending async client from set_api_key() if self._pending_async_close: - try: - # Best effort - httpx AsyncClient has _closed flag we can check - if not getattr(self._pending_async_close, "_closed", True): - # Create event loop if needed for cleanup - try: - loop = asyncio.get_running_loop() - loop.create_task(self._pending_async_close.aclose()) - except RuntimeError: - # No event loop - just let it be garbage collected - pass - except Exception: - pass + self._close_async_client_sync(self._pending_async_close) self._pending_async_close = None # Close current async client if it exists if self._async_client: - try: - if not getattr(self._async_client, "_closed", True): - try: - loop = asyncio.get_running_loop() - loop.create_task(self._async_client.aclose()) - except RuntimeError: - pass - except Exception: - pass + self._close_async_client_sync(self._async_client) self._async_client = None async def aclose(self) -> None: diff --git a/src/aiproxyguard/models.py b/src/aiproxyguard/models.py index 0b8dfee..bb66cd8 100644 --- a/src/aiproxyguard/models.py +++ b/src/aiproxyguard/models.py @@ -160,6 +160,30 @@ def from_dict(cls, data: dict[str, Any]) -> CloudCheckResult: ) +@dataclass(frozen=True) +class FeedbackResult: + """Result from submitting feedback for a check. + + Attributes: + success: Whether the feedback was submitted successfully. + check_id: The check ID that was updated. + feedback: The feedback value that was recorded. + """ + + success: bool + check_id: str + feedback: str + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> FeedbackResult: + """Create FeedbackResult from API response dictionary.""" + return cls( + success=data["success"], + check_id=data["check_id"], + feedback=data["feedback"], + ) + + @dataclass(frozen=True) class ServiceInfo: """Service information from the AIProxyGuard API. diff --git a/tests/test_client.py b/tests/test_client.py index aa18a74..342a475 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -802,6 +802,94 @@ def test_check_with_context(self): assert '"context":' in request_body assert '"provider":' in request_body + def test_feedback_requires_cloud_mode(self, base_url): + """Test feedback raises in proxy mode.""" + from aiproxyguard import AIProxyGuardError + + client = AIProxyGuard(base_url) + with pytest.raises(AIProxyGuardError) as exc: + client.feedback("chk_123", "confirmed") + assert "requires cloud API mode" in str(exc.value) + + def test_feedback_validates_feedback_value(self): + """Test feedback validates feedback parameter.""" + client = AIProxyGuard( + "https://aiproxyguard.com", api_key="test", api_mode="cloud" + ) + with pytest.raises(ValidationError) as exc: + client.feedback("chk_123", "invalid") + assert "confirmed" in str(exc.value) or "false_positive" in str(exc.value) + + def test_feedback_sync(self): + """Test feedback submits correctly.""" + feedback_response = { + "success": True, + "check_id": "chk_123", + "feedback": "confirmed", + } + with respx.mock(base_url="https://aiproxyguard.com") as mock: + route = mock.post("/api/v1/feedback").respond(200, json=feedback_response) + + client = AIProxyGuard( + "https://aiproxyguard.com", api_key="test", api_mode="cloud" + ) + result = client.feedback("chk_123", "confirmed") + + assert result.success is True + assert result.check_id == "chk_123" + assert result.feedback == "confirmed" + assert route.called + + def test_feedback_with_comment(self): + """Test feedback with optional comment.""" + feedback_response = { + "success": True, + "check_id": "chk_456", + "feedback": "false_positive", + } + with respx.mock(base_url="https://aiproxyguard.com") as mock: + route = mock.post("/api/v1/feedback").respond(200, json=feedback_response) + + client = AIProxyGuard( + "https://aiproxyguard.com", api_key="test", api_mode="cloud" + ) + result = client.feedback( + "chk_456", "false_positive", comment="This was a normal question" + ) + + assert result.success is True + request_body = route.calls[0].request.content.decode() + assert '"comment":' in request_body + + @pytest.mark.asyncio + async def test_feedback_async(self): + """Test feedback_async submits correctly.""" + feedback_response = { + "success": True, + "check_id": "chk_789", + "feedback": "confirmed", + } + with respx.mock(base_url="https://aiproxyguard.com") as mock: + mock.post("/api/v1/feedback").respond(200, json=feedback_response) + + async with AIProxyGuard( + "https://aiproxyguard.com", api_key="test", api_mode="cloud" + ) as client: + result = await client.feedback_async("chk_789", "confirmed") + + assert result.success is True + assert result.check_id == "chk_789" + + @pytest.mark.asyncio + async def test_feedback_async_requires_cloud_mode(self, base_url): + """Test feedback_async raises in proxy mode.""" + from aiproxyguard import AIProxyGuardError + + client = AIProxyGuard(base_url) + with pytest.raises(AIProxyGuardError) as exc: + await client.feedback_async("chk_123", "confirmed") + assert "requires cloud API mode" in str(exc.value) + class TestAsyncRetryPaths: """Tests for async retry error handling paths.""" @@ -935,6 +1023,45 @@ def test_rate_limit_with_retry_after(self, base_url, allow_response): assert result.is_safe + def test_rate_limit_with_invalid_retry_after(self, base_url, allow_response): + """Test rate limit with invalid Retry-After is handled gracefully.""" + with respx.mock(base_url=base_url) as mock: + mock.post("/check").side_effect = [ + httpx.Response(429, headers={"Retry-After": "invalid-value"}), + httpx.Response(200, json=allow_response), + ] + + client = AIProxyGuard(base_url, retry_delay=0.01) + result = client.check("test") + + assert result.is_safe + + def test_parse_retry_after_integer(self, base_url): + """Test _parse_retry_after with integer seconds.""" + client = AIProxyGuard(base_url) + assert client._parse_retry_after("60") == 60 + assert client._parse_retry_after("0") == 0 + assert client._parse_retry_after(None) is None + + def test_parse_retry_after_http_date(self, base_url): + """Test _parse_retry_after with HTTP-date format.""" + from datetime import datetime, timedelta, timezone + + client = AIProxyGuard(base_url) + # Create a date 30 seconds in the future + future = datetime.now(timezone.utc) + timedelta(seconds=30) + date_str = future.strftime("%a, %d %b %Y %H:%M:%S GMT") + result = client._parse_retry_after(date_str) + # Should be approximately 30 seconds (allow some tolerance) + assert result is not None + assert 25 <= result <= 35 + + def test_parse_retry_after_invalid(self, base_url): + """Test _parse_retry_after with invalid format returns None.""" + client = AIProxyGuard(base_url) + assert client._parse_retry_after("not-a-number-or-date") is None + assert client._parse_retry_after("") is None + class TestSetApiKeyAsyncCleanup: """Tests for set_api_key async client cleanup.""" @@ -956,18 +1083,23 @@ async def test_set_api_key_tracks_async_client(self, base_url, allow_response): def test_close_handles_pending_async(self, base_url, allow_response): """Test close() handles pending async client.""" + import asyncio + with respx.mock(base_url=base_url) as mock: mock.post("/check").respond(200, json=allow_response) client = AIProxyGuard(base_url, api_key="key") - # Manually set a pending client to test cleanup - import asyncio + # Create async client and pending close using a fresh event loop async def create_client(): await client.check_async("test") client.set_api_key("new") - asyncio.get_event_loop().run_until_complete(create_client()) + loop = asyncio.new_event_loop() + try: + loop.run_until_complete(create_client()) + finally: + loop.close() # close() should handle pending async client client.close() diff --git a/tests/test_models.py b/tests/test_models.py index c7bfd3d..3667e95 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -6,6 +6,7 @@ Action, CheckResult, CloudCheckResult, + FeedbackResult, HealthStatus, ReadyStatus, ServiceInfo, @@ -305,3 +306,30 @@ def test_from_cloud_dict_no_threats(self): assert result.category is None assert result.signature_name is None assert result.confidence == 0.0 + + +class TestFeedbackResult: + """Tests for FeedbackResult model.""" + + def test_from_dict(self): + """Test creating FeedbackResult from API response.""" + data = { + "success": True, + "check_id": "chk_123", + "feedback": "confirmed", + } + result = FeedbackResult.from_dict(data) + assert result.success is True + assert result.check_id == "chk_123" + assert result.feedback == "confirmed" + + def test_from_dict_false_positive(self): + """Test creating FeedbackResult for false positive.""" + data = { + "success": True, + "check_id": "chk_456", + "feedback": "false_positive", + } + result = FeedbackResult.from_dict(data) + assert result.success is True + assert result.feedback == "false_positive"