diff --git a/openelectricity/client.py b/openelectricity/client.py index 86c671c..192c03e 100644 --- a/openelectricity/client.py +++ b/openelectricity/client.py @@ -5,8 +5,10 @@ """ import asyncio +import concurrent.futures import os import ssl +from collections.abc import Coroutine from datetime import datetime from typing import Any, TypeVar, cast @@ -32,6 +34,22 @@ logger = get_logger("client") +def _run_sync(coro: Coroutine[Any, Any, T]) -> T: + """Run a coroutine to completion from synchronous code. + + Uses ``asyncio.run`` when no event loop is running. When a loop is already + running — as in a Jupyter/IPython notebook — the coroutine is run on its + own loop in a worker thread, so the synchronous client works there too. + """ + try: + asyncio.get_running_loop() + except RuntimeError: + return asyncio.run(coro) + + with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool: + return pool.submit(asyncio.run, coro).result() + + class OpenElectricityError(Exception): """Base exception for OpenElectricity API errors.""" @@ -162,8 +180,9 @@ class OEClient(BaseOEClient): """ Synchronous client for the OpenElectricity API. - Note: This client uses aiohttp with asyncio.run() internally to maintain - API consistency while using the same underlying HTTP client as the async version. + It runs aiohttp under the hood and is safe to call from inside an existing + event loop (e.g. a Jupyter/IPython notebook) — when a loop is already + running, requests are dispatched to a worker thread. """ def __init__( @@ -387,7 +406,7 @@ async def _run(): self._session = session return await self._async_get_facilities(facility_code, status_id, fueltech_id, network_id, network_region) - return asyncio.run(_run()) + return _run_sync(_run()) def get_network_data( self, @@ -437,7 +456,7 @@ async def _run(): secondary_grouping, ) - return asyncio.run(_run()) + return _run_sync(_run()) def get_facility_data( self, @@ -458,7 +477,7 @@ async def _run(): network_code, facility_code, metrics, interval, date_start, date_end, unit_code ) - return asyncio.run(_run()) + return _run_sync(_run()) def get_market( self, @@ -479,7 +498,7 @@ async def _run(): network_code, metrics, interval, date_start, date_end, primary_grouping, network_region ) - return asyncio.run(_run()) + return _run_sync(_run()) def get_current_user(self) -> OpennemUserResponse: """Get current user information.""" @@ -489,7 +508,7 @@ async def _run(): self._session = session return await self._async_get_current_user() - return asyncio.run(_run()) + return _run_sync(_run()) def close(self) -> None: """Close the underlying HTTP client.""" @@ -498,7 +517,7 @@ def close(self) -> None: async def _close(): await cast(ClientSession, self._session).close() - asyncio.run(_close()) + _run_sync(_close()) def __enter__(self) -> "OEClient": return self diff --git a/tests/test_run_sync.py b/tests/test_run_sync.py new file mode 100644 index 0000000..44d4326 --- /dev/null +++ b/tests/test_run_sync.py @@ -0,0 +1,28 @@ +""" +Tests for _run_sync, the helper that lets the synchronous client work both +standalone and inside an already-running event loop (e.g. a notebook). +""" + +import asyncio + +from openelectricity.client import _run_sync + + +async def _echo(value: str) -> str: + await asyncio.sleep(0) + return value + + +def test_run_sync_without_running_loop() -> None: + """With no event loop running, the coroutine runs via asyncio.run.""" + assert _run_sync(_echo("standalone")) == "standalone" + + +def test_run_sync_inside_running_loop() -> None: + """With a loop already running (the notebook case), it still completes.""" + + async def _notebook_like() -> str: + # A loop is running here; a naive asyncio.run() would raise RuntimeError. + return _run_sync(_echo("notebook")) + + assert asyncio.run(_notebook_like()) == "notebook"