diff --git a/sdk_v2/python/pyproject.toml b/sdk_v2/python/pyproject.toml index 7b6e7478..3d078f8a 100644 --- a/sdk_v2/python/pyproject.toml +++ b/sdk_v2/python/pyproject.toml @@ -66,7 +66,7 @@ Source = "https://github.com/microsoft/Foundry-Local" [project.optional-dependencies] numpy = ["numpy>=1.23"] openai = ["openai>=1.0"] -dev = ["pytest", "pytest-cov", "mypy", "ruff"] +dev = ["pytest", "pytest-asyncio", "pytest-cov", "mypy", "ruff"] [project.scripts] foundry-local-install = "foundry_local_sdk._native.installer:main" @@ -97,6 +97,7 @@ pythonpath = ["src"] python_files = ["test_*.py"] python_classes = ["Test*"] python_functions = ["test_*"] +asyncio_mode = "auto" addopts = "-m 'not manual'" markers = [ "manual: destructive tests requiring manual setup; skipped by default. Opt in with -m manual.", diff --git a/sdk_v2/python/src/foundry_local_sdk/foundry_local_manager.py b/sdk_v2/python/src/foundry_local_sdk/foundry_local_manager.py index 0a8d3375..da249340 100644 --- a/sdk_v2/python/src/foundry_local_sdk/foundry_local_manager.py +++ b/sdk_v2/python/src/foundry_local_sdk/foundry_local_manager.py @@ -4,8 +4,8 @@ # -------------------------------------------------------------------------- from __future__ import annotations -import threading -from typing import Callable +import asyncio +from typing import AsyncGenerator, Callable from foundry_local_sdk.catalog import Catalog from foundry_local_sdk.configuration import Configuration @@ -17,7 +17,7 @@ class FoundryLocalManager: """Singleton manager for Foundry Local SDK operations. - Call ``FoundryLocalManager.initialize(config)`` once at startup, then + Call ``await FoundryLocalManager.initialize(config)`` once at startup, then access the singleton via ``FoundryLocalManager.instance``. Attributes: @@ -26,11 +26,11 @@ class FoundryLocalManager: urls: Bound URL(s) after ``start_web_service()`` is called, or ``None``. """ - _lock: threading.Lock = threading.Lock() + _lock: asyncio.Lock = asyncio.Lock() instance: FoundryLocalManager | None = None @staticmethod - def initialize(config: Configuration) -> None: + async def initialize(config: Configuration) -> None: """Initialize the Foundry Local SDK with the given configuration. Must be called before using any other part of the SDK. @@ -38,22 +38,51 @@ def initialize(config: Configuration) -> None: Args: config: Configuration object for the SDK. """ - FoundryLocalManager(config) + manager = FoundryLocalManager(config) + await manager._initialize() def __init__(self, config: Configuration) -> None: - # Declared up front so close() / __del__ can run safely even if + # Declared up front so close() can run safely even if # _initialize() raises before the native handle is assigned. self._native_manager: object | None = None self.urls: list[str] | None = None + self.config = config - with FoundryLocalManager._lock: + async def _initialize(self) -> None: + """Async initialization logic.""" + async with FoundryLocalManager._lock: if FoundryLocalManager.instance is not None: raise FoundryLocalException( "FoundryLocalManager is a singleton and has already been initialized." ) - config.validate() - self.config = config - self._initialize() + self.config.validate() + + def _init(): + from foundry_local_sdk._native.api import api, ffi + + if self.config.log_level is not None: + set_default_logger_severity(self.config.log_level) + + native_config = self.config._build_native() + try: + mgr_out = ffi.new("flManager**") + api.check_status(api.root.Manager_Create(native_config, mgr_out)) + self._native_manager = mgr_out[0] + finally: + api.config.Configuration_Release(native_config) + + try: + cat_out = ffi.new("flCatalog**") + api.check_status(api.root.Manager_GetCatalog(self._native_manager, cat_out)) + self.catalog = Catalog(cat_out[0], parent=self) + except BaseException: + try: + api.root.Manager_Release(self._native_manager) + finally: + self._native_manager = None + raise + + await asyncio.to_thread(_init) FoundryLocalManager.instance = self # Register an interpreter-shutdown hook to release the native Manager @@ -62,71 +91,63 @@ def __init__(self, config: Configuration) -> None: # (Manager → SpdlogLogger → spdlog::async_logger::flush) can fire # after spdlog's global thread pool has already been destroyed, # raising std::system_error("mutex lock failed") and aborting the - # process. atexit guarantees we run before any of that. + # process. import atexit - atexit.register(self.close) - def _initialize(self) -> None: - from foundry_local_sdk._native.api import api, ffi - - # Only push a log level into the native side if the caller actually picked one. ``None`` means "use - # whatever default the native runtime decides" — forwarding ``None`` would only work by accident - # through ``dict.get(None, default)``. - if self.config.log_level is not None: - set_default_logger_severity(self.config.log_level) - - native_config = self.config._build_native() - try: - mgr_out = ffi.new("flManager**") - api.check_status(api.root.Manager_Create(native_config, mgr_out)) - self._native_manager = mgr_out[0] - finally: - # Manager_Create takes a const flConfiguration* — it copies what it needs. - # We own the config handle and release it now. - api.config.Configuration_Release(native_config) - - try: - cat_out = ffi.new("flCatalog**") - api.check_status(api.root.Manager_GetCatalog(self._native_manager, cat_out)) - self.catalog = Catalog(cat_out[0], parent=self) - except BaseException: - # Catalog fetch failed; release the manager handle to avoid leaking it. + def _atexit_cleanup(mgr_ref=self) -> None: + # Called during interpreter shutdown — asyncio.to_thread is unavailable at that point. + # Use the blocking native calls directly instead of going through close(). + if mgr_ref._native_manager is None: + return try: - api.root.Manager_Release(self._native_manager) + from foundry_local_sdk._native.api import api + try: + api.check_status(api.root.Manager_Shutdown(mgr_ref._native_manager)) + except Exception: + pass + api.root.Manager_Release(mgr_ref._native_manager) + except Exception: + pass finally: - self._native_manager = None - raise + mgr_ref._native_manager = None + if FoundryLocalManager.instance is mgr_ref: + FoundryLocalManager.instance = None + + atexit.register(_atexit_cleanup) # ------------------------------------------------------------------ # EP discovery and registration # ------------------------------------------------------------------ - def discover_eps(self) -> list[EpInfo]: + async def discover_eps(self) -> list[EpInfo]: """Discover available execution providers and their registration status. Returns: List of ``EpInfo`` entries for all discoverable EPs. """ - from foundry_local_sdk._native.api import api, ffi + def _discover(): + from foundry_local_sdk._native.api import api, ffi - eps_out = ffi.new("flEpInfo**") - count_out = ffi.new("size_t*") - api.check_status( - api.root.Manager_GetDiscoverableEps( - self._native_manager, eps_out, count_out + eps_out = ffi.new("flEpInfo**") + count_out = ffi.new("size_t*") + api.check_status( + api.root.Manager_GetDiscoverableEps( + self._native_manager, eps_out, count_out + ) ) - ) - - count = int(count_out[0]) - result: list[EpInfo] = [] - for i in range(count): - entry = eps_out[0][i] - name = ffi.string(entry.name).decode("utf-8") - is_reg = bool(entry.is_registered) - result.append(EpInfo(name=name, is_registered=is_reg)) - return result - - def download_and_register_eps( + + count = int(count_out[0]) + result: list[EpInfo] = [] + for i in range(count): + entry = eps_out[0][i] + name = ffi.string(entry.name).decode("utf-8") + is_reg = bool(entry.is_registered) + result.append(EpInfo(name=name, is_registered=is_reg)) + return result + + return await asyncio.to_thread(_discover) + + async def download_and_register_eps( self, names: list[str] | None = None, progress_callback: Callable[[str, float], None] | None = None, @@ -142,94 +163,104 @@ def download_and_register_eps( Returns: ``EpDownloadResult`` describing operation status and per-EP outcomes. """ - # An empty list is treated as "download all" (same as None) for consistency across language bindings. if names is not None and len(names) == 0: names = None - from foundry_local_sdk._native.api import api, ffi - - # Snapshot before-state to compute the delta of newly registered EPs. - before_eps: dict[str, bool] = {ep.name: ep.is_registered for ep in self.discover_eps()} - - # Build the native EP-names array (or NULL to download all). - if names is not None: - # Keep encoded byte strings alive for the duration of the call. - c_name_bufs = [ffi.new("char[]", n.encode("utf-8")) for n in names] - c_names_arr = ffi.new("const char*[]", c_name_bufs) - num_names = len(names) - else: - c_names_arr = ffi.NULL - num_names = 0 - - # Build the progress callback trampoline if requested. - cb = ffi.NULL - ud = ffi.NULL - if progress_callback is not None: - self._ep_cb_handle = ffi.new_handle(progress_callback) - - # Use the cdef typedef rather than an inline signature: the API-mode - # runtime parser in _cffi_backend does not accept `const`, but the - # cdef'd typedef preserves it. This keeps the const char* contract - # the C ABI exposes (and matches the v1 SDK). - @ffi.callback("flEpProgressCallback") - def _ep_cb(ep_name_ptr: object, value: float, user_data: object) -> int: - try: - fn = ffi.from_handle(user_data) - ep_name = ( - ffi.string(ep_name_ptr).decode("utf-8") - if ep_name_ptr != ffi.NULL - else "" + def _download_and_register(): + from foundry_local_sdk._native.api import api, ffi + + # Helper to discover EPs for before/after comparison + def _discover_eps_sync(): + eps_out = ffi.new("flEpInfo**") + count_out = ffi.new("size_t*") + api.check_status( + api.root.Manager_GetDiscoverableEps( + self._native_manager, eps_out, count_out ) - fn(ep_name, float(value)) - return 0 - except Exception: - return 1 + ) + count = int(count_out[0]) + result: list[EpInfo] = [] + for i in range(count): + entry = eps_out[0][i] + ep_name = ffi.string(entry.name).decode("utf-8") + is_reg = bool(entry.is_registered) + result.append(EpInfo(name=ep_name, is_registered=is_reg)) + return result + + before_eps: dict[str, bool] = {ep.name: ep.is_registered for ep in _discover_eps_sync()} + + if names is not None: + c_name_bufs = [ffi.new("char[]", n.encode("utf-8")) for n in names] + c_names_arr = ffi.new("const char*[]", c_name_bufs) + num_names = len(names) + else: + c_names_arr = ffi.NULL + num_names = 0 + + cb = ffi.NULL + ud = ffi.NULL + if progress_callback is not None: + self._ep_cb_handle = ffi.new_handle(progress_callback) + + @ffi.callback("flEpProgressCallback") + def _ep_cb(ep_name_ptr: object, value: float, user_data: object) -> int: + try: + fn = ffi.from_handle(user_data) + ep_name = ( + ffi.string(ep_name_ptr).decode("utf-8") + if ep_name_ptr != ffi.NULL + else "" + ) + fn(ep_name, float(value)) + return 0 + except Exception: + return 1 + + self._ep_cb = _ep_cb + cb = _ep_cb + ud = self._ep_cb_handle - self._ep_cb = _ep_cb # prevent GC - cb = _ep_cb - ud = self._ep_cb_handle + api.check_status( + api.root.Manager_DownloadAndRegisterEps( + self._native_manager, c_names_arr, num_names, cb, ud + ) + ) - api.check_status( - api.root.Manager_DownloadAndRegisterEps( - self._native_manager, c_names_arr, num_names, cb, ud + after_eps = _discover_eps_sync() + registered = [ + ep.name + for ep in after_eps + if ep.is_registered and not before_eps.get(ep.name, False) + ] + failed = [ + ep.name + for ep in after_eps + if not ep.is_registered and ep.name in (names or []) + ] + + return EpDownloadResult( + success=len(failed) == 0, + status="Completed", + registered_eps=registered, + failed_eps=failed, ) - ) - - # Determine which EPs were newly registered. - after_eps = self.discover_eps() - registered = [ - ep.name - for ep in after_eps - if ep.is_registered and not before_eps.get(ep.name, False) - ] - failed = [ - ep.name - for ep in after_eps - if not ep.is_registered and ep.name in (names or []) - ] - - # Native owns catalog cache invalidation after EP registration; no Python-side action needed. - return EpDownloadResult( - success=len(failed) == 0, - status="Completed", - registered_eps=registered, - failed_eps=failed, - ) + + return await asyncio.to_thread(_download_and_register) # ------------------------------------------------------------------ # Web service lifecycle # ------------------------------------------------------------------ - def start_web_service(self) -> None: + async def start_web_service(self) -> None: """Start the optional built-in web service. The service binds to the URL(s) specified in ``Configuration.web.urls``, or ``http://127.0.0.1:0`` (a random ephemeral port) if not specified. ``FoundryLocalManager.urls`` is updated with the actual bound URL(s). """ - from foundry_local_sdk._native.api import api, ffi + def _start(): + from foundry_local_sdk._native.api import api, ffi - with FoundryLocalManager._lock: api.check_status(api.root.Manager_WebServiceStart(self._native_manager)) urls_out = ffi.new("char***") @@ -241,33 +272,40 @@ def start_web_service(self) -> None: ffi.string(urls_out[0][i]).decode("utf-8") for i in range(int(count_out[0])) ] - def stop_web_service(self) -> None: + return await asyncio.to_thread(_start) + + async def stop_web_service(self) -> None: """Stop the optional built-in web service. Raises: FoundryLocalException: If the web service is not currently running. """ - from foundry_local_sdk._native.api import api + if self.urls is None: + raise FoundryLocalException("Web service is not running.") - with FoundryLocalManager._lock: - if self.urls is None: - raise FoundryLocalException("Web service is not running.") + def _stop(): + from foundry_local_sdk._native.api import api api.check_status(api.root.Manager_WebServiceStop(self._native_manager)) self.urls = None + return await asyncio.to_thread(_stop) + # ------------------------------------------------------------------ # Shutdown # ------------------------------------------------------------------ - def shutdown(self) -> None: + async def shutdown(self) -> None: """Initiate graceful shutdown of the native manager. Safe to call from any thread. Idempotent. """ - from foundry_local_sdk._native.api import api + def _shutdown(): + from foundry_local_sdk._native.api import api + + api.check_status(api.root.Manager_Shutdown(self._native_manager)) - api.check_status(api.root.Manager_Shutdown(self._native_manager)) + return await asyncio.to_thread(_shutdown) def is_shutdown_requested(self) -> bool: """Whether ``shutdown()`` has been called on the native manager.""" @@ -275,31 +313,24 @@ def is_shutdown_requested(self) -> bool: return bool(api.root.Manager_IsShutdownRequested(self._native_manager)) - def close(self) -> None: + async def close(self) -> None: """Tear down the native manager and clear the singleton. After ``close()`` returns, ``FoundryLocalManager.instance`` is ``None`` and - a fresh ``FoundryLocalManager(config)`` may be constructed. The native - side enforces single-instance semantics via its own singleton; this method - drives the native ``Manager_Shutdown`` (which drains sessions and unloads - models) before releasing the handle. + a fresh ``await FoundryLocalManager.initialize(config)`` may be called. Idempotent. Safe to call multiple times. """ - import logging + def _close(): + import logging - from foundry_local_sdk._native.api import api + from foundry_local_sdk._native.api import api - with FoundryLocalManager._lock: - # Idempotent — close() called twice or after a failed __init__. if self._native_manager is None: if FoundryLocalManager.instance is self: FoundryLocalManager.instance = None return - # Drive the orchestrated drain on the native side. Log shutdown errors - # rather than swallowing them silently — we still need to release the - # handle, but the failure must surface somewhere. try: api.check_status(api.root.Manager_Shutdown(self._native_manager)) except Exception as exc: @@ -314,15 +345,18 @@ def close(self) -> None: if FoundryLocalManager.instance is self: FoundryLocalManager.instance = None - def __enter__(self) -> "FoundryLocalManager": + return await asyncio.to_thread(_close) + + async def __aenter__(self) -> "FoundryLocalManager": return self - def __exit__(self, *_: object) -> None: - self.close() + async def __aexit__(self, *_: object) -> None: + await self.close() def __del__(self) -> None: # Best-effort safety net — production code should call close() explicitly. try: - self.close() + if self._native_manager is not None: + asyncio.run(self.close()) except Exception: pass diff --git a/sdk_v2/python/src/foundry_local_sdk/imodel.py b/sdk_v2/python/src/foundry_local_sdk/imodel.py index 1807515b..70d59ce4 100644 --- a/sdk_v2/python/src/foundry_local_sdk/imodel.py +++ b/sdk_v2/python/src/foundry_local_sdk/imodel.py @@ -4,8 +4,9 @@ # -------------------------------------------------------------------------- from __future__ import annotations +import asyncio from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Callable +from typing import TYPE_CHECKING, AsyncGenerator, Callable from foundry_local_sdk.exception import FoundryLocalException from foundry_local_sdk.model_info import DeviceType, ModelInfo, Runtime @@ -76,7 +77,7 @@ def supports_tool_calling(self) -> bool | None: """Whether the model supports tool/function calling, or ``None`` if unknown.""" @abstractmethod - def download(self, progress_callback: Callable[[float], None] | None = None) -> None: + async def download(self, progress_callback: Callable[[float], None] | None = None) -> None: """Download the model to the local cache if not already present. Args: @@ -93,7 +94,7 @@ def get_path(self) -> str: """ @abstractmethod - def load(self) -> None: + async def load(self) -> None: """Load the model into memory if not already loaded.""" @abstractmethod @@ -101,7 +102,7 @@ def remove_from_cache(self) -> None: """Remove the model from the local cache.""" @abstractmethod - def unload(self) -> None: + async def unload(self) -> None: """Unload the model if currently loaded.""" @abstractmethod @@ -311,29 +312,32 @@ def supports_tool_calling(self) -> bool | None: # Model lifecycle # ------------------------------------------------------------------ - def download(self, progress_callback: Callable[[float], None] | None = None) -> None: - from foundry_local_sdk._native.api import api, ffi + async def download(self, progress_callback: Callable[[float], None] | None = None) -> None: + def _download_blocking(): + from foundry_local_sdk._native.api import api, ffi - cb = ffi.NULL - user_data = ffi.NULL + cb = ffi.NULL + user_data = ffi.NULL - if progress_callback is not None: - self._progress_cb_handle = ffi.new_handle(progress_callback) + if progress_callback is not None: + self._progress_cb_handle = ffi.new_handle(progress_callback) - @ffi.callback("flProgressCallback") - def _cb(value: float, ud: object) -> int: - try: - fn = ffi.from_handle(ud) - fn(float(value)) - return 0 - except Exception: - return 1 + @ffi.callback("flProgressCallback") + def _cb(value: float, ud: object) -> int: + try: + fn = ffi.from_handle(ud) + fn(float(value)) + return 0 + except Exception: + return 1 - self._progress_cb = _cb # keep alive - cb = _cb - user_data = self._progress_cb_handle + self._progress_cb = _cb # keep alive + cb = _cb + user_data = self._progress_cb_handle - api.check_status(api.model.Download(self._ptr, cb, user_data)) + api.check_status(api.model.Download(self._ptr, cb, user_data)) + + await asyncio.to_thread(_download_blocking) def get_path(self) -> str: from foundry_local_sdk._native.api import api, ffi @@ -342,15 +346,19 @@ def get_path(self) -> str: api.check_status(api.model.GetPath(self._ptr, out)) return ffi.string(out[0]).decode("utf-8") if out[0] != ffi.NULL else "" - def load(self) -> None: - from foundry_local_sdk._native.api import api - - api.check_status(api.model.Load(self._ptr)) - - def unload(self) -> None: - from foundry_local_sdk._native.api import api - - api.check_status(api.model.Unload(self._ptr)) + async def load(self) -> None: + def _load_blocking(): + from foundry_local_sdk._native.api import api + api.check_status(api.model.Load(self._ptr)) + + await asyncio.to_thread(_load_blocking) + + async def unload(self) -> None: + def _unload_blocking(): + from foundry_local_sdk._native.api import api + api.check_status(api.model.Unload(self._ptr)) + + await asyncio.to_thread(_unload_blocking) def remove_from_cache(self) -> None: from foundry_local_sdk._native.api import api diff --git a/sdk_v2/python/src/foundry_local_sdk/openai/audio_client.py b/sdk_v2/python/src/foundry_local_sdk/openai/audio_client.py index be060c14..fc551330 100644 --- a/sdk_v2/python/src/foundry_local_sdk/openai/audio_client.py +++ b/sdk_v2/python/src/foundry_local_sdk/openai/audio_client.py @@ -5,9 +5,10 @@ """OpenAI-compatible audio transcription client backed by the Foundry Local native layer.""" from __future__ import annotations +import asyncio import json from dataclasses import dataclass -from typing import TYPE_CHECKING, Generator +from typing import TYPE_CHECKING, AsyncGenerator if TYPE_CHECKING: from foundry_local_sdk.imodel import IModel @@ -106,22 +107,25 @@ def _build_request_json(self, audio_file_path: str) -> str: return json.dumps(request) - def _run_native_request(self, request_json: str) -> str: + async def _run_native_request(self, request_json: str) -> str: """Create a fresh AudioSession, process the request, return the response JSON string.""" from foundry_local_sdk.items import TextItem, TextItemType from foundry_local_sdk.request import Request from foundry_local_sdk.session import AudioSession - with ( - AudioSession(self._model) as session, - Request() as request, - ): - request.add_item(TextItem(request_json, TextItemType.OPENAI_JSON)) - with session.process_request(request) as response: - # Copy the text out of the (response-owned) item before the response is released. - return response.get_item(0).text - - def transcribe(self, audio_file_path: str) -> AudioTranscriptionResponse: + def _blocking(): + with AudioSession(self._model) as session: + with Request() as request: + request.add_item(TextItem(request_json, TextItemType.OPENAI_JSON)) + response = session.process_request(request) + try: + return response.get_item(0).text + finally: + response._close() + + return await asyncio.to_thread(_blocking) + + async def transcribe(self, audio_file_path: str) -> AudioTranscriptionResponse: """Transcribe an audio file (non-streaming). Args: @@ -137,23 +141,23 @@ def transcribe(self, audio_file_path: str) -> AudioTranscriptionResponse: self._validate_audio_file_path(audio_file_path) request_json = self._build_request_json(audio_file_path) - response_json = self._run_native_request(request_json) + response_json = await self._run_native_request(request_json) data = json.loads(response_json) return AudioTranscriptionResponse(text=data.get("text", "")) - def transcribe_streaming(self, audio_file_path: str) -> Generator[AudioTranscriptionResponse, None, None]: + async def transcribe_streaming(self, audio_file_path: str) -> AsyncGenerator[AudioTranscriptionResponse, None]: """Transcribe an audio file with streaming chunks. - Consume with a standard ``for`` loop:: + Consume with an async ``for`` loop:: - for chunk in audio_client.transcribe_streaming("recording.mp3"): + async for chunk in audio_client.transcribe_streaming("recording.mp3"): print(chunk.text, end="", flush=True) Args: audio_file_path: Path to the audio file to transcribe. - Returns: - A generator of ``AudioTranscriptionResponse`` objects. + Yields: + ``AudioTranscriptionResponse`` objects as they arrive. Raises: ValueError: If *audio_file_path* is not a non-empty string. @@ -161,19 +165,28 @@ def transcribe_streaming(self, audio_file_path: str) -> Generator[AudioTranscrip """ self._validate_audio_file_path(audio_file_path) request_json = self._build_request_json(audio_file_path) - return self._transcribe_streaming_impl(request_json) + async for item in self._transcribe_streaming_impl(request_json): + yield item - def _transcribe_streaming_impl(self, request_json: str) -> Generator[AudioTranscriptionResponse, None, None]: + async def _transcribe_streaming_impl(self, request_json: str) -> AsyncGenerator[AudioTranscriptionResponse, None]: from foundry_local_sdk.items import TextItem, TextItemType from foundry_local_sdk.request import Request from foundry_local_sdk.session import AudioSession - with ( - AudioSession(self._model) as session, - Request() as request, - ): - session.set_streaming(True) - request.add_item(TextItem(request_json, TextItemType.OPENAI_JSON)) - for item in session.process_streaming_request(request): - data = json.loads(item.text) - yield AudioTranscriptionResponse(text=data.get("text", "")) + def _blocking_stream(): + """Run blocking streaming in a separate thread.""" + items = [] + with AudioSession(self._model) as session: + session.set_streaming(True) + with Request() as request: + request.add_item(TextItem(request_json, TextItemType.OPENAI_JSON)) + with session.process_streaming_request(request) as stream: + for item in stream: + items.append(item) + return items + + # Run blocking operation in thread and yield each result + items = await asyncio.to_thread(_blocking_stream) + for item in items: + data = json.loads(item.text) + yield AudioTranscriptionResponse(text=data.get("text", "")) diff --git a/sdk_v2/python/src/foundry_local_sdk/openai/chat_client.py b/sdk_v2/python/src/foundry_local_sdk/openai/chat_client.py index bcd7a977..025a0ceb 100644 --- a/sdk_v2/python/src/foundry_local_sdk/openai/chat_client.py +++ b/sdk_v2/python/src/foundry_local_sdk/openai/chat_client.py @@ -5,8 +5,9 @@ """OpenAI-compatible chat completion client backed by the Foundry Local native layer.""" from __future__ import annotations +import asyncio import json -from typing import TYPE_CHECKING, Any, Generator +from typing import TYPE_CHECKING, Any, AsyncGenerator from openai.types.chat.chat_completion_message_param import ChatCompletionMessageParam from openai.types.chat import ChatCompletion @@ -183,22 +184,25 @@ def _build_request_json( } return json.dumps(request_dict) - def _run_native_request(self, request_json: str) -> str: + async def _run_native_request(self, request_json: str) -> str: """Create a fresh ChatSession, process the request, return the response JSON string.""" from foundry_local_sdk.items import TextItem, TextItemType from foundry_local_sdk.request import Request from foundry_local_sdk.session import ChatSession - with ( - ChatSession(self._model) as session, - Request() as request, - ): - request.add_item(TextItem(request_json, TextItemType.OPENAI_JSON)) - with session.process_request(request) as response: - # Copy the text out of the (response-owned) item before the response is released. - return response.get_item(0).text - - def complete_chat( + def _blocking(): + with ChatSession(self._model) as session: + with Request() as request: + request.add_item(TextItem(request_json, TextItemType.OPENAI_JSON)) + response = session.process_request(request) + try: + return response.get_item(0).text + finally: + response._close() + + return await asyncio.to_thread(_blocking) + + async def complete( self, messages: list[ChatCompletionMessageParam], tools: list[dict[str, Any]] | None = None, @@ -220,19 +224,19 @@ def complete_chat( self._validate_tools(tools) request_json = self._build_request_json(messages, streaming=False, tools=tools) - response_json = self._run_native_request(request_json) + response_json = await self._run_native_request(request_json) return ChatCompletion.model_validate_json(response_json) - def complete_streaming_chat( + async def stream( self, messages: list[ChatCompletionMessageParam], tools: list[dict[str, Any]] | None = None, - ) -> Generator[ChatCompletionChunk, None, None]: + ) -> AsyncGenerator[ChatCompletionChunk, None]: """Perform a streaming chat completion, yielding chunks as they arrive. - Consume with a standard ``for`` loop:: + Consume with an async ``for`` loop:: - for chunk in client.complete_streaming_chat(messages): + async for chunk in client.stream(messages): delta = chunk.choices[0].delta.content if delta: print(delta, end="", flush=True) @@ -241,40 +245,42 @@ def complete_streaming_chat( messages: Conversation history as a list of OpenAI message dicts. tools: Optional list of tool definitions for function calling. - Returns: - A generator of ``ChatCompletionChunk`` objects. + Yields: + ``ChatCompletionChunk`` objects as they arrive. Raises: ValueError: If messages or tools are malformed. FoundryLocalException: If the native layer returns an error. """ + from foundry_local_sdk.items import TextItem, TextItemType + from foundry_local_sdk.request import Request + from foundry_local_sdk.session import ChatSession + self._validate_messages(messages) self._validate_tools(tools) request_json = self._build_request_json(messages, streaming=True, tools=tools) - from foundry_local_sdk.items import TextItem, TextItemType - from foundry_local_sdk.request import Request - from foundry_local_sdk.session import ChatSession - - with ( - ChatSession(self._model) as session, - Request() as request, - ): - session.set_streaming(True) - request.add_item(TextItem(request_json, TextItemType.OPENAI_JSON)) - for item in session.process_streaming_request(request): - # Each item is a TextItem(OPENAI_JSON) — parse and normalize. - raw = json.loads(item.text) - - # Foundry Local streams tool calls under "message" instead of the - # standard "delta". Normalize to "delta" so ChatCompletionChunk parses. - for choice in raw.get("choices", []): - if "message" in choice and "delta" not in choice: - msg_obj = choice.pop("message") - # ChoiceDeltaToolCall requires "index"; add if absent. - for i, tc in enumerate(msg_obj.get("tool_calls", [])): - tc.setdefault("index", i) - choice["delta"] = msg_obj - - yield ChatCompletionChunk.model_validate(raw) + def _blocking_stream(): + """Run blocking streaming in a separate thread.""" + items = [] + with ChatSession(self._model) as session: + session.set_streaming(True) + with Request() as request: + request.add_item(TextItem(request_json, TextItemType.OPENAI_JSON)) + with session.process_streaming_request(request) as stream: + for item in stream: + items.append(item) + return items + + # Run blocking operation in thread and yield each result + items = await asyncio.to_thread(_blocking_stream) + for item in items: + raw = json.loads(item.text) + for choice in raw.get("choices", []): + if "message" in choice and "delta" not in choice: + msg_obj = choice.pop("message") + for i, tc in enumerate(msg_obj.get("tool_calls", [])): + tc.setdefault("index", i) + choice["delta"] = msg_obj + yield ChatCompletionChunk.model_validate(raw) diff --git a/sdk_v2/python/src/foundry_local_sdk/openai/embedding_client.py b/sdk_v2/python/src/foundry_local_sdk/openai/embedding_client.py index 49d99205..694ed41a 100644 --- a/sdk_v2/python/src/foundry_local_sdk/openai/embedding_client.py +++ b/sdk_v2/python/src/foundry_local_sdk/openai/embedding_client.py @@ -5,6 +5,7 @@ """OpenAI-compatible embedding client backed by the Foundry Local native layer.""" from __future__ import annotations +import asyncio import json from typing import TYPE_CHECKING @@ -39,37 +40,25 @@ def _build_request_json(self, input_value: str | list[str]) -> str: """Build the JSON payload for an embeddings request.""" return json.dumps({"model": self.model_id, "input": input_value}) - def _run_native_request(self, request_json: str) -> str: + async def _run_native_request(self, request_json: str) -> str: """Create a fresh EmbeddingsSession, process the request, return the response JSON string.""" from foundry_local_sdk.items import TextItem, TextItemType from foundry_local_sdk.request import Request from foundry_local_sdk.session import EmbeddingsSession - with ( - EmbeddingsSession(self._model) as session, - Request() as request, - ): - request.add_item(TextItem(request_json, TextItemType.OPENAI_JSON)) - with session.process_request(request) as response: - # Copy the text out of the (response-owned) item before the response is released. - return response.get_item(0).text - - def _parse_response(self, response_json: str) -> CreateEmbeddingResponse: - """Parse the response JSON and apply fields required by the OpenAI type.""" - data = json.loads(response_json) - - # The server may omit "object" on embedding items and "usage" on the response; - # add defaults so CreateEmbeddingResponse.model_validate doesn't reject them. - for item in data.get("data", []): - if "object" not in item: - item["object"] = "embedding" - - if "usage" not in data: - data["usage"] = {"prompt_tokens": 0, "total_tokens": 0} - - return CreateEmbeddingResponse.model_validate(data) - - def generate_embedding(self, input_text: str) -> CreateEmbeddingResponse: + def _blocking(): + with EmbeddingsSession(self._model) as session: + with Request() as request: + request.add_item(TextItem(request_json, TextItemType.OPENAI_JSON)) + response = session.process_request(request) + try: + return response.get_item(0).text + finally: + response._close() + + return await asyncio.to_thread(_blocking) + + async def generate_embedding(self, input_text: str) -> CreateEmbeddingResponse: """Generate embeddings for a single input text. Args: @@ -85,10 +74,10 @@ def generate_embedding(self, input_text: str) -> CreateEmbeddingResponse: self._validate_input(input_text) request_json = self._build_request_json(input_text) - response_json = self._run_native_request(request_json) + response_json = await self._run_native_request(request_json) return self._parse_response(response_json) - def generate_embeddings(self, inputs: list[str]) -> CreateEmbeddingResponse: + async def generate_embeddings(self, inputs: list[str]) -> CreateEmbeddingResponse: """Generate embeddings for multiple input texts in a single request. Args: @@ -107,5 +96,20 @@ def generate_embeddings(self, inputs: list[str]) -> CreateEmbeddingResponse: self._validate_input(text) request_json = self._build_request_json(inputs) - response_json = self._run_native_request(request_json) + response_json = await self._run_native_request(request_json) return self._parse_response(response_json) + + def _parse_response(self, response_json: str) -> CreateEmbeddingResponse: + """Parse the response JSON and apply fields required by the OpenAI type.""" + data = json.loads(response_json) + + # The server may omit "object" on embedding items and "usage" on the response; + # add defaults so CreateEmbeddingResponse.model_validate doesn't reject them. + for item in data.get("data", []): + if "object" not in item: + item["object"] = "embedding" + + if "usage" not in data: + data["usage"] = {"prompt_tokens": 0, "total_tokens": 0} + + return CreateEmbeddingResponse.model_validate(data) diff --git a/sdk_v2/python/src/foundry_local_sdk/request.py b/sdk_v2/python/src/foundry_local_sdk/request.py index fe6879e8..54ab0f4c 100644 --- a/sdk_v2/python/src/foundry_local_sdk/request.py +++ b/sdk_v2/python/src/foundry_local_sdk/request.py @@ -109,6 +109,10 @@ def cancel(self) -> None: api.check_status(api.inference.Request_Cancel(self._ptr)) + async def cancel_async(self) -> None: + """Signal cancellation for an in-flight async request.""" + self.cancel() + def _close(self) -> None: if self._closed: return diff --git a/sdk_v2/python/test/conftest.py b/sdk_v2/python/test/conftest.py index f2921ac0..15b34bc7 100644 --- a/sdk_v2/python/test/conftest.py +++ b/sdk_v2/python/test/conftest.py @@ -21,6 +21,7 @@ from __future__ import annotations +import asyncio import os import pytest @@ -44,12 +45,23 @@ def is_running_in_ci() -> bool: FOUNDRY_TEST_DATA_DIR: str | None = os.environ.get("FOUNDRY_TEST_DATA_DIR") or None +# --------------------------------------------------------------------------- +# Session-scoped event loop — required for session-scoped async fixtures +# --------------------------------------------------------------------------- + +@pytest.fixture(scope="session") +def event_loop(): + loop = asyncio.new_event_loop() + yield loop + loop.close() + + # --------------------------------------------------------------------------- # Session-scoped manager — singleton, created exactly once per test process # --------------------------------------------------------------------------- @pytest.fixture(scope="session") -def manager(): +async def manager(): """Initialize the FoundryLocalManager singleton for the test session. A working manager is the most basic prerequisite for the entire integration @@ -75,7 +87,7 @@ def manager(): config_kwargs["model_cache_dir"] = FOUNDRY_TEST_DATA_DIR config = Configuration(**config_kwargs) - FoundryLocalManager.initialize(config) + await FoundryLocalManager.initialize(config) created_here = True mgr = FoundryLocalManager.instance @@ -88,7 +100,7 @@ def manager(): # singleton field as None or pointing at a different instance). if created_here and FoundryLocalManager.instance is mgr: try: - mgr.close() + await mgr.close() except Exception as e: # Teardown must not break unrelated tests, but a silent swallow # hides real native shutdown bugs. Surface as a warning so it @@ -159,7 +171,7 @@ def _name_matches(m) -> bool: return best -def _model_fixture_or_skip(manager, task: str, role: str, *, load: bool, name_substr: str | None = None): +async def _model_fixture_or_skip(manager, task: str, role: str, *, load: bool, name_substr: str | None = None): from foundry_local_sdk.exception import FoundryLocalException model = _find_smallest_cached_model_for_task(manager, task, name_substr=name_substr) @@ -173,7 +185,7 @@ def _model_fixture_or_skip(manager, task: str, role: str, *, load: bool, name_su pytest.skip(reason) if load: try: - model.load() + await model.load() except FoundryLocalException as e: pytest.skip(f"Could not load {role} model {model.alias!r}: {e}") return model @@ -188,7 +200,7 @@ def _model_fixture_or_skip(manager, task: str, role: str, *, load: bool, name_su @pytest.fixture(scope="session") -def chat_model(manager): +async def chat_model(manager): """Pinned cached chat-completion model, loaded. Falls back to smallest cached chat model when the pinned variant is not present. Skips if none cached. """ @@ -197,22 +209,22 @@ def chat_model(manager): pinned = manager.catalog.get_model_variant(_PINNED_CHAT_MODEL_ID) if pinned is not None and pinned.is_cached: try: - pinned.load() + await pinned.load() return pinned except FoundryLocalException as e: pytest.skip(f"Could not load pinned chat model {_PINNED_CHAT_MODEL_ID!r}: {e}") - return _model_fixture_or_skip(manager, "chat-completion", "chat", load=True) + return await _model_fixture_or_skip(manager, "chat-completion", "chat", load=True) @pytest.fixture(scope="session") -def embedding_model(manager): +async def embedding_model(manager): """Smallest cached embeddings model, loaded. Skips if none cached.""" - return _model_fixture_or_skip(manager, "embeddings", "embedding", load=True) + return await _model_fixture_or_skip(manager, "embeddings", "embedding", load=True) @pytest.fixture(scope="session") -def audio_model(manager): +async def audio_model(manager): """Smallest cached ASR model, loaded. Skips if none cached. Used by tests that exercise the streaming live-audio path, which works with any @@ -220,11 +232,11 @@ def audio_model(manager): one-shot file transcription path should depend on :func:`whisper_audio_model` instead — today only the whisper decoder implements that path. """ - return _model_fixture_or_skip(manager, "automatic-speech-recognition", "audio", load=True) + return await _model_fixture_or_skip(manager, "automatic-speech-recognition", "audio", load=True) @pytest.fixture(scope="session") -def whisper_audio_model(manager): +async def whisper_audio_model(manager): """Smallest cached whisper-family ASR model, loaded. Skips if none cached. The one-shot ``AudioClient.transcribe`` path goes through ``onnx_audio_generator``, which only implements @@ -232,7 +244,7 @@ def whisper_audio_model(manager): ``model does not support audio processing``. Constrain selection to models whose id/alias contains ``whisper``. """ - return _model_fixture_or_skip( + return await _model_fixture_or_skip( manager, "automatic-speech-recognition", "whisper-audio", load=True, name_substr="whisper" ) diff --git a/sdk_v2/python/test/integration/test_audio_client.py b/sdk_v2/python/test/integration/test_audio_client.py index 0f257fd7..5deecfee 100644 --- a/sdk_v2/python/test/integration/test_audio_client.py +++ b/sdk_v2/python/test/integration/test_audio_client.py @@ -31,26 +31,26 @@ def audio_client(whisper_audio_model): return whisper_audio_model.get_audio_client() -def test_transcribe_returns_response(audio_client): +async def test_transcribe_returns_response(audio_client): if not _RECORDING_PATH.is_file(): pytest.skip(f"testdata/Recording.wav not found at {_RECORDING_PATH}") - result = audio_client.transcribe(str(_RECORDING_PATH)) + result = await audio_client.transcribe(str(_RECORDING_PATH)) assert isinstance(result, AudioTranscriptionResponse) assert isinstance(result.text, str) assert result.text.strip() != "" -def test_transcribe_empty_path_raises(audio_client): +async def test_transcribe_empty_path_raises(audio_client): with pytest.raises(ValueError): - audio_client.transcribe("") + await audio_client.transcribe("") -def test_transcribe_nonexistent_path_raises(audio_client): +async def test_transcribe_nonexistent_path_raises(audio_client): # The client may validate at the Python layer (ValueError) or surface the native missing-file error # (FoundryLocalException) — accept either. with pytest.raises((ValueError, FoundryLocalException)): - audio_client.transcribe("/no/such/file.wav") + await audio_client.transcribe("/no/such/file.wav") # Streaming-transcription coverage — mirrors the v1 ``test_should_transcribe_audio_streaming`` @@ -59,12 +59,12 @@ def test_transcribe_nonexistent_path_raises(audio_client): # equivalent to the non-streaming ``transcribe`` result for the same fixture. -def test_transcribe_streaming_yields_chunks(audio_client): +async def test_transcribe_streaming_yields_chunks(audio_client): if not _RECORDING_PATH.is_file(): pytest.skip(f"testdata/Recording.wav not found at {_RECORDING_PATH}") chunks: list[AudioTranscriptionResponse] = [] - for chunk in audio_client.transcribe_streaming(str(_RECORDING_PATH)): + async for chunk in audio_client.transcribe_streaming(str(_RECORDING_PATH)): assert isinstance(chunk, AudioTranscriptionResponse) assert isinstance(chunk.text, str) chunks.append(chunk) @@ -75,13 +75,14 @@ def test_transcribe_streaming_yields_chunks(audio_client): assert full_text != "", "concatenated streamed text must be non-empty" -def test_transcribe_streaming_matches_non_streaming(audio_client): +async def test_transcribe_streaming_matches_non_streaming(audio_client): """Streaming concatenation should match (or be a close superset of) the one-shot result.""" if not _RECORDING_PATH.is_file(): pytest.skip(f"testdata/Recording.wav not found at {_RECORDING_PATH}") - one_shot = audio_client.transcribe(str(_RECORDING_PATH)).text.strip() - streamed = "".join(c.text for c in audio_client.transcribe_streaming(str(_RECORDING_PATH))).strip() + one_shot = (await audio_client.transcribe(str(_RECORDING_PATH))).text.strip() + streamed_chunks = [c async for c in audio_client.transcribe_streaming(str(_RECORDING_PATH))] + streamed = "".join(c.text for c in streamed_chunks).strip() assert one_shot, "one-shot transcription must be non-empty for comparison" assert streamed, "streamed transcription must be non-empty" @@ -99,13 +100,15 @@ def test_transcribe_streaming_matches_non_streaming(audio_client): ) -def test_transcribe_streaming_empty_path_raises(audio_client): - # Validation runs in the outer non-generator function, so the error surfaces - # at call time without needing to iterate the returned generator. +async def test_transcribe_streaming_empty_path_raises(audio_client): + # Validation runs when the async generator is first iterated. with pytest.raises(ValueError): - audio_client.transcribe_streaming("") + async for _ in audio_client.transcribe_streaming(""): + pass -def test_transcribe_streaming_nonexistent_path_raises(audio_client): +async def test_transcribe_streaming_nonexistent_path_raises(audio_client): with pytest.raises((ValueError, FoundryLocalException)): - list(audio_client.transcribe_streaming("/no/such/file.wav")) + async_gen = audio_client.transcribe_streaming("/no/such/file.wav") + async for _ in async_gen: + pass diff --git a/sdk_v2/python/test/integration/test_chat_client.py b/sdk_v2/python/test/integration/test_chat_client.py index 6bc975b9..b12908f9 100644 --- a/sdk_v2/python/test/integration/test_chat_client.py +++ b/sdk_v2/python/test/integration/test_chat_client.py @@ -54,23 +54,23 @@ def chat_client(chat_model) -> ChatClient: class TestNonStreaming: - def test_returns_typed_completion(self, chat_client): - resp = chat_client.complete_chat(PROMPT) + async def test_returns_typed_completion(self, chat_client): + resp = await chat_client.complete(PROMPT) assert isinstance(resp, ChatCompletion) - def test_response_has_content(self, chat_client): - resp = chat_client.complete_chat(PROMPT) + async def test_response_has_content(self, chat_client): + resp = await chat_client.complete(PROMPT) assert resp.choices, "Response must contain at least one choice" msg = resp.choices[0].message assert msg.content is not None assert msg.content.strip(), "Assistant content must not be empty" - def test_response_has_finish_reason(self, chat_client): - resp = chat_client.complete_chat(PROMPT) + async def test_response_has_finish_reason(self, chat_client): + resp = await chat_client.complete(PROMPT) assert resp.choices[0].finish_reason in {"stop", "length"} - def test_response_has_usage_with_positive_counts(self, chat_client): - resp = chat_client.complete_chat(PROMPT) + async def test_response_has_usage_with_positive_counts(self, chat_client): + resp = await chat_client.complete(PROMPT) assert resp.usage is not None assert resp.usage.prompt_tokens > 0 assert resp.usage.completion_tokens > 0 @@ -78,25 +78,29 @@ def test_response_has_usage_with_positive_counts(self, chat_client): resp.usage.prompt_tokens + resp.usage.completion_tokens ) - def test_response_model_field_matches_loaded_model(self, chat_client): - resp = chat_client.complete_chat(PROMPT) + async def test_response_model_field_matches_loaded_model(self, chat_client): + resp = await chat_client.complete(PROMPT) # Some models echo a normalized id; just confirm it is present. assert resp.model - def test_invalid_messages_rejected_before_native_call(self, chat_client): + async def test_invalid_messages_rejected_before_native_call(self, chat_client): with pytest.raises(ValueError): - chat_client.complete_chat([]) + await chat_client.complete([]) class TestStreaming: - def test_yields_chunks(self, chat_client): - chunks = list(chat_client.complete_streaming_chat(PROMPT)) + async def test_yields_chunks(self, chat_client): + chunks = [] + async for chunk in chat_client.stream(PROMPT): + chunks.append(chunk) assert chunks, "Streaming should yield at least one chunk" for c in chunks: assert isinstance(c, ChatCompletionChunk) - def test_concatenated_content_is_non_empty(self, chat_client): - chunks = list(chat_client.complete_streaming_chat(PROMPT)) + async def test_concatenated_content_is_non_empty(self, chat_client): + chunks = [] + async for chunk in chat_client.stream(PROMPT): + chunks.append(chunk) # Concatenate visible content. Foundry Local may emit assistant text # under either delta.content (per-token) or as a single message-level # payload that the client normalises into a delta. @@ -115,22 +119,27 @@ def test_concatenated_content_is_non_empty(self, chat_client): + "]" ) - def test_final_chunk_has_finish_reason(self, chat_client): - chunks = list(chat_client.complete_streaming_chat(PROMPT)) + async def test_final_chunk_has_finish_reason(self, chat_client): + chunks = [] + async for chunk in chat_client.stream(PROMPT): + chunks.append(chunk) finish = None for c in chunks: if c.choices and c.choices[0].finish_reason: finish = c.choices[0].finish_reason assert finish in {"stop", "length"} - def test_break_mid_stream_does_not_crash(self, chat_client): - gen = chat_client.complete_streaming_chat(PROMPT) + async def test_break_mid_stream_does_not_crash(self, chat_client): + gen = chat_client.stream(PROMPT) # Pull one chunk, then abandon — the finally-block in the session # must cancel the request and join the background thread cleanly. - first = next(gen) + first = None + async for chunk in gen: + first = chunk + break assert isinstance(first, ChatCompletionChunk) - gen.close() + # Async generator cleanup happens automatically # Subsequent calls must still work. - resp = chat_client.complete_chat(PROMPT) + resp = await chat_client.complete(PROMPT) assert isinstance(resp, ChatCompletion) diff --git a/sdk_v2/python/test/integration/test_embedding_client.py b/sdk_v2/python/test/integration/test_embedding_client.py index 486aa2b6..84b2ec22 100644 --- a/sdk_v2/python/test/integration/test_embedding_client.py +++ b/sdk_v2/python/test/integration/test_embedding_client.py @@ -43,36 +43,36 @@ def embedding_client(embedding_model): class TestSingle: - def test_returns_typed_response(self, embedding_client): - resp = embedding_client.generate_embedding(SAMPLE_INPUTS[0]) + async def test_returns_typed_response(self, embedding_client): + resp = await embedding_client.generate_embedding(SAMPLE_INPUTS[0]) assert isinstance(resp, CreateEmbeddingResponse) - def test_one_vector_for_one_input(self, embedding_client): - resp = embedding_client.generate_embedding(SAMPLE_INPUTS[0]) + async def test_one_vector_for_one_input(self, embedding_client): + resp = await embedding_client.generate_embedding(SAMPLE_INPUTS[0]) assert len(resp.data) == 1 assert len(resp.data[0].embedding) > 0 - def test_empty_input_rejected_before_native_call(self, embedding_client): + async def test_empty_input_rejected_before_native_call(self, embedding_client): with pytest.raises(ValueError): - embedding_client.generate_embedding("") + await embedding_client.generate_embedding("") - def test_whitespace_only_rejected(self, embedding_client): + async def test_whitespace_only_rejected(self, embedding_client): with pytest.raises(ValueError): - embedding_client.generate_embedding(" ") + await embedding_client.generate_embedding(" ") class TestBatched: - def test_one_vector_per_input(self, embedding_client): - resp = embedding_client.generate_embeddings(SAMPLE_INPUTS) + async def test_one_vector_per_input(self, embedding_client): + resp = await embedding_client.generate_embeddings(SAMPLE_INPUTS) assert len(resp.data) == len(SAMPLE_INPUTS) - def test_consistent_dimensions(self, embedding_client): - resp = embedding_client.generate_embeddings(SAMPLE_INPUTS) + async def test_consistent_dimensions(self, embedding_client): + resp = await embedding_client.generate_embeddings(SAMPLE_INPUTS) dims = {len(d.embedding) for d in resp.data} assert len(dims) == 1, f"Expected uniform dim, got {dims}" - def test_distinct_inputs_produce_distinct_vectors(self, embedding_client): - resp = embedding_client.generate_embeddings(SAMPLE_INPUTS) + async def test_distinct_inputs_produce_distinct_vectors(self, embedding_client): + resp = await embedding_client.generate_embeddings(SAMPLE_INPUTS) v0 = list(resp.data[0].embedding) v1 = list(resp.data[1].embedding) # Cosine similarity well below 1 — distinct sentences shouldn't collide. @@ -82,10 +82,10 @@ def test_distinct_inputs_produce_distinct_vectors(self, embedding_client): cos = dot / (n0 * n1) if n0 and n1 else 0.0 assert cos < 0.999, f"Expected distinct vectors, got cosine sim {cos}" - def test_empty_list_rejected(self, embedding_client): + async def test_empty_list_rejected(self, embedding_client): with pytest.raises(ValueError): - embedding_client.generate_embeddings([]) + await embedding_client.generate_embeddings([]) - def test_empty_element_in_list_rejected(self, embedding_client): + async def test_empty_element_in_list_rejected(self, embedding_client): with pytest.raises(ValueError): - embedding_client.generate_embeddings(["ok", ""]) + await embedding_client.generate_embeddings(["ok", ""]) diff --git a/sdk_v2/python/test/integration/test_ep_lifecycle.py b/sdk_v2/python/test/integration/test_ep_lifecycle.py index 7cabeb7a..0364b35a 100644 --- a/sdk_v2/python/test/integration/test_ep_lifecycle.py +++ b/sdk_v2/python/test/integration/test_ep_lifecycle.py @@ -9,8 +9,8 @@ class TestEpLifecycle: - def test_discover_eps_returns_list(self, manager): - eps = manager.discover_eps() + async def test_discover_eps_returns_list(self, manager): + eps = await manager.discover_eps() assert isinstance(eps, list) for ep in eps: assert isinstance(ep, EpInfo) diff --git a/sdk_v2/python/test/integration/test_model_lifecycle.py b/sdk_v2/python/test/integration/test_model_lifecycle.py index d23145c5..1634af2e 100644 --- a/sdk_v2/python/test/integration/test_model_lifecycle.py +++ b/sdk_v2/python/test/integration/test_model_lifecycle.py @@ -19,23 +19,23 @@ # --------------------------------------------------------------------------- class TestLoadUnload: - def test_load_idempotent(self, chat_model): + async def test_load_idempotent(self, chat_model): # The fixture already loaded this model; load() again must be a no-op. - chat_model.load() + await chat_model.load() assert chat_model.is_loaded is True - def test_unload_then_load(self, chat_model): + async def test_unload_then_load(self, chat_model): try: - chat_model.unload() + await chat_model.unload() assert chat_model.is_loaded is False - chat_model.load() + await chat_model.load() assert chat_model.is_loaded is True finally: # Leave the model loaded so other session-scoped consumers of ``chat_model`` see the same state # they expect. if not chat_model.is_loaded: - chat_model.load() + await chat_model.load() def test_is_cached_true_for_fixture_model(self, chat_model): # The fixture selection requires is_cached — sanity check the invariant. @@ -47,7 +47,7 @@ def test_is_cached_true_for_fixture_model(self, chat_model): # --------------------------------------------------------------------------- @pytest.mark.manual -def test_download_progress_callback_fires(chat_model): +async def test_download_progress_callback_fires(chat_model): """Verify the download progress callback fires at least once. Removes the model from cache via the public ``remove_from_cache`` API (the SDK's supported way to force a @@ -60,7 +60,7 @@ def test_download_progress_callback_fires(chat_model): """ # Unload first — removing a loaded model from cache is not supported. if chat_model.is_loaded: - chat_model.unload() + await chat_model.unload() chat_model.remove_from_cache() assert chat_model.is_cached is False @@ -71,7 +71,7 @@ def on_progress(pct: float) -> None: received.append(pct) try: - chat_model.download(progress_callback=on_progress) + await chat_model.download(progress_callback=on_progress) assert chat_model.is_cached is True assert len(received) >= 1 for pct in received: @@ -79,4 +79,4 @@ def on_progress(pct: float) -> None: finally: # Restore the loaded state so session-scoped fixture consumers are not disturbed. if chat_model.is_cached and not chat_model.is_loaded: - chat_model.load() + await chat_model.load() diff --git a/sdk_v2/python/test/integration/test_web_service_and_eps.py b/sdk_v2/python/test/integration/test_web_service_and_eps.py index 51f9803f..050678f3 100644 --- a/sdk_v2/python/test/integration/test_web_service_and_eps.py +++ b/sdk_v2/python/test/integration/test_web_service_and_eps.py @@ -20,8 +20,8 @@ # --------------------------------------------------------------------------- class TestWebServiceLifecycle: - def test_start_web_service_populates_urls(self, manager): - manager.start_web_service() + async def test_start_web_service_populates_urls(self, manager): + await manager.start_web_service() try: assert isinstance(manager.urls, list) assert len(manager.urls) > 0 @@ -29,32 +29,32 @@ def test_start_web_service_populates_urls(self, manager): assert isinstance(url, str) assert url.startswith("http://") finally: - manager.stop_web_service() + await manager.stop_web_service() assert manager.urls is None - def test_stop_without_start_raises(self, manager): + async def test_stop_without_start_raises(self, manager): # Make sure no prior test left the service running. if manager.urls is not None: - manager.stop_web_service() + await manager.stop_web_service() assert manager.urls is None with pytest.raises(FoundryLocalException) as exc_info: - manager.stop_web_service() + await manager.stop_web_service() assert "not running" in str(exc_info.value).lower() - def test_start_stop_start_cycle(self, manager): + async def test_start_stop_start_cycle(self, manager): try: - manager.start_web_service() + await manager.start_web_service() assert manager.urls and len(manager.urls) > 0 - manager.stop_web_service() + await manager.stop_web_service() assert manager.urls is None - manager.start_web_service() + await manager.start_web_service() assert manager.urls and len(manager.urls) > 0 finally: if manager.urls is not None: - manager.stop_web_service() + await manager.stop_web_service() # --------------------------------------------------------------------------- @@ -62,8 +62,8 @@ def test_start_stop_start_cycle(self, manager): # --------------------------------------------------------------------------- class TestEpDiscovery: - def test_discover_eps_returns_list(self, manager): - eps = manager.discover_eps() + async def test_discover_eps_returns_list(self, manager): + eps = await manager.discover_eps() assert isinstance(eps, list) for ep in eps: assert isinstance(ep, EpInfo) diff --git a/sdk_v2/python/test/integration/test_zz_manager_shutdown.py b/sdk_v2/python/test/integration/test_zz_manager_shutdown.py index 0d7b68bb..6c81c961 100644 --- a/sdk_v2/python/test/integration/test_zz_manager_shutdown.py +++ b/sdk_v2/python/test/integration/test_zz_manager_shutdown.py @@ -31,14 +31,14 @@ class TestManagerShutdown: - def test_shutdown_sets_is_shutdown_requested(self, manager): + async def test_shutdown_sets_is_shutdown_requested(self, manager): assert manager.is_shutdown_requested() is False - manager.shutdown() + await manager.shutdown() assert manager.is_shutdown_requested() is True - def test_shutdown_is_idempotent(self, manager): + async def test_shutdown_is_idempotent(self, manager): # Previous test may already have called shutdown; calling again # must not raise and the flag must stay set. - manager.shutdown() - manager.shutdown() + await manager.shutdown() + await manager.shutdown() assert manager.is_shutdown_requested() is True diff --git a/sdk_v2/python/test/integration/test_zz_singleton_recreate.py b/sdk_v2/python/test/integration/test_zz_singleton_recreate.py index 9fd7bc38..97dbe6d7 100644 --- a/sdk_v2/python/test/integration/test_zz_singleton_recreate.py +++ b/sdk_v2/python/test/integration/test_zz_singleton_recreate.py @@ -54,40 +54,41 @@ def _make_config(manager) -> Configuration: @pytest.fixture -def restore_singleton(manager): +async def restore_singleton(manager): """Save the current config, run the test, then leave a working singleton in place.""" saved_config = _make_config(manager) yield saved_config if FoundryLocalManager.instance is None: - FoundryLocalManager(saved_config) + await FoundryLocalManager.initialize(saved_config) class TestSingletonRecreate: - def test_close_clears_singleton(self, restore_singleton): + async def test_close_clears_singleton(self, restore_singleton): config = restore_singleton # Close whatever singleton currently exists. assert FoundryLocalManager.instance is not None - FoundryLocalManager.instance.close() + await FoundryLocalManager.instance.close() assert FoundryLocalManager.instance is None # A fresh manager can now be constructed and registers as the singleton. - new_mgr = FoundryLocalManager(config) - assert FoundryLocalManager.instance is new_mgr + await FoundryLocalManager.initialize(config) + assert FoundryLocalManager.instance is not None - def test_close_is_idempotent(self, restore_singleton): + async def test_close_is_idempotent(self, restore_singleton): mgr = FoundryLocalManager.instance assert mgr is not None - mgr.close() + await mgr.close() # Second close must not raise. - mgr.close() + await mgr.close() assert FoundryLocalManager.instance is None - def test_context_manager_clears_singleton(self, restore_singleton): + async def test_context_manager_clears_singleton(self, restore_singleton): config = restore_singleton - # Tear down the existing singleton so the with-block can build a new one. + # Tear down the existing singleton so we can build a new one. if FoundryLocalManager.instance is not None: - FoundryLocalManager.instance.close() + await FoundryLocalManager.instance.close() - with FoundryLocalManager(config) as m: + await FoundryLocalManager.initialize(config) + async with FoundryLocalManager.instance as m: assert FoundryLocalManager.instance is m assert FoundryLocalManager.instance is None