From d2bf71f3c81d5f4d789c230eff00d08493fb4669 Mon Sep 17 00:00:00 2001 From: thodson-usgs Date: Mon, 25 May 2026 10:38:26 -0500 Subject: [PATCH 01/16] feat(waterdata): Add async parallel chunker over httpx.AsyncClient MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds a parallel fan-out path to `multi_value_chunked`. When `API_USGS_CONCURRENT` resolves to >1 (default: 16), the decorator runs the sub-requests of an over-budget plan concurrently under one shared `httpx.AsyncClient`, instead of issuing them serially. Falls back to the serial sync path (with a one-time UserWarning) when no async fetch sibling is wired or when an asyncio event loop is already running (Jupyter, IPython, async apps — `asyncio.run` would otherwise raise). Architecture (`dataretrieval/waterdata/chunking.py`): * `_fan_out_async(plan, fetch_once, fetch_async, *, max_concurrent)` is the orchestrator: it dispatches every sub-request concurrently via `asyncio.gather(return_exceptions=True)`. Completed pairs survive a sibling's transient failure, so partial state stays recoverable through `ChunkedCall.resume()` on the sync path. * Failure precedence in the gather: 1. Cancellation/interrupt signals (CancelledError, KeyboardInterrupt, SystemExit) propagate unmodified — never wrapped as transients. Cancellation is asyncio's abort signal; rewriting it as ChunkInterrupted would silently consume the user's stop request. 2. Recognized transients (RateLimited, ServiceUnavailable, bare httpx.HTTPError) wrap as ChunkInterrupted so the user gets a resumable handle even when a non-transient bug landed earlier in submission order. 3. Otherwise raise the first failure in submission order, preserving its type. * `_execute_in_parallel` owns the sync→async bridge: `asyncio.run` dispatch with the `fetch_async is None` and running-event-loop fallbacks (each a one-time UserWarning, then serial). * `_publish_async_client` / `get_active_async_client` / `_chunked_async_client` ContextVar let async paginated-loop helpers (`_walk_pages_async`, `_paginate_async`) reuse one `AsyncClient` connection pool across every concurrent sub-request. Wiring (`dataretrieval/waterdata/utils.py`): * `_walk_pages_async`, `_paginate_async`, `_client_for_async`, `_fetch_once_async` — async siblings of the sync paginate path, sharing the same `parse_response` / `follow_up` callbacks and the `_ogc_parse_response` parser. * The `@chunking.multi_value_chunked(fetch_async=_fetch_once_async)` decorator on `_fetch_once` wires the async sibling so the parallel path is available to every Water Data OGC getter. * `ChunkedCall.record()` encapsulates the completion write so the serial loop and the parallel fan-out share it; `_chunks` is a sparse index map so a parallel partial-failure resumes correctly via the sync path. Concurrency cap (`API_USGS_CONCURRENT`): * Integer N >= 1: bounded fan-out (semaphore-gated, N=1 forces serial sync). Default 16 — the server-friendly sweet spot. * `unbounded`: no per-call cap (`Semaphore(sys.maxsize)`). * Unset: default 16. Retries (`API_USGS_RETRIES`, default 4; `0` disables): each sub-request is retried on a transient failure with exponential backoff + full jitter, so a large fan-out completes through the AWS API Gateway's burst throttling and the occasional backend straggler instead of aborting on the first 429/5xx/timeout. * `RetryPolicy` — a frozen value object owning the timing decisions (`from_env`, `should_retry`, `backoff`). Full jitter (`random.uniform(0, ceiling)`) de-correlates the concurrent retries so they don't re-burst in lockstep. A server `Retry-After` overrides the computed backoff, unless it exceeds `retry_after_cap` (60s) — a multi-minute quota-window reset escalates to the resumable interruption instead of blocking the call inline. * `_retryable` — chain-walking predicate, narrower than `_classify_chunk_error`: retries `RateLimited` / `ServiceUnavailable` / `httpx.TransportError` but NOT `httpx.InvalidURL`. * `_retry_sync` / `_retry_async` drivers wrap the per-sub-request fetch at both seams (`ChunkedCall._issue`, `_fan_out_async.track`); the async retry runs inside the semaphore, so a backing-off chunk holds its slot and effective concurrency shrinks under throttling. On exhaustion the last exception re-raises into the existing `wrap_failure` path, so `.resume()` stays the escape hatch. * `ProgressReporter.note_retry` surfaces the backoff on the status line ("retrying (attempt N, waiting Ns)"), cleared by the next page. Test scaffolding: `tests/conftest.py` extends the `_serial_chunker` autouse fixture to pin `API_USGS_CONCURRENT=1` and `API_USGS_RETRIES=0` so the existing mocked suite stays on the deterministic serial path with transients surfacing immediately; async and retry tests opt back in by re-setting the env vars inside their body. Tests: async-path coverage in `tests/waterdata_chunking_test.py` (one-call-per-sub-request, mid-fan-out transient yields resumable ChunkInterrupted, fallback-to-serial parametrized over running-loop and missing-fetch_async, cancellation-wins-over- transient-sibling regression), plus retry coverage (policy math/jitter bounds, `_retryable` taxonomy, sync+async transient-then-success, exhausted-still-resumable, long-`Retry-After` escalation). `tests/waterdata_progress_test.py` adds progress integration for `_fan_out_async` / `_paginate_async` and the `note_retry` render/clear. `tests/waterdata_utils_test.py` adds a `_walk_pages_async` initial-parse-error test. Test suite: 435 passing, 2 skipped (mocked); ruff clean. Co-Authored-By: Claude Opus 4.7 (1M context) --- dataretrieval/waterdata/_progress.py | 16 + dataretrieval/waterdata/chunking.py | 606 +++++++++++++++++++++++++-- dataretrieval/waterdata/utils.py | 147 ++++++- tests/conftest.py | 31 +- tests/waterdata_chunking_test.py | 420 +++++++++++++++++++ tests/waterdata_progress_test.py | 141 +++++++ tests/waterdata_utils_test.py | 31 ++ 7 files changed, 1356 insertions(+), 36 deletions(-) diff --git a/dataretrieval/waterdata/_progress.py b/dataretrieval/waterdata/_progress.py index 7263d555..7104f3af 100644 --- a/dataretrieval/waterdata/_progress.py +++ b/dataretrieval/waterdata/_progress.py @@ -121,6 +121,9 @@ def __init__( # The hourly request quota (``x-ratelimit-limit``), shown as the # denominator when the server reports it. self.rate_limit: str | None = None + # Transient note shown while a sub-request backs off before a + # retry; cleared by the next page/chunk so it doesn't linger. + self.retry_note: str | None = None self._last_len = 0 # Whether anything was actually written to the stream — drives whether # close() needs a terminating newline. (``current_chunk`` is a poor @@ -140,6 +143,7 @@ def start_chunk(self, index: int) -> None: avoids a premature "0 pages" frame before the first page arrives. """ self.current_chunk = index + self.retry_note = None if self.total_chunks > 1: self._render() @@ -147,6 +151,16 @@ def add_page(self, rows: int = 0) -> None: """Record one fetched page carrying ``rows`` rows and redraw.""" self.pages += 1 self.rows += int(rows) + self.retry_note = None + self._render() + + def note_retry(self, *, attempt: int, wait: float) -> None: + """Show that a sub-request is backing off before retry ``attempt``. + + Cleared by the next :meth:`add_page` / :meth:`start_chunk` so the + line returns to normal progress once the retry succeeds. + """ + self.retry_note = f"retrying (attempt {attempt}, waiting {wait:.0f}s)" self._render() def set_rate_remaining( @@ -179,6 +193,8 @@ def _format(self) -> str: else: segment = f"{remaining} requests remaining" parts.append(segment) + if self.retry_note is not None: + parts.append(self.retry_note) if self.service: return f"Retrieving: {self.service} · " + " · ".join(parts) return "Progress: " + " · ".join(parts) diff --git a/dataretrieval/waterdata/chunking.py b/dataretrieval/waterdata/chunking.py index 36ee24fd..1e3b429d 100644 --- a/dataretrieval/waterdata/chunking.py +++ b/dataretrieval/waterdata/chunking.py @@ -9,14 +9,34 @@ sub-request URL fits. Requests that already fit get a trivial single-step plan — ``ChunkedCall`` has one code path either way. +Concurrency: when ``API_USGS_CONCURRENT`` is set to an integer N > 1 +(or the literal ``unbounded``), ``multi_value_chunked`` fans the plan +out across ``N`` async coroutines sharing one ``httpx.AsyncClient`` +instead of issuing sub-requests serially. ``N=1`` forces the +synchronous path. The default (16) is the server-friendly sweet +spot; higher values can trip USGS burst-protection 5xx in practice. +The wrapper falls back to the serial path (with a ``UserWarning``) +when an asyncio event loop is already running (Jupyter / IPython / +async apps) or when no async fetch sibling is wired into the +decorator. + +Retries: each sub-request is retried on a transient failure (429, +5xx, connect/read timeout) with exponential backoff + full jitter, +honoring a server ``Retry-After`` when present. ``API_USGS_RETRIES`` +sets the cap (default 4; ``0`` disables). A ``Retry-After`` longer +than the per-call ceiling isn't slept off inline — it escalates to +the resumable interruption below so a multi-minute quota-window +reset doesn't block the call. + Interruption: any mid-stream transient failure (429, 5xx) surfaces as a ``ChunkInterrupted`` subclass — ``QuotaExhausted`` for 429, ``ServiceInterrupted`` for 5xx. The exception carries ``.call``, a ``ChunkedCall`` handle that owns the already-completed sub-request -state. Call ``.call.resume()`` once the underlying condition -clears; only the still-pending sub-requests are re-issued. -``Retry-After`` (when the server sets it) is surfaced on the -exception as ``.retry_after``. +state (sparse-indexed on the parallel path, contiguous-prefix on +the serial path). Call ``.call.resume()`` once the underlying +condition clears; only the still-pending sub-requests are +re-issued, via the serial sync path. ``Retry-After`` (when the +server sets it) is surfaced on the exception as ``.retry_after``. Dedup: list-axis chunks don't overlap; filter-axis chunks can, so ``_combine_chunk_frames`` dedupes by feature ``id``. ``properties``, @@ -27,11 +47,17 @@ from __future__ import annotations +import asyncio import copy import functools import itertools import math -from collections.abc import Callable, Iterator +import os +import random +import sys +import time +import warnings +from collections.abc import Awaitable, Callable, Iterator from contextlib import contextmanager, suppress from contextvars import ContextVar from dataclasses import dataclass @@ -93,15 +119,172 @@ # Response header USGS uses to advertise remaining hourly quota. _QUOTA_HEADER = "x-ratelimit-remaining" +# Environment variable that controls async fan-out concurrency. Read +# at call time (not import) so test patches via ``monkeypatch.setenv`` +# take effect. The default (16) is the server-friendly sweet spot: +# higher values trip the upstream into 5xx burst-protection in +# practice. Set to ``1`` to force the serial sync path, set to +# ``unbounded`` for no per-call cap (use sparingly — you own the +# upstream-burst risk). +_CONCURRENCY_ENV = "API_USGS_CONCURRENT" +_CONCURRENCY_DEFAULT = 16 +_CONCURRENCY_UNBOUNDED = "unbounded" + + +def _read_concurrency_env() -> int | None: + """ + Resolve the ``API_USGS_CONCURRENT`` env var to a parallelism cap. + + Returns + ------- + int or None + ``1`` for the serial sync path; an integer >1 for bounded + parallelism; ``None`` to disable the per-call cap entirely + (``unbounded`` keyword). Unset → default of + ``_CONCURRENCY_DEFAULT``. + """ + raw = os.environ.get(_CONCURRENCY_ENV) + if raw is None: + return _CONCURRENCY_DEFAULT + raw = raw.strip() + if raw == "": + return _CONCURRENCY_DEFAULT + if raw.lower() == _CONCURRENCY_UNBOUNDED: + return None + try: + value = int(raw) + except ValueError as exc: + raise ValueError( + f"{_CONCURRENCY_ENV} must be a positive integer or " + f"'{_CONCURRENCY_UNBOUNDED}'; got {raw!r}." + ) from exc + if value < 1: + raise ValueError( + f"{_CONCURRENCY_ENV} must be >= 1 (got {value}); use " + f"'{_CONCURRENCY_UNBOUNDED}' to disable the cap." + ) + return value + + +# Retry-with-backoff for transient sub-request failures (429 / 5xx / +# connect-read timeouts). The env var is read at call time so test +# ``monkeypatch.setenv`` takes effect; the timing constants are +# module-level so power users (and tests) can ``monkeypatch.setattr`` +# them. Defaults: 4 retries, 0.5s base doubling under full jitter up to +# a 30s per-attempt ceiling, and honor a server ``Retry-After`` up to +# 60s before escalating to a resumable interruption instead. +_RETRIES_ENV = "API_USGS_RETRIES" +_RETRIES_DEFAULT = 4 +_RETRY_BASE_BACKOFF = 0.5 +_RETRY_MAX_BACKOFF = 30.0 +_RETRY_AFTER_CAP = 60.0 + + +def _read_retries_env() -> int: + """ + Resolve the ``API_USGS_RETRIES`` env var to a max-retry count. + + Returns + ------- + int + Number of retries after the first attempt; ``0`` disables + retrying. Unset/blank → ``_RETRIES_DEFAULT``. + """ + raw = os.environ.get(_RETRIES_ENV) + if raw is None or raw.strip() == "": + return _RETRIES_DEFAULT + try: + value = int(raw.strip()) + except ValueError as exc: + raise ValueError( + f"{_RETRIES_ENV} must be a non-negative integer (got {raw!r})." + ) from exc + if value < 0: + raise ValueError(f"{_RETRIES_ENV} must be >= 0 (got {value}).") + return value + + +@dataclass(frozen=True) +class RetryPolicy: + """Bounded retry-with-backoff config for transient sub-request failures. + + An immutable value object that owns the *timing* decisions; the + exception taxonomy ("is this worth retrying at all?") lives in + :func:`_retryable`. Backoff is exponential with **full jitter** + (:func:`random.uniform` over ``[0, ceiling]``) so the concurrent + fan-out's retries don't re-burst in lockstep. A server ``Retry-After`` + hint, when present, overrides the computed backoff — unless it exceeds + :attr:`retry_after_cap`, in which case retrying stops and the failure + surfaces as a resumable :class:`ChunkInterrupted` (a multi-minute + quota-window reset shouldn't block the call inline). + + Attributes + ---------- + max_retries : int + Retries attempted after the first try; ``0`` disables retrying. + base_backoff : float + Seconds; the jitter ceiling for the first retry, doubled each + subsequent attempt. + max_backoff : float + Upper bound on any single attempt's backoff ceiling. + retry_after_cap : float + Largest ``Retry-After`` (seconds) honored inline; longer hints + escalate to a resumable interruption. + """ + + max_retries: int = _RETRIES_DEFAULT + base_backoff: float = _RETRY_BASE_BACKOFF + max_backoff: float = _RETRY_MAX_BACKOFF + retry_after_cap: float = _RETRY_AFTER_CAP + + @classmethod + def from_env(cls) -> RetryPolicy: + """Build a policy, resolving ``max_retries`` from ``API_USGS_RETRIES``.""" + return cls(max_retries=_read_retries_env()) + + def should_retry(self, attempt: int, retry_after: float | None) -> bool: + """Whether a just-failed ``attempt`` (1-based) warrants another try. + + A ``Retry-After`` longer than ``retry_after_cap`` is *not* slept + off inline — it returns ``False`` so the failure escalates to a + resumable interruption instead of blocking the call for minutes. + """ + if attempt > self.max_retries: + return False + return retry_after is None or retry_after <= self.retry_after_cap + + def backoff(self, attempt: int, retry_after: float | None) -> float: + """Seconds to wait before retry ``attempt`` (1-based).""" + if retry_after is not None: + return retry_after + ceiling = min(self.max_backoff, self.base_backoff * 2 ** (attempt - 1)) + return random.uniform(0.0, ceiling) + + +# Default for direct ``ChunkedCall`` / ``ChunkPlan.execute`` construction +# (and tests): no retrying. The production decorator path explicitly passes +# ``RetryPolicy.from_env()`` so retries are on by default there. +_NO_RETRY = RetryPolicy(max_retries=0) + + # Client shared across all sub-requests of a single chunked call so # paginated-loop helpers downstream (``_walk_pages``) reuse one -# connection pool across the whole call. ``None`` when not inside a +# connection pool across the whole fan-out. ``None`` when not inside a # chunked call — paginated helpers fall back to their own short-lived # client in that case. _chunked_client: ContextVar[httpx.Client | None] = ContextVar( "_chunked_client", default=None ) +# Async sibling of ``_chunked_client``. Published by +# ``_publish_async_client`` during ``_fan_out_async`` so async +# paginated-loop helpers reuse one ``httpx.AsyncClient`` (and its +# connection pool) across every concurrent sub-request of a single +# chunked call. +_chunked_async_client: ContextVar[httpx.AsyncClient | None] = ContextVar( + "_chunked_async_client", default=None +) + @contextmanager def _publish_client(client: httpx.Client) -> Iterator[None]: @@ -117,6 +300,20 @@ def _publish_client(client: httpx.Client) -> Iterator[None]: _chunked_client.reset(token) +@contextmanager +def _publish_async_client(client: httpx.AsyncClient) -> Iterator[None]: + """ + Make ``client`` visible to :func:`get_active_async_client` for the + duration of the ``with`` block. Async sibling of + :func:`_publish_client`. + """ + token = _chunked_async_client.set(client) + try: + yield + finally: + _chunked_async_client.reset(token) + + def get_active_client() -> httpx.Client | None: """ Return the chunker's currently-published sync client, or ``None``. @@ -134,6 +331,16 @@ def get_active_client() -> httpx.Client | None: return _chunked_client.get() +def get_active_async_client() -> httpx.AsyncClient | None: + """ + Return the chunker's currently-published async client, or ``None``. + + Async sibling of :func:`get_active_client`. Used by async + paginated-loop helpers to reuse the per-call AsyncClient pool. + """ + return _chunked_async_client.get() + + # Separators the two axis kinds use to join their atoms back into # URL text. List axes comma-join values (``site=USGS-A,USGS-B``); the # filter axis OR-joins clauses (``filter=a='1' OR a='2'``). @@ -141,6 +348,9 @@ def get_active_client() -> httpx.Client | None: _OR_SEP = " OR " _FetchOnce = Callable[[dict[str, Any]], tuple[pd.DataFrame, httpx.Response]] +_FetchOnceAsync = Callable[ + [dict[str, Any]], Awaitable[tuple[pd.DataFrame, httpx.Response]] +] class _RetryableTransportError(RuntimeError): @@ -767,7 +977,9 @@ def iter_sub_args(self) -> Iterator[dict[str, Any]]: sub_args[axis.arg_key] = axis.render(chunk) yield sub_args - def execute(self, fetch_once: _FetchOnce) -> tuple[pd.DataFrame, httpx.Response]: + def execute( + self, fetch_once: _FetchOnce, retry_policy: RetryPolicy = _NO_RETRY + ) -> tuple[pd.DataFrame, httpx.Response]: """ Run the plan and return the combined ``(frame, response)``. @@ -779,6 +991,9 @@ def execute(self, fetch_once: _FetchOnce) -> tuple[pd.DataFrame, httpx.Response] fetch_once : Callable Function that issues a single sub-request, given the substituted args dict, and returns ``(frame, response)``. + retry_policy : RetryPolicy, optional + Per-sub-request retry-with-backoff policy. Defaults to + :data:`_NO_RETRY`; the decorator passes ``RetryPolicy.from_env()``. Returns ------- @@ -796,7 +1011,7 @@ def execute(self, fetch_once: _FetchOnce) -> tuple[pd.DataFrame, httpx.Response] :class:`ServiceInterrupted` for 5xx). The resumable handle is on ``exc.call``. """ - return ChunkedCall(self, fetch_once).resume() + return ChunkedCall(self, fetch_once, retry_policy).resume() def _classify_chunk_error( @@ -850,6 +1065,93 @@ def _classify_chunk_error( return None +def _retryable(exc: BaseException) -> tuple[bool, float | None]: + """ + Decide whether ``exc`` is a transient worth an automatic retry. + + Narrower than :func:`_classify_chunk_error`: it retries rate limits + (429), service errors (5xx), and genuine transport transients + (:class:`httpx.TransportError` — ``ConnectError``, ``ReadTimeout``, …) + but NOT :class:`httpx.InvalidURL` (a too-long server cursor URL won't + fix on retry, though it stays *resumable*). Walks the ``__cause__`` + chain because ``_walk_pages`` re-wraps mid-pagination failures as + ``RuntimeError``. + + Returns + ------- + tuple[bool, float or None] + ``(retryable, retry_after)`` — the server ``Retry-After`` hint + (seconds) when the transient carried one, else ``None``. + """ + cur: BaseException | None = exc + while cur is not None: + if isinstance(cur, (RateLimited, ServiceUnavailable)): + return True, cur.retry_after + if isinstance(cur, httpx.TransportError): + return True, None + cur = cur.__cause__ + return False, None + + +# Sleep hooks, indirected through module globals so tests can +# ``monkeypatch.setattr`` them to no-ops instead of waiting for real +# backoff. Production uses the stdlib calls. +_SLEEP = time.sleep +_ASLEEP = asyncio.sleep + + +def _note_retry(attempt: int, wait: float) -> None: + """Surface an imminent retry on the active progress reporter, if any.""" + reporter = _progress.current() + if reporter is not None: + reporter.note_retry(attempt=attempt, wait=wait) + + +def _retry_sync( + fn: Callable[[], tuple[pd.DataFrame, httpx.Response]], + policy: RetryPolicy, +) -> tuple[pd.DataFrame, httpx.Response]: + """ + Call ``fn`` with bounded retry-with-backoff on transient failures. + + On a non-retryable error, or once ``policy`` is exhausted (or the + server's ``Retry-After`` is too long to absorb inline), the last + exception propagates unchanged so the caller's existing handling wraps + it as a resumable :class:`ChunkInterrupted`. + """ + attempt = 0 + while True: + try: + return fn() + except Exception as exc: # noqa: BLE001 — re-raised unless retryable + retryable, retry_after = _retryable(exc) + attempt += 1 + if not retryable or not policy.should_retry(attempt, retry_after): + raise + delay = policy.backoff(attempt, retry_after) + _note_retry(attempt, delay) + _SLEEP(delay) + + +async def _retry_async( + afn: Callable[[], Awaitable[tuple[pd.DataFrame, httpx.Response]]], + policy: RetryPolicy, +) -> tuple[pd.DataFrame, httpx.Response]: + """Async sibling of :func:`_retry_sync` (awaits :func:`asyncio.sleep`).""" + attempt = 0 + while True: + try: + return await afn() + except Exception as exc: # noqa: BLE001 — re-raised unless retryable + retryable, retry_after = _retryable(exc) + attempt += 1 + if not retryable or not policy.should_retry(attempt, retry_after): + raise + delay = policy.backoff(attempt, retry_after) + _note_retry(attempt, delay) + await _ASLEEP(delay) + + def _combine_chunk_frames(frames: list[pd.DataFrame]) -> pd.DataFrame: """ Concatenate per-chunk frames, dropping empties and deduping by ``id``. @@ -989,9 +1291,11 @@ class ChunkedCall: :meth:`resume` is idempotent: it iterates :meth:`ChunkPlan.iter_sub_args` (deterministic order) and skips any index whose result is already in ``self._chunks``. The - completion set is a ``dict[int, (df, response)]`` keyed by - sub-args index; a subsequent ``resume`` only re-issues - sub-requests whose index isn't already present. + completion set is a sparse ``dict[int, (df, response)]`` so the + parallel path can record scattered completions (e.g. indices + [0, 2, 5] after siblings [1, 3, 4] failed) and a subsequent + ``resume`` only re-issues the missing indices — via the serial + sync ``fetch_once`` path. Parameters ---------- @@ -1015,13 +1319,53 @@ class ChunkedCall: when nothing has completed yet (live; recomputed per access). """ - def __init__(self, plan: ChunkPlan, fetch_once: _FetchOnce) -> None: + def __init__( + self, + plan: ChunkPlan, + fetch_once: _FetchOnce, + retry_policy: RetryPolicy = _NO_RETRY, + ) -> None: self.plan = plan self.fetch_once = fetch_once - # Completed (frame, response) pairs keyed by sub-args index; - # ``resume()`` skips indices already present. + self.retry_policy = retry_policy + # Completed (frame, response) pairs keyed by sub-args index. + # Sparse so the parallel fan-out path can record scattered + # completions (e.g. indices [0, 2, 5] when 1/3/4 failed) and a + # subsequent ``resume()`` only re-issues the missing indices. + # On the serial path this fills contiguously from 0. self._chunks: dict[int, tuple[pd.DataFrame, httpx.Response]] = {} + def record(self, index: int, pair: tuple[pd.DataFrame, httpx.Response]) -> None: + """Record a completed sub-request's ``(frame, response)`` pair + under its sub-args index. Used by both the serial loop in + :meth:`resume` and the parallel fan-out in + :func:`_fan_out_async` so the completion set stays + encapsulated.""" + self._chunks[index] = pair + + def wrap_failure(self, exc: BaseException) -> ChunkInterrupted | None: + """Build the matching :class:`ChunkInterrupted` carrying this + call when ``exc`` is a recognized transient transport failure; + return ``None`` for unrecognized failures so the caller can + re-raise. Encapsulates the + ``classify → instantiate-with-call-state`` recipe so + :class:`ChunkedCall`'s private fields stay private.""" + classification = _classify_chunk_error(exc) + if classification is None: + return None + interrupted_class, retry_after = classification + return interrupted_class( + completed_chunks=len(self._chunks), + total_chunks=self.plan.total, + call=self, + retry_after=retry_after, + cause=exc, + ) + + @property + def completed_chunks(self) -> int: + return len(self._chunks) + def _ordered_chunks(self) -> list[tuple[pd.DataFrame, httpx.Response]]: return [self._chunks[i] for i in sorted(self._chunks)] @@ -1078,7 +1422,9 @@ def resume(self) -> tuple[pd.DataFrame, httpx.Response]: Idempotent: only sub-requests whose index isn't already in ``self._chunks`` are re-issued. Sub-args order matches - :meth:`ChunkPlan.iter_sub_args` and is deterministic. + :meth:`ChunkPlan.iter_sub_args` and is deterministic, so a + parallel-mode partial completion (sparse indices) resumes + correctly via the sync path. Returns ------- @@ -1132,24 +1478,148 @@ def _issue(self, index: int, sub_args: dict[str, Any]) -> None: three feed :func:`_classify_chunk_error`. """ try: - self._chunks[index] = self.fetch_once(sub_args) + chunk = _retry_sync(lambda: self.fetch_once(sub_args), self.retry_policy) except (RuntimeError, httpx.HTTPError, httpx.InvalidURL) as exc: - classification = _classify_chunk_error(exc) - if classification is None: + interrupted = self.wrap_failure(exc) + if interrupted is None: raise - interrupted_class, retry_after = classification - raise interrupted_class( - completed_chunks=len(self._chunks), - total_chunks=self.plan.total, - call=self, - retry_after=retry_after, - cause=exc, - ) from exc + raise interrupted from exc + self.record(index, chunk) + + +async def _fan_out_async( + plan: ChunkPlan, + fetch_once: _FetchOnce, + fetch_async: _FetchOnceAsync, + *, + max_concurrent: int | None, + retry_policy: RetryPolicy = _NO_RETRY, +) -> tuple[pd.DataFrame, httpx.Response]: + """ + Execute ``plan`` concurrently under one shared + :class:`httpx.AsyncClient`. + + The fan-out preserves the same resumability contract the serial + :class:`ChunkedCall` path provides: + + * **Resumable interruptions.** Sub-requests run under + ``asyncio.gather`` with ``return_exceptions=True`` so completed + sub-requests survive a sibling's transient failure. On a + recognized transient (:class:`RateLimited`, + :class:`ServiceUnavailable`) a :class:`ChunkInterrupted` + subclass is raised with ``.call`` set to a + :class:`ChunkedCall` carrying the sparse completed sub-args; + ``exc.call.resume()`` re-issues only the unfinished ones via + the sync ``fetch_once`` path. + + In-flight sub-requests are capped by an + :class:`asyncio.Semaphore`; ``max_concurrent=None`` ("unbounded") + uses ``sys.maxsize`` so every call site can take the same + ``async with semaphore`` path. The shared client is published on + :data:`_chunked_async_client` so async paginated-loop helpers + reuse its connection pool. + + Parameters + ---------- + plan : ChunkPlan + Pre-built plan whose sub-args sequence drives the fan-out. + fetch_once : Callable + Sync per-sub-request fetcher. Used to build the resumable + :class:`ChunkedCall` returned via ``ChunkInterrupted.call``; + never invoked by this function directly. + fetch_async : Callable + Async per-sub-request fetcher returning ``(df, response)``. + max_concurrent : int or None + Maximum in-flight sub-requests. ``None`` disables the cap. + + Returns + ------- + df : pandas.DataFrame + Combined data from every sub-request. + response : httpx.Response + Aggregated response (canonical URL, last sub-request's + headers, cumulative elapsed time). + + Raises + ------ + ChunkInterrupted + On a transient sub-request failure. ``.call`` is a + :class:`ChunkedCall` holding the sparse completed sub-requests; + ``.call.resume()`` re-issues the unfinished ones serially. + """ + sub_args_list = list(plan.iter_sub_args()) + + # ``httpx.Limits()`` defaults to ``max_connections=100`` — at + # higher concurrency the pool would silently bottleneck the + # fan-out behind the connection cap. Match it to the semaphore, + # or ``None`` for truly unbounded. + limits = httpx.Limits( + max_connections=max_concurrent, max_keepalive_connections=max_concurrent + ) + # ``sys.maxsize`` stands in for "unbounded": ``asyncio.Semaphore`` + # only decrements a counter, never preallocates slots. + semaphore = asyncio.Semaphore(max_concurrent or sys.maxsize) + call = ChunkedCall(plan, fetch_once, retry_policy) + + async with httpx.AsyncClient(limits=limits, **HTTPX_DEFAULTS) as client: + with _publish_async_client(client): + reporter = _progress.current() + if reporter is not None: + reporter.set_chunks(plan.total) + + async def track( + offset: int, args: dict[str, Any] + ) -> tuple[pd.DataFrame, httpx.Response]: + """One sub-request (with retry) + record + progress tick. + + The retry loop runs *inside* the semaphore, so a chunk + backing off holds its slot — effective concurrency shrinks + under throttling instead of re-bursting against it. + """ + async with semaphore: + result = await _retry_async(lambda: fetch_async(args), retry_policy) + call.record(offset, result) + if reporter is not None: + reporter.start_chunk(call.completed_chunks) + return result + + # Dispatch every sub-request concurrently. ``return_exceptions`` + # keeps completed pairs after a sibling fails, so partial state + # stays recoverable via ``ChunkedCall.resume()``. Failure + # precedence: + # 1. Cancellation / interrupt signals (CancelledError, + # KeyboardInterrupt, SystemExit — non-Exception) propagate + # unmodified; wrapping them as a transient would swallow the + # user's stop signal. + # 2. Recognized transients wrap as ChunkInterrupted so the user + # gets a resumable handle even when a non-transient failure + # landed earlier in submission order. + # 3. Otherwise re-raise the first failure, preserving its type. + results = await asyncio.gather( + *(track(i, args) for i, args in enumerate(sub_args_list)), + return_exceptions=True, + ) + failures = [r for r in results if isinstance(r, BaseException)] + for exc in failures: + if not isinstance(exc, Exception): + raise exc + for exc in failures: + if (interrupted := call.wrap_failure(exc)) is not None: + raise interrupted from exc + if failures: + raise failures[0] + + ordered = call._ordered_chunks() + return ( + _combine_chunk_frames([df for df, _ in ordered]), + _combine_chunk_responses([resp for _, resp in ordered], plan.canonical_url), + ) def multi_value_chunked( *, build_request: Callable[..., httpx.Request], + fetch_async: _FetchOnceAsync | None = None, url_limit: int | None = None, ) -> Callable[[_FetchOnce], _FetchOnce]: """ @@ -1161,12 +1631,24 @@ def multi_value_chunked( single-step plan, so the decorated function has one code path either way. + When ``API_USGS_CONCURRENT`` resolves to a parallelism greater than + 1 (the default), the decorator routes execution through + :func:`_fan_out_async` over the provided ``fetch_async``. The + wrapper falls back to the synchronous :class:`ChunkedCall` path + (with a ``UserWarning``) when ``fetch_async`` wasn't wired or + when an asyncio event loop is already running (Jupyter / IPython / + async apps where ``asyncio.run`` would raise ``RuntimeError``). + Parameters ---------- build_request : Callable[..., httpx.Request] Factory that turns a kwargs dict into a sized httpx request, e.g. ``_construct_api_requests``. Called during planning to measure each candidate plan. + fetch_async : Callable, optional + Async sibling of the decorated sync fetcher. Used when + ``API_USGS_CONCURRENT`` resolves to >1; if omitted, the + wrapper warns and stays on the serial path. url_limit : int, optional Byte budget for the request (URL + body). When ``None`` (default), the module-level ``_WATERDATA_URL_BYTE_LIMIT`` is @@ -1202,8 +1684,78 @@ def wrapper( ) -> tuple[pd.DataFrame, httpx.Response]: limit = _WATERDATA_URL_BYTE_LIMIT if url_limit is None else url_limit plan = ChunkPlan(args, build_request, limit) - return plan.execute(fetch_once) + concurrency = _read_concurrency_env() + retry_policy = RetryPolicy.from_env() + + # Trivial plans and explicit opt-outs stay on the sync + # path; ``_execute_in_parallel`` owns the rest of the + # serial/parallel decision (async wiring, running loop). + if plan.total <= 1 or concurrency == 1: + return plan.execute(fetch_once, retry_policy) + return _execute_in_parallel( + plan, fetch_once, fetch_async, concurrency, retry_policy + ) return wrapper return decorator + + +def _execute_in_parallel( + plan: ChunkPlan, + fetch_once: _FetchOnce, + fetch_async: _FetchOnceAsync | None, + concurrency: int | None, + retry_policy: RetryPolicy = _NO_RETRY, +) -> tuple[pd.DataFrame, httpx.Response]: + """ + Run ``plan`` on the parallel async path, falling back to the + serial sync path when the runtime can't host an event loop. + + Falls back (with a one-time :class:`UserWarning`) when: + + * ``fetch_async`` wasn't wired into the decorator, or + * an asyncio event loop is already running (Jupyter / IPython + kernels, async apps — ``asyncio.run`` would raise). + + Otherwise opens a fresh event loop via :func:`asyncio.run` and + drives :func:`_fan_out_async`. + """ + if fetch_async is None: + warnings.warn( + f"{_CONCURRENCY_ENV} is set to {concurrency} but this " + f"call site has no async fetch sibling wired; falling " + f"back to the serial path. Either set " + f"{_CONCURRENCY_ENV}=1 to silence this warning or pass " + f"fetch_async= to @multi_value_chunked.", + UserWarning, + stacklevel=3, + ) + return plan.execute(fetch_once, retry_policy) + if _running_event_loop() is not None: + warnings.warn( + "Detected a running asyncio event loop; the parallel " + f"chunker path cannot run inside one. Falling back to " + f"the serial path. Set {_CONCURRENCY_ENV}=1 to silence " + f"this warning.", + UserWarning, + stacklevel=3, + ) + return plan.execute(fetch_once, retry_policy) + return asyncio.run( + _fan_out_async( + plan, + fetch_once, + fetch_async, + max_concurrent=concurrency, + retry_policy=retry_policy, + ) + ) + + +def _running_event_loop() -> asyncio.AbstractEventLoop | None: + """Return the active asyncio event loop, or ``None`` when none.""" + try: + return asyncio.get_running_loop() + except RuntimeError: + return None diff --git a/dataretrieval/waterdata/utils.py b/dataretrieval/waterdata/utils.py index 66ed1723..f8475957 100644 --- a/dataretrieval/waterdata/utils.py +++ b/dataretrieval/waterdata/utils.py @@ -7,12 +7,14 @@ import os import re from collections.abc import ( + AsyncIterator, + Awaitable, Callable, Iterable, Iterator, Mapping, ) -from contextlib import contextmanager +from contextlib import asynccontextmanager, contextmanager from datetime import datetime, timedelta from typing import Any, TypeVar, get_args from zoneinfo import ZoneInfo @@ -28,6 +30,7 @@ RateLimited, ServiceUnavailable, _safe_elapsed, + get_active_async_client, get_active_client, ) from dataretrieval.waterdata.types import ( @@ -837,6 +840,29 @@ def _client_for(client: httpx.Client | None) -> Iterator[httpx.Client]: yield new +@asynccontextmanager +async def _client_for_async( + client: httpx.AsyncClient | None, +) -> AsyncIterator[httpx.AsyncClient]: + """ + Yield a usable async client, picking the best available source. + Async sibling of :func:`_client_for`. + + Resolution order matches the sync version: explicit caller-owned + ``AsyncClient`` first, the chunker's shared async client next, a + fresh short-lived ``AsyncClient`` last. + """ + if client is not None: + yield client + return + shared = get_active_async_client() + if shared is not None: + yield shared + return + async with httpx.AsyncClient(**HTTPX_DEFAULTS) as new: + yield new + + def _aggregate_paginated_response( initial: httpx.Response, last: httpx.Response, @@ -998,14 +1024,86 @@ def _paginate( return pd.concat(dfs, ignore_index=True), final_response +async def _paginate_async( + initial_req: httpx.Request, + *, + parse_response: Callable[[httpx.Response], tuple[pd.DataFrame, _Cursor | None]], + follow_up: Callable[[_Cursor, httpx.AsyncClient], Awaitable[httpx.Response]], + client: httpx.AsyncClient | None = None, +) -> tuple[pd.DataFrame, httpx.Response]: + """ + Drive a paginated request to completion over an + :class:`httpx.AsyncClient`. Async sibling of :func:`_paginate`. + + Runs the same per-page loop but issues HTTP asynchronously so + multiple sub-requests of a chunked call can run concurrently from + :func:`_fan_out_async`. + """ + logger.debug("Requesting: %s", initial_req.url) + reporter = _progress.current() + async with _client_for_async(client) as sess: + resp = await sess.send(initial_req) + _raise_for_non_200(resp) + initial_response = resp + total_elapsed = _safe_elapsed(resp) + + try: + df, cursor = parse_response(resp) + except Exception as e: # noqa: BLE001 + # Mirror the sync path: initial-page parse failures + # (malformed JSON, missing ``features``, schema drift) + # get the same wrapped-message treatment as follow-up + # failures so callers see a consistent diagnostic + # regardless of which page broke. + logger.warning("Initial response parse failed.") + raise RuntimeError(_paginated_failure_message(0, e)) from e + dfs = [df] + if reporter is not None: + reporter.set_rate_remaining( + resp.headers.get(_QUOTA_HEADER), + limit=resp.headers.get("x-ratelimit-limit"), + ) + reporter.add_page(rows=len(df)) + while cursor is not None: + try: + resp = await follow_up(cursor, sess) + _raise_for_non_200(resp) + df, cursor = parse_response(resp) + dfs.append(df) + total_elapsed += _safe_elapsed(resp) + if reporter is not None: + reporter.set_rate_remaining( + resp.headers.get(_QUOTA_HEADER), + limit=resp.headers.get("x-ratelimit-limit"), + ) + reporter.add_page(rows=len(df)) + except Exception as e: # noqa: BLE001 + logger.warning( + "Request failed at cursor %r. Data download interrupted.", + cursor, + ) + raise RuntimeError(_paginated_failure_message(len(dfs), e)) from e + + # Aggregate headers / elapsed onto a COPY of the initial + # response so the user's caller never sees an in-place + # mutation of the response object they may have inspected + # mid-pagination via a hook or test fixture. + final_response = _aggregate_paginated_response( + initial_response, resp, total_elapsed + ) + return pd.concat(dfs, ignore_index=True), final_response + + def _ogc_parse_response( resp: httpx.Response, *, geopd: bool ) -> tuple[pd.DataFrame, str | None]: """Parse one OGC API page: extract the DataFrame and the next-page URL. - Coerces falsy cursors (empty href, etc.) to ``None`` so the - paginate loop's ``while cursor is not None`` terminates instead - of spinning on a meaningless value. + Shared between :func:`_walk_pages` and :func:`_walk_pages_async` + since the parse step is identical on either path. Coerces falsy + cursors (empty href, etc.) to ``None`` so the paginate loop's + ``while cursor is not None`` terminates instead of spinning on a + meaningless value. """ body = resp.json() return ( @@ -1069,6 +1167,31 @@ def follow_up(cursor: str, client: httpx.Client) -> httpx.Response: ) +async def _walk_pages_async( + geopd: bool, + req: httpx.Request, + client: httpx.AsyncClient | None = None, +) -> tuple[pd.DataFrame, httpx.Response]: + """ + Iterate paginated OGC API responses asynchronously and aggregate + them into one DataFrame. Async sibling of :func:`_walk_pages`; + delegates to :func:`_paginate_async`. + """ + method = req.method + headers = req.headers + content = req.content if method == "POST" else None + + async def follow_up(cursor: str, sess: httpx.AsyncClient) -> httpx.Response: + return await sess.request(method, cursor, headers=headers, content=content) + + return await _paginate_async( + req, + parse_response=functools.partial(_ogc_parse_response, geopd=geopd), + follow_up=follow_up, + client=client, + ) + + def _deal_with_empty( return_list: pd.DataFrame, properties: list[str] | None, service: str ) -> pd.DataFrame: @@ -1290,8 +1413,19 @@ def get_ogc_data( return return_list, BaseMetadata(response) +async def _fetch_once_async( + args: dict[str, Any], +) -> tuple[pd.DataFrame, httpx.Response]: + """Send one prepared-args OGC request asynchronously; return the + frame + response. Async sibling of :func:`_fetch_once` used by the + parallel chunker.""" + req = _construct_api_requests(**args) + return await _walk_pages_async(geopd=GEOPANDAS, req=req) + + @chunking.multi_value_chunked( build_request=_construct_api_requests, + fetch_async=_fetch_once_async, ) def _fetch_once( args: dict[str, Any], @@ -1302,7 +1436,10 @@ def _fetch_once( parameter and the cql-text filter as a chunkable axis, greedy-halves the biggest chunk across all axes until each sub-request URL fits, and iterates the cartesian product. With no chunkable inputs the - decorator passes args through unchanged. The return shape + decorator passes args through unchanged. When ``API_USGS_CONCURRENT`` + is >1 (the default), the decorator routes execution through + :func:`_fetch_once_async` so the sub-requests run concurrently under + one shared :class:`httpx.AsyncClient`. Either way the return shape is ``(frame, response)``. """ req = _construct_api_requests(**args) diff --git a/tests/conftest.py b/tests/conftest.py index afbdfec2..5eb46cb8 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,10 +1,15 @@ """ Test scaffolding for the dataretrieval test suite. -Relaxes ``pytest-httpx``'s strict-mode flags so unconsumed mocks and -unmatched real requests don't fail the suite (matches the historical -``requests-mock``-style permissiveness the test code was written -against, and keeps mocked-URL setup terse). +* Relaxes ``pytest-httpx``'s strict-mode flags so unconsumed mocks and + unmatched real requests don't fail the suite (matches the historical + ``requests-mock``-style permissiveness the test code was written + against, and keeps mocked-URL setup terse). +* Pins ``API_USGS_CONCURRENT=1`` and ``API_USGS_RETRIES=0`` for every + test by default so the historical mocked suite stays on the + deterministic serial chunker path and a single transient surfaces + immediately (no backoff). Async-mode and retry tests opt in by + re-setting the env vars inside their body via ``monkeypatch.setenv``. """ from __future__ import annotations @@ -30,3 +35,21 @@ def non_mocked_hosts() -> list[str]: """No hosts are exempted from mocking; every HTTP call must hit a mock registered through the ``httpx_mock`` fixture.""" return [] + + +@pytest.fixture(autouse=True) +def _serial_chunker(monkeypatch): + """Default every test to the serial, no-retry chunker path. + + Production defaults ``API_USGS_CONCURRENT`` to 16 (parallel + fan-out) and ``API_USGS_RETRIES`` to 4, but the historical tests + assume sequential, deterministic sub-request ordering — and they + mock the sync ``_walk_pages`` rather than the async sibling, and + expect a single transient to surface immediately rather than be + retried. Pinning ``API_USGS_CONCURRENT=1`` and ``API_USGS_RETRIES=0`` + keeps the test surface focused on the planner / fetch contracts; + async-mode and retry tests opt in by overriding the env inside + their body. + """ + monkeypatch.setenv("API_USGS_CONCURRENT", "1") + monkeypatch.setenv("API_USGS_RETRIES", "0") diff --git a/tests/waterdata_chunking_test.py b/tests/waterdata_chunking_test.py index 21b23757..ee129aaa 100644 --- a/tests/waterdata_chunking_test.py +++ b/tests/waterdata_chunking_test.py @@ -15,11 +15,13 @@ and then fail in production. """ +import asyncio import datetime import sys from unittest import mock from urllib.parse import quote_plus +import httpx import pandas as pd import pytest @@ -36,10 +38,14 @@ QuotaExhausted, RateLimited, RequestTooLarge, + RetryPolicy, ServiceInterrupted, ServiceUnavailable, _chunked_client, _extract_axes, + _retry_async, + _retry_sync, + _retryable, multi_value_chunked, ) from dataretrieval.waterdata.utils import _construct_api_requests @@ -1202,6 +1208,195 @@ def test_iter_sub_args_passthrough_yields_a_copy(): assert "new_key" not in plan.args +# --- async fan-out path ---------------------------------------------------- +# +# The conftest's ``_serial_chunker`` autouse pins ``API_USGS_CONCURRENT=1`` +# for the whole suite. Each test below overrides it so the wrapper takes +# the parallel branch. The decorator's ``fetch_async`` accepts any +# coroutine returning ``(df, response)`` — no real ``httpx.AsyncClient`` +# round-trip occurs, even though :func:`_fan_out_async` opens one for +# pool management. + + +def _async_chunked_fetch(monkeypatch, fetch_async, *, max_concurrent=16): + """Decorate a deterministic chunkable fetch with the parallel + path forced on via ``API_USGS_CONCURRENT``.""" + monkeypatch.setenv("API_USGS_CONCURRENT", str(max_concurrent)) + + @multi_value_chunked( + build_request=_fake_build, fetch_async=fetch_async, url_limit=240 + ) + def fetch(args): + # Sync sibling — invoked on resume() after a parallel failure + # and never during the happy parallel path. + return pd.DataFrame({"id": [_atom_id(args)]}), _ok_response() + + return fetch + + +def _atom_id(args): + """Build a deterministic id for a sub-args dict — used as the dedup key.""" + return ",".join(args["sites"]) if isinstance(args["sites"], list) else args["sites"] + + +def _ok_response(remaining=None): + headers = {} if remaining is None else {_QUOTA_HEADER: str(remaining)} + return mock.Mock(elapsed=datetime.timedelta(seconds=0.1), headers=headers) + + +def test_async_fan_out_emits_one_call_per_sub_request(monkeypatch): + """Parallel mode hits every sub-args once — same coverage as the + sync ``ChunkedCall`` path, just dispatched concurrently.""" + seen_args = [] + + async def fetch_async(args): + seen_args.append(tuple(args["sites"])) + return pd.DataFrame({"id": [_atom_id(args)]}), _ok_response() + + fetch = _async_chunked_fetch(monkeypatch, fetch_async) + + df, _ = fetch({"sites": ["S1" * 10, "S2" * 10, "S3" * 10, "S4" * 10]}) + + # Planner halves the 4-site list, so 2 sub-args → 2 async calls. + assert len(seen_args) > 1 + # Every sub-args atom is union-recovered. + assert sorted({s for tup in seen_args for s in tup}) == sorted( + ["S1" * 10, "S2" * 10, "S3" * 10, "S4" * 10] + ) + # Frames concat to one row per sub-request id, in deterministic order. + assert len(df) == len(seen_args) + + +def test_async_fan_out_failure_yields_resumable_call(monkeypatch): + """A transient 5xx mid-fan-out raises ``ServiceInterrupted`` whose + ``.call`` is a ``ChunkedCall`` holding the completed sub-requests + in a sparse index map. ``exc.call.resume()`` re-issues only the + unfinished sub-requests, via the sync ``fetch_once`` path.""" + call_count = {"async": 0, "sync": 0} + + async def fetch_async(args): + call_count["async"] += 1 + # First sub-request succeeds; siblings fail. + if call_count["async"] == 1: + return pd.DataFrame({"id": [_atom_id(args)]}), _ok_response(remaining=99) + raise ServiceUnavailable("503: simulated") + + monkeypatch.setenv("API_USGS_CONCURRENT", "16") + + @multi_value_chunked( + build_request=_fake_build, fetch_async=fetch_async, url_limit=240 + ) + def fetch(args): + call_count["sync"] += 1 + return pd.DataFrame({"id": [_atom_id(args)]}), _ok_response(remaining=99) + + with pytest.raises(ServiceInterrupted) as exc_info: + fetch({"sites": ["S1" * 10, "S2" * 10, "S3" * 10, "S4" * 10]}) + + interrupted = exc_info.value + assert interrupted.call is not None, "parallel-mode interruption must be resumable" + # First sub-request completed; the rest still owe. + assert interrupted.completed_chunks == 1 + assert interrupted.total_chunks > 1 + + # Resume on the sync path picks up only the missing sub-requests. + sync_before = call_count["sync"] + df, _ = interrupted.call.resume() + sync_calls_on_resume = call_count["sync"] - sync_before + assert sync_calls_on_resume == interrupted.total_chunks - 1 + # Final frame unions every sub-args. + assert len(df) == interrupted.total_chunks + + +@pytest.mark.parametrize( + "fallback_trigger,warning_match", + [ + pytest.param("running_loop", "running asyncio event loop", id="running-loop"), + pytest.param("no_fetch_async", "no async fetch sibling", id="missing-async"), + ], +) +def test_async_falls_back_to_serial_with_warning( + monkeypatch, fallback_trigger, warning_match +): + """The parallel path falls back to the serial ``ChunkedCall`` + (with a ``UserWarning``) in two situations: + + * a running asyncio event loop (Jupyter / IPython kernels, async + apps) — ``asyncio.run`` would otherwise raise ``RuntimeError``; + * the decorator wasn't wired with a ``fetch_async=`` sibling — + ``API_USGS_CONCURRENT`` would otherwise be a silent no-op. + """ + sync_calls = [] + monkeypatch.setenv("API_USGS_CONCURRENT", "16") + + if fallback_trigger == "running_loop": + + async def fetch_async(args): + raise AssertionError("parallel path must not run inside an active loop") + + @multi_value_chunked( + build_request=_fake_build, fetch_async=fetch_async, url_limit=240 + ) + def fetch(args): + sync_calls.append(tuple(args["sites"])) + return pd.DataFrame({"id": [_atom_id(args)]}), _ok_response() + + async def driver(): + with pytest.warns(UserWarning, match=warning_match): + return fetch({"sites": ["S1" * 10, "S2" * 10, "S3" * 10, "S4" * 10]}) + + df, _ = asyncio.run(driver()) + else: + + @multi_value_chunked(build_request=_fake_build, url_limit=240) + def fetch(args): + sync_calls.append(tuple(args["sites"])) + return pd.DataFrame({"id": [_atom_id(args)]}), _ok_response() + + with pytest.warns(UserWarning, match=warning_match): + df, _ = fetch({"sites": ["S1" * 10, "S2" * 10, "S3" * 10, "S4" * 10]}) + + assert len(sync_calls) > 1 + assert len(df) == len(sync_calls) + + +def test_async_fan_out_cancellation_wins_over_transient_sibling(monkeypatch): + """``asyncio.CancelledError`` raised by any sub-request must + propagate unmodified, even when a sibling raises a recognized + transient (which would otherwise wrap as a resumable + :class:`ChunkInterrupted`). Cancellation is asyncio's abort + signal — letting a transient-classification path consume it + would silently swallow the user's stop request. + + fetch_async has no ``await`` inside its body, so gather schedules + the tasks in submission order and each runs synchronously to its + raise — making ``call_count`` deterministic for this test: + 1 = probe, 2 = first fan-out task (transient), 3 = second + fan-out task (cancellation). + """ + call_count = {"async": 0} + + async def fetch_async(args): + call_count["async"] += 1 + if call_count["async"] == 1: + return pd.DataFrame({"id": [_atom_id(args)]}), _ok_response(remaining=99) + if call_count["async"] == 2: + raise ServiceUnavailable("503: transient sibling") + if call_count["async"] == 3: + raise asyncio.CancelledError("user cancel") + return pd.DataFrame({"id": [_atom_id(args)]}), _ok_response(remaining=99) + + fetch = _async_chunked_fetch(monkeypatch, fetch_async) + + # 8 × 20-byte sites force the planner to >=3 sub-args under + # url_limit=240, so the fan-out gather sees at least the + # transient (call 2) AND the cancellation (call 3). + sites = [f"S{i}" * 10 for i in range(1, 9)] + + with pytest.raises(asyncio.CancelledError): + fetch({"sites": sites}) + + def test_combine_chunk_responses_does_not_mutate_input_urls(): """Regression for the _set_response_url aliasing bug. @@ -1230,3 +1425,228 @@ def test_combine_chunk_responses_does_not_mutate_input_urls(): assert str(r2.url) == "https://example.com/chunk1" assert str(req1.url) == "https://example.com/chunk0" assert str(req2.url) == "https://example.com/chunk1" + + +# --------------------------------------------------------------------------- +# Retry-with-backoff: RetryPolicy + _retryable + drivers + decorator wiring. +# Conftest pins API_USGS_RETRIES=0, so these tests opt in explicitly and +# patch chunking._SLEEP / chunking._ASLEEP to no-ops (no real backoff). +# --------------------------------------------------------------------------- + + +def _wrap_cause(transport_exc): + """Wrap ``transport_exc`` the way ``_walk_pages`` does — a + ``RuntimeError`` with the typed transport error on ``__cause__`` — so + chain-walking is exercised realistically.""" + try: + raise RuntimeError("Paginated request failed") from transport_exc + except RuntimeError as wrapped: + return wrapped + + +# -- RetryPolicy (pure value object) ---------------------------------------- + + +def test_retry_policy_backoff_honors_retry_after(): + policy = RetryPolicy() + # A server Retry-After overrides the computed backoff verbatim. + assert policy.backoff(attempt=1, retry_after=7.5) == 7.5 + assert policy.backoff(attempt=4, retry_after=2.0) == 2.0 + + +def test_retry_policy_backoff_full_jitter_within_ceiling(): + policy = RetryPolicy(base_backoff=2.0, max_backoff=30.0) + for attempt, ceiling in [(1, 2.0), (2, 4.0), (3, 8.0), (5, 30.0)]: + samples = [policy.backoff(attempt, None) for _ in range(200)] + assert all(0.0 <= s <= ceiling for s in samples) + # Full jitter genuinely varies and reaches below the ceiling. + assert min(samples) < ceiling + + +def test_retry_policy_should_retry_exhaustion(): + policy = RetryPolicy(max_retries=2) + assert policy.should_retry(attempt=1, retry_after=None) + assert policy.should_retry(attempt=2, retry_after=None) + assert not policy.should_retry(attempt=3, retry_after=None) + + +def test_retry_policy_long_retry_after_escalates(): + policy = RetryPolicy(max_retries=5, retry_after_cap=60.0) + assert policy.should_retry(attempt=1, retry_after=30.0) # absorbed inline + assert not policy.should_retry(attempt=1, retry_after=120.0) # escalates + + +def test_retry_policy_from_env(monkeypatch): + monkeypatch.setenv("API_USGS_RETRIES", "2") + assert RetryPolicy.from_env().max_retries == 2 + monkeypatch.setenv("API_USGS_RETRIES", "0") + assert RetryPolicy.from_env().max_retries == 0 + monkeypatch.delenv("API_USGS_RETRIES", raising=False) + assert RetryPolicy.from_env().max_retries == _chunking._RETRIES_DEFAULT + monkeypatch.setenv("API_USGS_RETRIES", "-1") + with pytest.raises(ValueError): + RetryPolicy.from_env() + monkeypatch.setenv("API_USGS_RETRIES", "lots") + with pytest.raises(ValueError): + RetryPolicy.from_env() + + +# -- _retryable taxonomy ---------------------------------------------------- + + +def test_retryable_taxonomy(): + assert _retryable(RateLimited("429", retry_after=5.0)) == (True, 5.0) + assert _retryable(ServiceUnavailable("503")) == (True, None) + assert _retryable(httpx.ReadTimeout("slow")) == (True, None) + assert _retryable(httpx.ConnectError("down")) == (True, None) + # InvalidURL is resumable but NOT retryable (a too-long cursor won't fix). + assert _retryable(httpx.InvalidURL("too long")) == (False, None) + # Plain non-transient (e.g. a 4xx programmer error wrapped as RuntimeError). + assert _retryable(RuntimeError("400")) == (False, None) + + +def test_retryable_walks_cause_chain(): + assert _retryable(_wrap_cause(RateLimited("429", retry_after=3.0))) == (True, 3.0) + + +# -- sync driver ------------------------------------------------------------ + + +def test_retry_sync_transient_then_success(monkeypatch): + monkeypatch.setattr(_chunking, "_SLEEP", lambda _d: None) + calls = {"n": 0} + + def fn(): + calls["n"] += 1 + if calls["n"] <= 2: + raise RateLimited("429") + return "ok" + + assert _retry_sync(fn, RetryPolicy(max_retries=3, base_backoff=0.0)) == "ok" + assert calls["n"] == 3 # two failures + one success + + +def test_retry_sync_exhausted_reraises(monkeypatch): + monkeypatch.setattr(_chunking, "_SLEEP", lambda _d: None) + calls = {"n": 0} + + def fn(): + calls["n"] += 1 + raise ServiceUnavailable("503") + + with pytest.raises(ServiceUnavailable): + _retry_sync(fn, RetryPolicy(max_retries=2, base_backoff=0.0)) + assert calls["n"] == 3 # first attempt + 2 retries + + +def test_retry_sync_non_retryable_not_retried(monkeypatch): + slept: list[float] = [] + monkeypatch.setattr(_chunking, "_SLEEP", slept.append) + calls = {"n": 0} + + def fn(): + calls["n"] += 1 + raise RuntimeError("400: bad request") + + with pytest.raises(RuntimeError): + _retry_sync(fn, RetryPolicy(max_retries=3)) + assert calls["n"] == 1 and slept == [] + + +def test_retry_sync_long_retry_after_escalates(monkeypatch): + slept: list[float] = [] + monkeypatch.setattr(_chunking, "_SLEEP", slept.append) + calls = {"n": 0} + + def fn(): + calls["n"] += 1 + raise RateLimited("429", retry_after=999.0) + + with pytest.raises(RateLimited): + _retry_sync(fn, RetryPolicy(max_retries=5, retry_after_cap=60.0)) + assert calls["n"] == 1 and slept == [] # no inline wait for a long window + + +# -- async driver ----------------------------------------------------------- + + +def test_retry_async_transient_then_success(monkeypatch): + async def _noslept(_d): + return None + + monkeypatch.setattr(_chunking, "_ASLEEP", _noslept) + calls = {"n": 0} + + async def afn(): + calls["n"] += 1 + if calls["n"] == 1: + raise httpx.ReadTimeout("slow") + return "ok" + + out = asyncio.run(_retry_async(afn, RetryPolicy(max_retries=3, base_backoff=0.0))) + assert out == "ok" and calls["n"] == 2 + + +# -- end-to-end through the decorator -------------------------------------- + + +def test_chunker_retries_transient_then_completes(monkeypatch): + """A transient on one sub-request is retried transparently; the + decorated call completes with no ChunkInterrupted.""" + monkeypatch.setenv("API_USGS_RETRIES", "3") + monkeypatch.setattr(_chunking, "_SLEEP", lambda _d: None) + state = {"failed": False} + + def fetch(args): + # Fail the first sub-request once, then succeed everywhere. + if not state["failed"]: + state["failed"] = True + raise RateLimited("429: Too many requests made.") + return pd.DataFrame({"sites": list(args["sites"])}), _quota_response(500) + + decorated = multi_value_chunked(build_request=_fake_build, url_limit=240)(fetch) + sites = ["S1" * 10, "S2" * 10, "S3" * 10, "S4" * 10] + df, _ = decorated({"sites": sites}) + assert sorted(df["sites"]) == sorted(sites) # all recovered despite the 429 + + +def test_chunker_exhausted_retries_still_resumable(monkeypatch): + """When retries are exhausted the failure still surfaces as a + resumable ChunkInterrupted — retries don't swallow the escape hatch.""" + monkeypatch.setenv("API_USGS_RETRIES", "2") + monkeypatch.setattr(_chunking, "_SLEEP", lambda _d: None) + attempts = {"n": 0} + + def fetch(args): + sites = list(args["sites"]) + if "S1" * 10 in sites: + attempts["n"] += 1 + raise ServiceUnavailable("503: service unavailable") + return pd.DataFrame({"sites": sites}), _quota_response(500) + + decorated = multi_value_chunked(build_request=_fake_build, url_limit=240)(fetch) + with pytest.raises(ServiceInterrupted) as excinfo: + decorated({"sites": ["S1" * 10, "S2" * 10, "S3" * 10, "S4" * 10]}) + assert excinfo.value.call is not None + assert attempts["n"] == 3 # first attempt + 2 retries before giving up + + +def test_async_fan_out_retries_transient_then_completes(monkeypatch): + """The parallel path retries a transient sub-request and completes.""" + monkeypatch.setenv("API_USGS_RETRIES", "3") + + async def _noslept(_d): + return None + + monkeypatch.setattr(_chunking, "_ASLEEP", _noslept) + state = {"failed": False} + + async def fetch_async(args): + if not state["failed"]: + state["failed"] = True + raise ServiceUnavailable("503: transient") + return pd.DataFrame({"id": [_atom_id(args)]}), _ok_response() + + fetch = _async_chunked_fetch(monkeypatch, fetch_async) + df, _ = fetch({"sites": ["S1" * 10, "S2" * 10, "S3" * 10, "S4" * 10]}) + assert len(df) > 1 # every sub-args atom recovered after the retry diff --git a/tests/waterdata_progress_test.py b/tests/waterdata_progress_test.py index faa61630..30be56a2 100644 --- a/tests/waterdata_progress_test.py +++ b/tests/waterdata_progress_test.py @@ -65,6 +65,26 @@ def test_page_count_is_pluralized(): assert "2 pages" in stream.getvalue() +def test_note_retry_renders_then_clears_on_next_page(): + stream = io.StringIO() + reporter = ProgressReporter(stream=stream, enabled=True) + reporter.set_chunks(3) + reporter.start_chunk(1) + reporter.note_retry(attempt=2, wait=8.0) + assert "retrying (attempt 2, waiting 8s)" in stream.getvalue() + # The next page redraws without the note (last frame is after the + # final carriage return). + reporter.add_page(rows=5) + assert "retrying" not in stream.getvalue().rsplit("\r", 1)[-1] + + +def test_note_retry_is_noop_when_disabled(): + stream = io.StringIO() + reporter = ProgressReporter(stream=stream, enabled=False) + reporter.note_retry(attempt=1, wait=1.0) + assert stream.getvalue() == "" + + def test_chunk_segment_only_shown_when_multiple_chunks(): single = io.StringIO() reporter = ProgressReporter(stream=single, enabled=True) @@ -363,3 +383,124 @@ def test_broken_progress_stream_does_not_truncate_pagination(): df, _ = _walk_pages(geopd=False, req=req, client=client) assert len(df) == 2 # both pages returned despite the broken progress stream + + +# -- async path integration ---------------------------------------------------- + + +def test_paginate_async_reports_pages_through_active_reporter(monkeypatch): + """The async paginate path must drive the same progress reporter the + sync path does. Pages and rate-limit updates from each completed + page should land via the active ``ProgressReporter``, exactly as + they would on ``_walk_pages``.""" + import asyncio + + from dataretrieval.waterdata.utils import _paginate_async + + resp1 = _resp( + [{"id": "1", "properties": {"v": "a"}}], + next_url="https://example.com/p2", + rate_remaining="4999", + ) + resp2 = _resp([{"id": "2", "properties": {"v": "b"}}], rate_remaining="4998") + + async def parse_response(resp): + body = resp.json() + nxt = next( + (link["href"] for link in body["links"] if link["rel"] == "next"), None + ) + return mock.MagicMock(empty=False, __len__=lambda self: 1), nxt + + # _paginate_async expects parse_response to be sync, like the sync path. + def parse_sync(resp): + body = resp.json() + nxt = next( + (link["href"] for link in body["links"] if link["rel"] == "next"), None + ) + import pandas as pd + + return pd.DataFrame(body["features"]), nxt + + async def follow_up(cursor, sess): + return resp2 + + client = mock.AsyncMock(spec=httpx.AsyncClient) + client.send.return_value = resp1 + + req = mock.MagicMock(spec=httpx.Request) + req.method = "GET" + req.headers = {} + req.url = "https://example.com/p1" + + stream = io.StringIO() + + async def run(): + with progress_context(service="continuous", stream=stream, enabled=True): + df, _ = await _paginate_async( + req, + parse_response=parse_sync, + follow_up=follow_up, + client=client, + ) + return df + + df = asyncio.run(run()) + assert len(df) == 2 + out = stream.getvalue() + assert "Retrieving: continuous ·" in out + assert "2 pages" in out + assert "4,998 requests remaining" in out + assert out.endswith("\n") + + +def test_fan_out_async_sets_chunks_on_active_reporter(monkeypatch): + """``_fan_out_async`` records ``plan.total`` on the active reporter + so the progress line knows how many sub-requests are in flight. + It deliberately does NOT call ``start_chunk`` (which would be + misleading under parallel fan-out — chunks fire concurrently).""" + import asyncio + + import pandas as pd + + from dataretrieval.waterdata.chunking import ChunkPlan, _fan_out_async + + # Fake build_request whose URL length scales with the sites list, + # mirroring the planner's _request_bytes contract. _FakeReq has the + # same shape as httpx.Request for sizing purposes. + class _FakeReq: + __slots__ = ("url", "content") + + def __init__(self, url): + self.url = url + self.content = b"" + + def build(*, sites): + return _FakeReq("x" * (200 + len(",".join(sites)))) + + sites = ["S" * 10 + str(i) for i in range(4)] + plan = ChunkPlan({"sites": sites}, build, url_limit=240) + assert plan.total > 1, "test setup error: plan must fan out" + + async def fetch_async(args): + return pd.DataFrame({"id": [",".join(args["sites"])]}), mock.Mock( + elapsed=__import__("datetime").timedelta(seconds=0.01), + headers={"x-ratelimit-remaining": "999"}, + ) + + def fetch_once(args): # noqa: ARG001 — never invoked on the happy parallel path + raise AssertionError("sync fetch must not run in this test") + + stream = io.StringIO() + + async def run(): + with progress_context(service="daily", stream=stream, enabled=True) as rep: + await _fan_out_async(plan, fetch_once, fetch_async, max_concurrent=4) + return rep.total_chunks, rep.current_chunk + + total_recorded, current_recorded = asyncio.run(run()) + assert total_recorded == plan.total + # Each sub-request that completes bumps current_chunk via + # start_chunk(len(completed)), so by the time the gather finishes + # current_chunk reflects the total number of successful chunks — + # plan.total in the all-success case. + assert current_recorded == plan.total diff --git a/tests/waterdata_utils_test.py b/tests/waterdata_utils_test.py index bb5ece10..413f39c8 100644 --- a/tests/waterdata_utils_test.py +++ b/tests/waterdata_utils_test.py @@ -221,6 +221,37 @@ def test_walk_pages_wraps_initial_page_parse_error(): assert isinstance(excinfo.value.__cause__, json.JSONDecodeError) +def test_walk_pages_async_wraps_initial_page_parse_error(): + """Async sibling of the above. ``_paginate_async`` must wrap an + initial-page parse failure with the same ``RuntimeError`` shape so + callers get a consistent diagnostic across sync and async paths.""" + import asyncio + + from dataretrieval.waterdata.utils import _walk_pages_async + + resp = mock.MagicMock() + resp.status_code = 200 + resp.url = "https://example.com/page1" + resp.json.side_effect = json.JSONDecodeError("Expecting value", "...", 0) + + mock_client = mock.AsyncMock(spec=httpx.AsyncClient) + mock_client.send.return_value = resp + + mock_req = mock.MagicMock(spec=httpx.Request) + mock_req.method = "GET" + mock_req.headers = {} + mock_req.content = b"" + mock_req.url = "https://example.com/page1" + + async def run(): + await _walk_pages_async(geopd=False, req=mock_req, client=mock_client) + + with pytest.raises(RuntimeError, match="Paginated request failed") as excinfo: + asyncio.run(run()) + + assert isinstance(excinfo.value.__cause__, json.JSONDecodeError) + + def test_get_resp_data_handles_missing_features_key(): """Regression: a 200 with ``numberReturned > 0`` but no ``features`` key (real schema-drift shape) used to crash From 1d0acb7912cf7ffdeb2e5be884d1c619b0b2e9fe Mon Sep 17 00:00:00 2001 From: thodson-usgs Date: Tue, 26 May 2026 21:14:41 -0500 Subject: [PATCH 02/16] refactor(waterdata): maintainer polish for chunk retries --- dataretrieval/waterdata/_progress.py | 20 ++- dataretrieval/waterdata/chunking.py | 218 ++++++++++++++++----------- tests/waterdata_chunking_test.py | 68 ++++++++- tests/waterdata_progress_test.py | 20 +++ 4 files changed, 235 insertions(+), 91 deletions(-) diff --git a/dataretrieval/waterdata/_progress.py b/dataretrieval/waterdata/_progress.py index 7104f3af..e529d6d3 100644 --- a/dataretrieval/waterdata/_progress.py +++ b/dataretrieval/waterdata/_progress.py @@ -157,10 +157,17 @@ def add_page(self, rows: int = 0) -> None: def note_retry(self, *, attempt: int, wait: float) -> None: """Show that a sub-request is backing off before retry ``attempt``. - Cleared by the next :meth:`add_page` / :meth:`start_chunk` so the - line returns to normal progress once the retry succeeds. + Cleared by the next :meth:`add_page` / :meth:`start_chunk` (or by + :meth:`close`) so the line returns to normal once the retry resolves. """ - self.retry_note = f"retrying (attempt {attempt}, waiting {wait:.0f}s)" + # Keep sub-second waits explicit (avoid misleading ``0s``) while + # rendering whole-second waits without unnecessary ``.0`` noise. + wait_1dp = round(wait, 1) + if wait_1dp < 1 or not wait_1dp.is_integer(): + secs = f"{wait_1dp:.1f}s" + else: + secs = f"{wait_1dp:.0f}s" + self.retry_note = f"retrying (attempt {attempt}, waiting {secs})" self._render() def set_rate_remaining( @@ -225,6 +232,13 @@ def close(self) -> None: """ if self._closed: return + # A retry note set during the final backoff would otherwise freeze as + # the persisted last line of a call that has since completed or given + # up; clear it and redraw (while still un-closed, so ``_render`` runs) + # so the final state isn't a stale "retrying". + if self.enabled and self._rendered and self.retry_note is not None: + self.retry_note = None + self._render() self._closed = True if not (self.enabled and self._rendered): return diff --git a/dataretrieval/waterdata/chunking.py b/dataretrieval/waterdata/chunking.py index 1e3b429d..2d01614c 100644 --- a/dataretrieval/waterdata/chunking.py +++ b/dataretrieval/waterdata/chunking.py @@ -62,7 +62,7 @@ from contextvars import ContextVar from dataclasses import dataclass from datetime import timedelta -from typing import Any, ClassVar +from typing import Any, ClassVar, TypeVar from urllib.parse import quote_plus import httpx @@ -166,13 +166,14 @@ def _read_concurrency_env() -> int | None: return value -# Retry-with-backoff for transient sub-request failures (429 / 5xx / -# connect-read timeouts). The env var is read at call time so test -# ``monkeypatch.setenv`` takes effect; the timing constants are -# module-level so power users (and tests) can ``monkeypatch.setattr`` -# them. Defaults: 4 retries, 0.5s base doubling under full jitter up to -# a 30s per-attempt ceiling, and honor a server ``Retry-After`` up to -# 60s before escalating to a resumable interruption instead. +# Retry-with-backoff defaults for transient sub-request failures (429 / +# 5xx / connect-read timeouts). All four are resolved at call time by +# ``RetryPolicy.from_env`` (the env var via ``monkeypatch.setenv``, the +# timing constants via ``monkeypatch.setattr`` on this module), so both +# are overridable in tests and by power users. Defaults: 4 retries, 0.5s +# base doubling under full jitter up to a 30s per-attempt ceiling, and +# honor a server ``Retry-After`` up to 60s before escalating to a +# resumable interruption instead. _RETRIES_ENV = "API_USGS_RETRIES" _RETRIES_DEFAULT = 4 _RETRY_BASE_BACKOFF = 0.5 @@ -237,10 +238,31 @@ class RetryPolicy: max_backoff: float = _RETRY_MAX_BACKOFF retry_after_cap: float = _RETRY_AFTER_CAP + def __post_init__(self) -> None: + # Guard the value object's own invariants so a misconfiguration + # fails loudly at construction rather than as a downstream + # ``time.sleep`` ValueError (negative delay) or a silent + # asyncio.sleep-treats-negative-as-zero divergence. + if self.max_retries < 0: + raise ValueError(f"max_retries must be >= 0 (got {self.max_retries}).") + if self.base_backoff < 0 or self.max_backoff < 0 or self.retry_after_cap < 0: + raise ValueError("retry backoff settings must be non-negative.") + @classmethod def from_env(cls) -> RetryPolicy: - """Build a policy, resolving ``max_retries`` from ``API_USGS_RETRIES``.""" - return cls(max_retries=_read_retries_env()) + """Build a policy from the module-level defaults, resolved now. + + ``max_retries`` comes from ``API_USGS_RETRIES``; the timing knobs + are read from the ``_RETRY_*`` module constants at call time (not + the dataclass field defaults, which freeze at class definition) so + ``monkeypatch.setattr`` on those constants takes effect. + """ + return cls( + max_retries=_read_retries_env(), + base_backoff=_RETRY_BASE_BACKOFF, + max_backoff=_RETRY_MAX_BACKOFF, + retry_after_cap=_RETRY_AFTER_CAP, + ) def should_retry(self, attempt: int, retry_after: float | None) -> bool: """Whether a just-failed ``attempt`` (1-based) warrants another try. @@ -276,42 +298,36 @@ def backoff(self, attempt: int, retry_after: float | None) -> float: "_chunked_client", default=None ) -# Async sibling of ``_chunked_client``. Published by -# ``_publish_async_client`` during ``_fan_out_async`` so async -# paginated-loop helpers reuse one ``httpx.AsyncClient`` (and its -# connection pool) across every concurrent sub-request of a single -# chunked call. +# Async sibling of ``_chunked_client``. Published (via :func:`_publish`) +# during ``_fan_out_async`` so async paginated-loop helpers reuse one +# ``httpx.AsyncClient`` (and its connection pool) across every concurrent +# sub-request of a single chunked call. _chunked_async_client: ContextVar[httpx.AsyncClient | None] = ContextVar( "_chunked_async_client", default=None ) - -@contextmanager -def _publish_client(client: httpx.Client) -> Iterator[None]: - """ - Make ``client`` visible to :func:`get_active_client` for the - duration of the ``with`` block via the ``_chunked_client`` - ContextVar. Wraps the set/reset token dance so callers don't have to. - """ - token = _chunked_client.set(client) - try: - yield - finally: - _chunked_client.reset(token) +_ClientT = TypeVar("_ClientT") @contextmanager -def _publish_async_client(client: httpx.AsyncClient) -> Iterator[None]: +def _publish(var: ContextVar[_ClientT | None], client: _ClientT) -> Iterator[None]: """ - Make ``client`` visible to :func:`get_active_async_client` for the - duration of the ``with`` block. Async sibling of - :func:`_publish_client`. + Bind ``client`` to the ContextVar ``var`` for the duration of the + ``with`` block (wrapping the set/reset token dance), so paginated-loop + helpers can borrow the chunker's shared client via + :func:`get_active_client` / :func:`get_active_async_client`. + + Generic over the client type so the sync (:class:`httpx.Client` via + ``_chunked_client``) and async (:class:`httpx.AsyncClient` via + ``_chunked_async_client``) paths share one implementation, while the + ``_ClientT`` type var still lets a type checker reject a var/client + type mismatch. """ - token = _chunked_async_client.set(client) + token = var.set(client) try: yield finally: - _chunked_async_client.reset(token) + var.reset(token) def get_active_client() -> httpx.Client | None: @@ -325,8 +341,8 @@ def get_active_client() -> httpx.Client | None: Returns ------- httpx.Client or None - The client published by :func:`_publish_client` if currently - inside a :class:`ChunkedCall` ``resume`` block; ``None`` otherwise. + The client published via :func:`_publish` if currently inside a + :class:`ChunkedCall` ``resume`` block; ``None`` otherwise. """ return _chunked_client.get() @@ -1069,13 +1085,18 @@ def _retryable(exc: BaseException) -> tuple[bool, float | None]: """ Decide whether ``exc`` is a transient worth an automatic retry. - Narrower than :func:`_classify_chunk_error`: it retries rate limits - (429), service errors (5xx), and genuine transport transients - (:class:`httpx.TransportError` — ``ConnectError``, ``ReadTimeout``, …) - but NOT :class:`httpx.InvalidURL` (a too-long server cursor URL won't - fix on retry, though it stays *resumable*). Walks the ``__cause__`` - chain because ``_walk_pages`` re-wraps mid-pagination failures as - ``RuntimeError``. + Inspects only the *top-level* exception, by design — and so is + deliberately narrower than :func:`_classify_chunk_error`, which walks + the ``__cause__`` chain for resumability. ``_paginate`` raises an + initial-request transient (429 / 5xx / :class:`httpx.TransportError` + such as ``ConnectError`` / ``ReadTimeout``) *raw*, but re-wraps any + mid-pagination failure as a ``RuntimeError``. Retrying only the raw, + top-level transient means we re-issue a sub-request that made no + progress (cheap), while a failure after partial pagination escalates + to the resumable :class:`ChunkInterrupted` instead of being re-walked + from page 1 — which would re-spend the very quota that was exhausted. + ``httpx.InvalidURL`` is excluded (a too-long cursor won't fix on + retry), and it only ever arises on a follow-up page anyway. Returns ------- @@ -1083,13 +1104,10 @@ def _retryable(exc: BaseException) -> tuple[bool, float | None]: ``(retryable, retry_after)`` — the server ``Retry-After`` hint (seconds) when the transient carried one, else ``None``. """ - cur: BaseException | None = exc - while cur is not None: - if isinstance(cur, (RateLimited, ServiceUnavailable)): - return True, cur.retry_after - if isinstance(cur, httpx.TransportError): - return True, None - cur = cur.__cause__ + if isinstance(exc, (RateLimited, ServiceUnavailable)): + return True, exc.retry_after + if isinstance(exc, httpx.TransportError): + return True, None return False, None @@ -1334,6 +1352,10 @@ def __init__( # subsequent ``resume()`` only re-issues the missing indices. # On the serial path this fills contiguously from 0. self._chunks: dict[int, tuple[pd.DataFrame, httpx.Response]] = {} + # Explicit completion order for response-header aggregation. + # Keeping this separate from ``_chunks`` avoids coupling that + # behavior to dict insertion semantics or future write patterns. + self._completion_order: list[int] = [] def record(self, index: int, pair: tuple[pd.DataFrame, httpx.Response]) -> None: """Record a completed sub-request's ``(frame, response)`` pair @@ -1341,6 +1363,8 @@ def record(self, index: int, pair: tuple[pd.DataFrame, httpx.Response]) -> None: :meth:`resume` and the parallel fan-out in :func:`_fan_out_async` so the completion set stays encapsulated.""" + if index not in self._chunks: + self._completion_order.append(index) self._chunks[index] = pair def wrap_failure(self, exc: BaseException) -> ChunkInterrupted | None: @@ -1369,6 +1393,27 @@ def completed_chunks(self) -> int: def _ordered_chunks(self) -> list[tuple[pd.DataFrame, httpx.Response]]: return [self._chunks[i] for i in sorted(self._chunks)] + def _responses_by_completion(self) -> list[httpx.Response]: + # The final element is the most-recently completed sub-request, whose + # headers carry the freshest ``x-ratelimit-remaining`` for aggregation. + return [self._chunks[i][1] for i in self._completion_order] + + def combined(self) -> tuple[pd.DataFrame, httpx.Response]: + """Combine every recorded sub-request into one ``(frame, response)``. + + Frames concatenate in sub-args *index* order (deterministic, + independent of parallel completion order); the aggregated response + takes its headers from the most-recently-*completed* sub-request, so + a fan-out that finished chunks out of index order still surfaces the + latest rate-limit state the server reported rather than a stale one. + """ + return ( + _combine_chunk_frames([frame for frame, _ in self._ordered_chunks()]), + _combine_chunk_responses( + self._responses_by_completion(), self.plan.canonical_url + ), + ) + @property def partial_frame(self) -> pd.DataFrame: """ @@ -1405,7 +1450,7 @@ def partial_response(self) -> httpx.Response | None: if not self._chunks: return None return _combine_chunk_responses( - [resp for _, resp in self._ordered_chunks()], self.plan.canonical_url + self._responses_by_completion(), self.plan.canonical_url ) def resume(self) -> tuple[pd.DataFrame, httpx.Response]: @@ -1443,23 +1488,18 @@ def resume(self) -> tuple[pd.DataFrame, httpx.Response]: is on ``exc.call`` — wait for the underlying condition to clear and call ``exc.call.resume()`` again. """ - with httpx.Client(**HTTPX_DEFAULTS) as client, _publish_client(client): - reporter = _progress.current() - if reporter is not None: - reporter.set_chunks(self.plan.total) - for i, sub_args in enumerate(self.plan.iter_sub_args()): - if i in self._chunks: - continue + with httpx.Client(**HTTPX_DEFAULTS) as client: + with _publish(_chunked_client, client): + reporter = _progress.current() if reporter is not None: - reporter.start_chunk(i + 1) - self._issue(i, sub_args) - ordered = self._ordered_chunks() - frames = [frame for frame, _ in ordered] - responses = [resp for _, resp in ordered] - return ( - _combine_chunk_frames(frames), - _combine_chunk_responses(responses, self.plan.canonical_url), - ) + reporter.set_chunks(self.plan.total) + for i, sub_args in enumerate(self.plan.iter_sub_args()): + if i in self._chunks: + continue + if reporter is not None: + reporter.start_chunk(i + 1) + self._issue(i, sub_args) + return self.combined() def _issue(self, index: int, sub_args: dict[str, Any]) -> None: """ @@ -1556,13 +1596,17 @@ async def _fan_out_async( limits = httpx.Limits( max_connections=max_concurrent, max_keepalive_connections=max_concurrent ) - # ``sys.maxsize`` stands in for "unbounded": ``asyncio.Semaphore`` - # only decrements a counter, never preallocates slots. - semaphore = asyncio.Semaphore(max_concurrent or sys.maxsize) + # ``None`` means "unbounded"; ``sys.maxsize`` stands in for it since + # ``asyncio.Semaphore`` only decrements a counter, never preallocates + # slots. Test ``is None`` explicitly so a stray ``0`` isn't silently + # promoted to unbounded by a falsy-``or``. + semaphore = asyncio.Semaphore( + sys.maxsize if max_concurrent is None else max_concurrent + ) call = ChunkedCall(plan, fetch_once, retry_policy) async with httpx.AsyncClient(limits=limits, **HTTPX_DEFAULTS) as client: - with _publish_async_client(client): + with _publish(_chunked_async_client, client): reporter = _progress.current() if reporter is not None: reporter.set_chunks(plan.total) @@ -1586,15 +1630,16 @@ async def track( # Dispatch every sub-request concurrently. ``return_exceptions`` # keeps completed pairs after a sibling fails, so partial state # stays recoverable via ``ChunkedCall.resume()``. Failure - # precedence: + # precedence, in order: # 1. Cancellation / interrupt signals (CancelledError, # KeyboardInterrupt, SystemExit — non-Exception) propagate # unmodified; wrapping them as a transient would swallow the # user's stop signal. - # 2. Recognized transients wrap as ChunkInterrupted so the user - # gets a resumable handle even when a non-transient failure - # landed earlier in submission order. - # 3. Otherwise re-raise the first failure, preserving its type. + # 2. A non-transient failure (a real bug — unrecognized by + # ``wrap_failure``) surfaces raw, so it isn't masked behind a + # resumable handle for a transient sibling that landed later. + # 3. Only when every failure is a recognized transient do we + # raise the first as a resumable ``ChunkInterrupted``. results = await asyncio.gather( *(track(i, args) for i, args in enumerate(sub_args_list)), return_exceptions=True, @@ -1603,17 +1648,18 @@ async def track( for exc in failures: if not isinstance(exc, Exception): raise exc + first_transient: tuple[ChunkInterrupted, BaseException] | None = None for exc in failures: - if (interrupted := call.wrap_failure(exc)) is not None: - raise interrupted from exc - if failures: - raise failures[0] - - ordered = call._ordered_chunks() - return ( - _combine_chunk_frames([df for df, _ in ordered]), - _combine_chunk_responses([resp for _, resp in ordered], plan.canonical_url), - ) + interrupted = call.wrap_failure(exc) + if interrupted is None: + raise exc + if first_transient is None: + first_transient = (interrupted, exc) + if first_transient is not None: + interrupted, exc = first_transient + raise interrupted from exc + + return call.combined() def multi_value_chunked( diff --git a/tests/waterdata_chunking_test.py b/tests/waterdata_chunking_test.py index ee129aaa..e9500bb4 100644 --- a/tests/waterdata_chunking_test.py +++ b/tests/waterdata_chunking_test.py @@ -1267,6 +1267,25 @@ async def fetch_async(args): assert len(df) == len(seen_args) +def test_async_fan_out_aggregates_headers_from_latest_completion(monkeypatch): + """Aggregated headers reflect the most recently completed chunk. + + Completion order can differ from index order in parallel mode, so + rate-limit headers should come from whichever sub-request finished + last, not from the highest sub-args index. + """ + + async def fetch_async(args): + if "S1" * 10 in args["sites"]: + await asyncio.sleep(0.02) + return pd.DataFrame({"id": [_atom_id(args)]}), _ok_response(remaining=11) + return pd.DataFrame({"id": [_atom_id(args)]}), _ok_response(remaining=77) + + fetch = _async_chunked_fetch(monkeypatch, fetch_async) + _, response = fetch({"sites": ["S1" * 10, "S2" * 10, "S3" * 10, "S4" * 10]}) + assert response.headers[_QUOTA_HEADER] == "11" + + def test_async_fan_out_failure_yields_resumable_call(monkeypatch): """A transient 5xx mid-fan-out raises ``ServiceInterrupted`` whose ``.call`` is a ``ChunkedCall`` holding the completed sub-requests @@ -1491,6 +1510,24 @@ def test_retry_policy_from_env(monkeypatch): RetryPolicy.from_env() +def test_retry_policy_rejects_invalid_settings(): + with pytest.raises(ValueError): + RetryPolicy(max_retries=-1) + with pytest.raises(ValueError): + RetryPolicy(base_backoff=-0.5) + with pytest.raises(ValueError): + RetryPolicy(max_backoff=-1.0) + + +def test_retry_policy_from_env_honors_monkeypatched_constants(monkeypatch): + # The timing knobs are read from the module constants at call time, so + # monkeypatching them (as the module comment promises) takes effect. + monkeypatch.setattr(_chunking, "_RETRY_MAX_BACKOFF", 0.0) + monkeypatch.setattr(_chunking, "_RETRY_BASE_BACKOFF", 0.0) + policy = RetryPolicy.from_env() + assert policy.max_backoff == 0.0 and policy.base_backoff == 0.0 + + # -- _retryable taxonomy ---------------------------------------------------- @@ -1505,8 +1542,13 @@ def test_retryable_taxonomy(): assert _retryable(RuntimeError("400")) == (False, None) -def test_retryable_walks_cause_chain(): - assert _retryable(_wrap_cause(RateLimited("429", retry_after=3.0))) == (True, 3.0) +def test_retryable_skips_wrapped_midpagination_transient(): + # A transient surfaced mid-pagination is re-wrapped as RuntimeError by + # _paginate; it must NOT be auto-retried (re-walking from page 1 would + # re-spend quota) — it escalates to the resumable handle instead. Only + # the raw, top-level (initial-request) transient is retryable. + assert _retryable(_wrap_cause(RateLimited("429", retry_after=3.0))) == (False, None) + assert _retryable(RateLimited("429", retry_after=3.0)) == (True, 3.0) # -- sync driver ------------------------------------------------------------ @@ -1650,3 +1692,25 @@ async def fetch_async(args): fetch = _async_chunked_fetch(monkeypatch, fetch_async) df, _ = fetch({"sites": ["S1" * 10, "S2" * 10, "S3" * 10, "S4" * 10]}) assert len(df) > 1 # every sub-args atom recovered after the retry + + +def test_async_fan_out_surfaces_fatal_over_transient(monkeypatch): + """A non-transient bug in one sub-request surfaces raw rather than + being masked behind a resumable interruption from a transient sibling.""" + monkeypatch.setenv("API_USGS_RETRIES", "2") + + async def _noslept(_d): + return None + + monkeypatch.setattr(_chunking, "_ASLEEP", _noslept) + + async def fetch_async(args): + # One chunk carries a deterministic programmer error; the rest are + # transient. The real bug must win over the resumable transient. + if "S1" * 10 in args["sites"]: + raise ValueError("deterministic bug") + raise ServiceUnavailable("503: transient") + + fetch = _async_chunked_fetch(monkeypatch, fetch_async) + with pytest.raises(ValueError, match="deterministic bug"): + fetch({"sites": ["S1" * 10, "S2" * 10, "S3" * 10, "S4" * 10]}) diff --git a/tests/waterdata_progress_test.py b/tests/waterdata_progress_test.py index 30be56a2..a98dc76a 100644 --- a/tests/waterdata_progress_test.py +++ b/tests/waterdata_progress_test.py @@ -78,6 +78,26 @@ def test_note_retry_renders_then_clears_on_next_page(): assert "retrying" not in stream.getvalue().rsplit("\r", 1)[-1] +def test_note_retry_subsecond_wait_shows_decimal(): + # A sub-second backoff must not collapse to a misleading "0s". + stream = io.StringIO() + reporter = ProgressReporter(stream=stream, enabled=True) + reporter.note_retry(attempt=1, wait=0.3) + out = stream.getvalue() + assert "waiting 0.3s" in out and "waiting 0s" not in out + + +def test_note_retry_cleared_on_close(): + # An exhausted retry leaves retry_note set with no following page; + # close() must clear it so the persisted last line isn't a stale note. + stream = io.StringIO() + reporter = ProgressReporter(stream=stream, enabled=True) + reporter.add_page(rows=1) + reporter.note_retry(attempt=3, wait=5.0) + reporter.close() + assert "retrying" not in stream.getvalue().rsplit("\r", 1)[-1] + + def test_note_retry_is_noop_when_disabled(): stream = io.StringIO() reporter = ProgressReporter(stream=stream, enabled=False) From e740a649339e09c298828b44bdb9aa4d42242ae3 Mon Sep 17 00:00:00 2001 From: thodson-usgs Date: Wed, 27 May 2026 10:53:21 -0500 Subject: [PATCH 03/16] fix(waterdata): finalize resumed chunked results; add reference-table row cap Post-review fixes layered on the async parallel chunker: - Funnel OGC post-processing through the chunker via an injected `finalize` hook so `ChunkInterrupted.call.resume()` returns the same type-coerced `(df, BaseMetadata)` as an un-interrupted call instead of the raw `(frame, httpx.Response)`. `partial_frame`/`partial_response` stay raw, so building the exception never triggers finalize's side effects (a schema network GET on an empty frame would otherwise fire inside the ctor). - Add `max_rows` to `get_reference_table`/`get_ogc_data` to preview large reference tables without downloading every page; enforced as the exact total in `_finalize_ogc` (after dedup) and validated as a positive integer (accepts numpy ints via `numbers.Integral`). - Co-locate the parallel fan-out into `ChunkedCall.resume_async`, sharing a `_pending()` generator with the serial `resume()` so the two execution paths can't drift. - Harden `ProgressReporter.note_retry` for Python 3.9-3.11 (int `wait` and `int.is_integer()`). Co-Authored-By: Claude Opus 4.7 (1M context) --- dataretrieval/waterdata/_progress.py | 4 +- dataretrieval/waterdata/api.py | 11 +- dataretrieval/waterdata/chunking.py | 518 +++++++++++++++------------ dataretrieval/waterdata/utils.py | 138 ++++++- tests/waterdata_chunking_test.py | 184 +++++++--- tests/waterdata_progress_test.py | 16 +- tests/waterdata_test.py | 19 + tests/waterdata_utils_test.py | 93 +++++ 8 files changed, 689 insertions(+), 294 deletions(-) diff --git a/dataretrieval/waterdata/_progress.py b/dataretrieval/waterdata/_progress.py index e529d6d3..ce94effb 100644 --- a/dataretrieval/waterdata/_progress.py +++ b/dataretrieval/waterdata/_progress.py @@ -162,7 +162,9 @@ def note_retry(self, *, attempt: int, wait: float) -> None: """ # Keep sub-second waits explicit (avoid misleading ``0s``) while # rendering whole-second waits without unnecessary ``.0`` noise. - wait_1dp = round(wait, 1) + # ``float()`` to support Python 3.9-3.11: ``round(int, 1)`` returns an + # int and ``int.is_integer()`` (used below) only exists on 3.12+. + wait_1dp = round(float(wait), 1) if wait_1dp < 1 or not wait_1dp.is_integer(): secs = f"{wait_1dp:.1f}s" else: diff --git a/dataretrieval/waterdata/api.py b/dataretrieval/waterdata/api.py index 57fffc88..44550375 100644 --- a/dataretrieval/waterdata/api.py +++ b/dataretrieval/waterdata/api.py @@ -2022,6 +2022,7 @@ def get_reference_table( collection: str, limit: int | None = None, query: dict | None = None, + max_rows: int | None = None, ) -> tuple[pd.DataFrame, BaseMetadata]: """Get metadata reference tables for the USGS Water Data API. @@ -2046,6 +2047,12 @@ def get_reference_table( query: dictionary, optional The optional args parameter can be used to pass a dictionary of query parameters to the collection API call. + max_rows : int, optional + Cap the total number of rows returned, stopping pagination early + instead of downloading the whole table. Useful for cheaply + previewing large tables (e.g. ``hydrologic-unit-codes`` has ~125k + rows). Unlike ``limit`` (the per-page size), this bounds the total + result. The default (None) downloads every page. Returns ------- @@ -2092,7 +2099,9 @@ def get_reference_table( query_args = dict(query) if query else {} if limit is not None: query_args["limit"] = limit - return get_ogc_data(args=query_args, output_id=output_id, service=collection) + return get_ogc_data( + args=query_args, output_id=output_id, service=collection, max_rows=max_rows + ) def get_codes(code_service: CODE_SERVICES) -> pd.DataFrame: diff --git a/dataretrieval/waterdata/chunking.py b/dataretrieval/waterdata/chunking.py index 2d01614c..75121cd8 100644 --- a/dataretrieval/waterdata/chunking.py +++ b/dataretrieval/waterdata/chunking.py @@ -15,10 +15,12 @@ instead of issuing sub-requests serially. ``N=1`` forces the synchronous path. The default (16) is the server-friendly sweet spot; higher values can trip USGS burst-protection 5xx in practice. -The wrapper falls back to the serial path (with a ``UserWarning``) -when an asyncio event loop is already running (Jupyter / IPython / -async apps) or when no async fetch sibling is wired into the -decorator. +The fan-out runs in a short-lived worker thread (an +``anyio`` blocking portal), so it works whether or not the caller is +already inside an event loop (Jupyter / IPython / async apps) — no +nested-loop error and no silent serial degradation. It falls back to +the serial path (with a ``UserWarning``) only when no async fetch +sibling is wired into the decorator. Retries: each sub-request is retried on a transient failure (429, 5xx, connect/read timeout) with exponential backoff + full jitter, @@ -67,6 +69,7 @@ import httpx import pandas as pd +from anyio.from_thread import start_blocking_portal from dataretrieval.utils import HTTPX_DEFAULTS @@ -299,7 +302,7 @@ def backoff(self, attempt: int, retry_after: float | None) -> float: ) # Async sibling of ``_chunked_client``. Published (via :func:`_publish`) -# during ``_fan_out_async`` so async paginated-loop helpers reuse one +# during ``ChunkedCall.resume_async`` so async paginated-loop helpers reuse one # ``httpx.AsyncClient`` (and its connection pool) across every concurrent # sub-request of a single chunked call. _chunked_async_client: ContextVar[httpx.AsyncClient | None] = ContextVar( @@ -368,6 +371,23 @@ def get_active_async_client() -> httpx.AsyncClient | None: [dict[str, Any]], Awaitable[tuple[pd.DataFrame, httpx.Response]] ] +# Caller-supplied transform applied to the *combined* chunk result. It lets a +# resumed call (:meth:`ChunkedCall.resume` / :attr:`~ChunkedCall.partial_frame` +# / :attr:`~ChunkedCall.partial_response`) return the same shape as the +# un-interrupted call instead of the chunker's raw ``(frame, httpx.Response)``. +# The chunker stays generic — it only knows "post-process the assembled +# result"; the OGC getters inject the actual type-coercion / column-arrangement +# / ``BaseMetadata`` pipeline (see ``utils._finalize_ogc``). The default is +# identity, so direct ``ChunkedCall`` use and the tests are unaffected. +_Finalize = Callable[[pd.DataFrame, httpx.Response], "tuple[pd.DataFrame, Any]"] + + +def _passthrough_result( + frame: pd.DataFrame, response: httpx.Response +) -> tuple[pd.DataFrame, Any]: + """Default :data:`_Finalize`: return the raw combined pair unchanged.""" + return frame, response + class _RetryableTransportError(RuntimeError): """ @@ -470,9 +490,10 @@ class ChunkInterrupted(RuntimeError): later ``call.resume()`` (use ``exc.call.partial_frame`` for the live view). partial_response : httpx.Response or None - Aggregated response covering the completed sub-requests at - raise time; ``None`` if nothing had completed yet. Same - snapshot semantics as ``partial_frame``. + Raw aggregate response covering the completed sub-requests at + raise time; ``None`` if nothing had completed yet. Same snapshot + semantics as ``partial_frame``. (Raw, not finalized — use + ``exc.call.resume()`` for the finalized ``(df, metadata)`` result.) Examples -------- @@ -994,10 +1015,13 @@ def iter_sub_args(self) -> Iterator[dict[str, Any]]: yield sub_args def execute( - self, fetch_once: _FetchOnce, retry_policy: RetryPolicy = _NO_RETRY - ) -> tuple[pd.DataFrame, httpx.Response]: + self, + fetch_once: _FetchOnce, + retry_policy: RetryPolicy = _NO_RETRY, + finalize: _Finalize = _passthrough_result, + ) -> tuple[pd.DataFrame, Any]: """ - Run the plan and return the combined ``(frame, response)``. + Run the plan and return the combined, finalized result. Thin wrapper around ``ChunkedCall(self, fetch_once).resume()``; see :class:`ChunkedCall` for the per-sub-request semantics. @@ -1010,14 +1034,17 @@ def execute( retry_policy : RetryPolicy, optional Per-sub-request retry-with-backoff policy. Defaults to :data:`_NO_RETRY`; the decorator passes ``RetryPolicy.from_env()``. + finalize : Callable, optional + Transform applied to the combined ``(frame, response)`` (see + :data:`_Finalize`). Defaults to :func:`_passthrough_result`. Returns ------- df : pandas.DataFrame Combined data from every successful sub-request. - response : httpx.Response - Aggregated response (canonical URL, last page's headers, - cumulative elapsed time). + response + The finalized aggregate — a raw :class:`httpx.Response` by + default, or whatever ``finalize`` produces. Raises ------ @@ -1027,7 +1054,7 @@ def execute( :class:`ServiceInterrupted` for 5xx). The resumable handle is on ``exc.call``. """ - return ChunkedCall(self, fetch_once, retry_policy).resume() + return ChunkedCall(self, fetch_once, retry_policy, finalize).resume() def _classify_chunk_error( @@ -1125,6 +1152,26 @@ def _note_retry(attempt: int, wait: float) -> None: reporter.note_retry(attempt=attempt, wait=wait) +def _retry_delay(exc: BaseException, attempt: int, policy: RetryPolicy) -> float | None: + """ + Decide the backoff for a just-failed ``attempt`` (1-based), or ``None`` + to give up and re-raise. + + Returns ``None`` when the error isn't a retryable transient, the policy + is exhausted, or the server's ``Retry-After`` is too long to absorb + inline (so it escalates to a resumable :class:`ChunkInterrupted`). + Otherwise returns the seconds to wait and emits the progress-bar retry + note. This is the whole retry *decision* — the sync and async drivers + share it and differ only in how they call the fetch and how they sleep. + """ + retryable, retry_after = _retryable(exc) + if not retryable or not policy.should_retry(attempt, retry_after): + return None + delay = policy.backoff(attempt, retry_after) + _note_retry(attempt, delay) + return delay + + def _retry_sync( fn: Callable[[], tuple[pd.DataFrame, httpx.Response]], policy: RetryPolicy, @@ -1132,22 +1179,19 @@ def _retry_sync( """ Call ``fn`` with bounded retry-with-backoff on transient failures. - On a non-retryable error, or once ``policy`` is exhausted (or the - server's ``Retry-After`` is too long to absorb inline), the last - exception propagates unchanged so the caller's existing handling wraps - it as a resumable :class:`ChunkInterrupted`. + A non-retryable or policy-exhausted failure (see :func:`_retry_delay`) + propagates unchanged so the caller's existing handling wraps it as a + resumable :class:`ChunkInterrupted`. """ attempt = 0 while True: try: return fn() except Exception as exc: # noqa: BLE001 — re-raised unless retryable - retryable, retry_after = _retryable(exc) attempt += 1 - if not retryable or not policy.should_retry(attempt, retry_after): + delay = _retry_delay(exc, attempt, policy) + if delay is None: raise - delay = policy.backoff(attempt, retry_after) - _note_retry(attempt, delay) _SLEEP(delay) @@ -1161,12 +1205,10 @@ async def _retry_async( try: return await afn() except Exception as exc: # noqa: BLE001 — re-raised unless retryable - retryable, retry_after = _retryable(exc) attempt += 1 - if not retryable or not policy.should_retry(attempt, retry_after): + delay = _retry_delay(exc, attempt, policy) + if delay is None: raise - delay = policy.backoff(attempt, retry_after) - _note_retry(attempt, delay) await _ASLEEP(delay) @@ -1329,12 +1371,17 @@ class ChunkedCall: The plan being driven (read-only after construction). fetch_once : Callable The per-sub-request fetch function. + finalize : Callable + Transform applied to the combined result (see :data:`_Finalize`) at + the terminal :meth:`resume` / :meth:`resume_async` returns, so a + completed call yields the caller's finished shape. The ``partial_*`` + accessors deliberately skip it and stay raw. partial_frame : pandas.DataFrame - Combined frame of completed sub-requests (live; recomputed per - access). - partial_response : httpx.Response or None - Aggregated response with canonical URL restored, or ``None`` - when nothing has completed yet (live; recomputed per access). + Raw combined frame of completed sub-requests (live; recomputed per + access). Not finalized — see :attr:`partial_frame`. + partial_response + Raw aggregate response (canonical URL restored), or ``None`` when + nothing has completed yet (live; recomputed per access). """ def __init__( @@ -1342,10 +1389,12 @@ def __init__( plan: ChunkPlan, fetch_once: _FetchOnce, retry_policy: RetryPolicy = _NO_RETRY, + finalize: _Finalize = _passthrough_result, ) -> None: self.plan = plan self.fetch_once = fetch_once self.retry_policy = retry_policy + self.finalize = finalize # Completed (frame, response) pairs keyed by sub-args index. # Sparse so the parallel fan-out path can record scattered # completions (e.g. indices [0, 2, 5] when 1/3/4 failed) and a @@ -1361,7 +1410,7 @@ def record(self, index: int, pair: tuple[pd.DataFrame, httpx.Response]) -> None: """Record a completed sub-request's ``(frame, response)`` pair under its sub-args index. Used by both the serial loop in :meth:`resume` and the parallel fan-out in - :func:`_fan_out_async` so the completion set stays + :meth:`resume_async` so the completion set stays encapsulated.""" if index not in self._chunks: self._completion_order.append(index) @@ -1398,8 +1447,9 @@ def _responses_by_completion(self) -> list[httpx.Response]: # headers carry the freshest ``x-ratelimit-remaining`` for aggregation. return [self._chunks[i][1] for i in self._completion_order] - def combined(self) -> tuple[pd.DataFrame, httpx.Response]: - """Combine every recorded sub-request into one ``(frame, response)``. + def _combine_raw(self) -> tuple[pd.DataFrame, httpx.Response]: + """Assemble the raw ``(frame, response)`` from completed sub-requests, + before :attr:`finalize` runs. Frames concatenate in sub-args *index* order (deterministic, independent of parallel completion order); the aggregated response @@ -1414,14 +1464,33 @@ def combined(self) -> tuple[pd.DataFrame, httpx.Response]: ), ) + def combined(self) -> tuple[pd.DataFrame, Any]: + """Combine every recorded sub-request and apply :attr:`finalize`. + + The terminal *success* result: :meth:`resume` and + :meth:`resume_async` both return this, so a completed call (whether + serial or parallel, first run or resume) yields the same shape + ``finalize`` produces — a raw ``(frame, httpx.Response)`` by + default, or the OGC getters' type-coerced / column-arranged frame + plus ``BaseMetadata``. The ``partial_*`` accessors deliberately do + NOT go through here — they return the raw :meth:`_combine_raw` + snapshot to stay cheap and side-effect-free. + """ + return self.finalize(*self._combine_raw()) + @property def partial_frame(self) -> pd.DataFrame: """ - Concatenated, deduplicated frame of sub-requests that have - completed so far. + Raw combined frame of sub-requests that have completed so far. Live — recomputed on each access so it reflects current state - across resume attempts. + across resume attempts. Deliberately the *raw* combined frame + (``_combine_raw``), NOT the finalized result: this is a cheap, + side-effect-free snapshot for inspecting partial progress, so + reading it (or building a :class:`ChunkInterrupted` around it) + never triggers ``finalize`` work — which for OGC getters includes + a schema network fetch on an empty frame. Use ``call.resume()`` + for the finalized result. Returns ------- @@ -1431,15 +1500,17 @@ def partial_frame(self) -> pd.DataFrame: """ if not self._chunks: return pd.DataFrame() - return _combine_chunk_frames([frame for frame, _ in self._ordered_chunks()]) + return self._combine_raw()[0] @property def partial_response(self) -> httpx.Response | None: """ - Aggregated response with the canonical URL restored to the + Raw aggregate response with the canonical URL restored to the user's full original query. - Live — recomputed on each access. + Live — recomputed on each access. Like :attr:`partial_frame`, this + is the *raw* aggregate (an :class:`httpx.Response`), not the + finalized result, so inspecting it is side-effect-free. Returns ------- @@ -1449,11 +1520,22 @@ def partial_response(self) -> httpx.Response | None: """ if not self._chunks: return None - return _combine_chunk_responses( - self._responses_by_completion(), self.plan.canonical_url - ) + return self._combine_raw()[1] + + def _pending(self) -> Iterator[tuple[int, dict[str, Any]]]: + """Yield ``(index, sub_args)`` for sub-requests not yet completed. + + The single source of the "walk :meth:`ChunkPlan.iter_sub_args` in + deterministic order, skip any index already in ``self._chunks``" + rule, shared by the serial :meth:`resume` and the parallel + :meth:`resume_async` so the two execution paths can't drift on + *which* sub-requests they still owe. + """ + for index, sub_args in enumerate(self.plan.iter_sub_args()): + if index not in self._chunks: + yield index, sub_args - def resume(self) -> tuple[pd.DataFrame, httpx.Response]: + def resume(self) -> tuple[pd.DataFrame, Any]: """ Drive the chunked call to completion via the sync ``fetch_once``. @@ -1475,9 +1557,11 @@ def resume(self) -> tuple[pd.DataFrame, httpx.Response]: ------- df : pandas.DataFrame Combined data from every successful sub-request. - response : httpx.Response - Aggregated response (canonical URL, last page's headers, - cumulative elapsed time). + response + The finalized aggregate — a raw :class:`httpx.Response` + (canonical URL, last page's headers, cumulative elapsed time) + by default, or whatever :attr:`finalize` produces (e.g. + ``BaseMetadata`` for the OGC getters). Raises ------ @@ -1493,12 +1577,17 @@ def resume(self) -> tuple[pd.DataFrame, httpx.Response]: reporter = _progress.current() if reporter is not None: reporter.set_chunks(self.plan.total) - for i, sub_args in enumerate(self.plan.iter_sub_args()): - if i in self._chunks: - continue + for index, sub_args in self._pending(): + # Serial progress semantics: announce the chunk we're + # *about to* fetch (1-based), so the line reads + # "chunk k/total" while that fetch + its pages are in + # flight. (The parallel path can't do this — chunks fire + # at once and finish out of order — so :meth:`resume_async` + # instead ticks the completed *count*; the two are + # deliberately different, not drift.) if reporter is not None: - reporter.start_chunk(i + 1) - self._issue(i, sub_args) + reporter.start_chunk(index + 1) + self._issue(index, sub_args) return self.combined() def _issue(self, index: int, sub_args: dict[str, Any]) -> None: @@ -1526,140 +1615,130 @@ def _issue(self, index: int, sub_args: dict[str, Any]) -> None: raise interrupted from exc self.record(index, chunk) + async def resume_async( + self, fetch_async: _FetchOnceAsync, *, max_concurrent: int | None + ) -> tuple[pd.DataFrame, Any]: + """ + Drive the chunked call to completion concurrently over one shared + :class:`httpx.AsyncClient`. Async sibling of :meth:`resume`. + + Pending sub-requests (:meth:`_pending`) fan out under + ``asyncio.gather`` with ``return_exceptions=True`` so completed + sub-requests survive a sibling's transient failure. On a recognized + transient (:class:`RateLimited`, :class:`ServiceUnavailable`) a + :class:`ChunkInterrupted` subclass is raised carrying ``self`` on + ``.call``; ``exc.call.resume()`` then re-issues only the unfinished + indices via the serial sync ``fetch_once`` path. The per-sub-request + bookkeeping (:meth:`_pending`, :meth:`record`, :meth:`wrap_failure`, + :meth:`combined`) is shared with :meth:`resume`, so the two execution + paths differ only in serial ``for`` vs concurrent ``gather``. + + In-flight sub-requests are capped by an :class:`asyncio.Semaphore`; + ``max_concurrent=None`` ("unbounded") uses ``sys.maxsize`` so every + call site takes the same ``async with semaphore`` path. The shared + client is published on :data:`_chunked_async_client` so async + paginated-loop helpers reuse its connection pool. -async def _fan_out_async( - plan: ChunkPlan, - fetch_once: _FetchOnce, - fetch_async: _FetchOnceAsync, - *, - max_concurrent: int | None, - retry_policy: RetryPolicy = _NO_RETRY, -) -> tuple[pd.DataFrame, httpx.Response]: - """ - Execute ``plan`` concurrently under one shared - :class:`httpx.AsyncClient`. - - The fan-out preserves the same resumability contract the serial - :class:`ChunkedCall` path provides: - - * **Resumable interruptions.** Sub-requests run under - ``asyncio.gather`` with ``return_exceptions=True`` so completed - sub-requests survive a sibling's transient failure. On a - recognized transient (:class:`RateLimited`, - :class:`ServiceUnavailable`) a :class:`ChunkInterrupted` - subclass is raised with ``.call`` set to a - :class:`ChunkedCall` carrying the sparse completed sub-args; - ``exc.call.resume()`` re-issues only the unfinished ones via - the sync ``fetch_once`` path. - - In-flight sub-requests are capped by an - :class:`asyncio.Semaphore`; ``max_concurrent=None`` ("unbounded") - uses ``sys.maxsize`` so every call site can take the same - ``async with semaphore`` path. The shared client is published on - :data:`_chunked_async_client` so async paginated-loop helpers - reuse its connection pool. + Parameters + ---------- + fetch_async : Callable + Async per-sub-request fetcher returning ``(df, response)``. + max_concurrent : int or None + Maximum in-flight sub-requests. ``None`` disables the cap. - Parameters - ---------- - plan : ChunkPlan - Pre-built plan whose sub-args sequence drives the fan-out. - fetch_once : Callable - Sync per-sub-request fetcher. Used to build the resumable - :class:`ChunkedCall` returned via ``ChunkInterrupted.call``; - never invoked by this function directly. - fetch_async : Callable - Async per-sub-request fetcher returning ``(df, response)``. - max_concurrent : int or None - Maximum in-flight sub-requests. ``None`` disables the cap. + Returns + ------- + df : pandas.DataFrame + Combined data from every sub-request. + response + The finalized aggregate — a raw :class:`httpx.Response` + (canonical URL, most-recently-completed sub-request's headers, + cumulative elapsed time) by default, or whatever + :attr:`finalize` produces (e.g. ``BaseMetadata`` for OGC getters). - Returns - ------- - df : pandas.DataFrame - Combined data from every sub-request. - response : httpx.Response - Aggregated response (canonical URL, last sub-request's - headers, cumulative elapsed time). + Raises + ------ + ChunkInterrupted + On a transient sub-request failure. ``.call`` is ``self``, + holding the sparse completed sub-requests; ``.call.resume()`` + re-issues the unfinished ones serially. + """ + # ``httpx.Limits()`` defaults to ``max_connections=100`` — at + # higher concurrency the pool would silently bottleneck the + # fan-out behind the connection cap. Match it to the semaphore, + # or ``None`` for truly unbounded. + limits = httpx.Limits( + max_connections=max_concurrent, max_keepalive_connections=max_concurrent + ) + # ``None`` means "unbounded"; ``sys.maxsize`` stands in for it since + # ``asyncio.Semaphore`` only decrements a counter, never preallocates + # slots. Test ``is None`` explicitly so a stray ``0`` isn't silently + # promoted to unbounded by a falsy-``or``. + semaphore = asyncio.Semaphore( + sys.maxsize if max_concurrent is None else max_concurrent + ) - Raises - ------ - ChunkInterrupted - On a transient sub-request failure. ``.call`` is a - :class:`ChunkedCall` holding the sparse completed sub-requests; - ``.call.resume()`` re-issues the unfinished ones serially. - """ - sub_args_list = list(plan.iter_sub_args()) - - # ``httpx.Limits()`` defaults to ``max_connections=100`` — at - # higher concurrency the pool would silently bottleneck the - # fan-out behind the connection cap. Match it to the semaphore, - # or ``None`` for truly unbounded. - limits = httpx.Limits( - max_connections=max_concurrent, max_keepalive_connections=max_concurrent - ) - # ``None`` means "unbounded"; ``sys.maxsize`` stands in for it since - # ``asyncio.Semaphore`` only decrements a counter, never preallocates - # slots. Test ``is None`` explicitly so a stray ``0`` isn't silently - # promoted to unbounded by a falsy-``or``. - semaphore = asyncio.Semaphore( - sys.maxsize if max_concurrent is None else max_concurrent - ) - call = ChunkedCall(plan, fetch_once, retry_policy) - - async with httpx.AsyncClient(limits=limits, **HTTPX_DEFAULTS) as client: - with _publish(_chunked_async_client, client): - reporter = _progress.current() - if reporter is not None: - reporter.set_chunks(plan.total) - - async def track( - offset: int, args: dict[str, Any] - ) -> tuple[pd.DataFrame, httpx.Response]: - """One sub-request (with retry) + record + progress tick. - - The retry loop runs *inside* the semaphore, so a chunk - backing off holds its slot — effective concurrency shrinks - under throttling instead of re-bursting against it. - """ - async with semaphore: - result = await _retry_async(lambda: fetch_async(args), retry_policy) - call.record(offset, result) + async with httpx.AsyncClient(limits=limits, **HTTPX_DEFAULTS) as client: + with _publish(_chunked_async_client, client): + reporter = _progress.current() if reporter is not None: - reporter.start_chunk(call.completed_chunks) - return result - - # Dispatch every sub-request concurrently. ``return_exceptions`` - # keeps completed pairs after a sibling fails, so partial state - # stays recoverable via ``ChunkedCall.resume()``. Failure - # precedence, in order: - # 1. Cancellation / interrupt signals (CancelledError, - # KeyboardInterrupt, SystemExit — non-Exception) propagate - # unmodified; wrapping them as a transient would swallow the - # user's stop signal. - # 2. A non-transient failure (a real bug — unrecognized by - # ``wrap_failure``) surfaces raw, so it isn't masked behind a - # resumable handle for a transient sibling that landed later. - # 3. Only when every failure is a recognized transient do we - # raise the first as a resumable ``ChunkInterrupted``. - results = await asyncio.gather( - *(track(i, args) for i, args in enumerate(sub_args_list)), - return_exceptions=True, - ) - failures = [r for r in results if isinstance(r, BaseException)] - for exc in failures: - if not isinstance(exc, Exception): - raise exc - first_transient: tuple[ChunkInterrupted, BaseException] | None = None - for exc in failures: - interrupted = call.wrap_failure(exc) - if interrupted is None: - raise exc - if first_transient is None: - first_transient = (interrupted, exc) - if first_transient is not None: - interrupted, exc = first_transient - raise interrupted from exc - - return call.combined() + reporter.set_chunks(self.plan.total) + + async def track( + index: int, args: dict[str, Any] + ) -> tuple[pd.DataFrame, httpx.Response]: + """One sub-request (with retry) + record + progress tick. + + The retry loop runs *inside* the semaphore, so a chunk + backing off holds its slot — effective concurrency shrinks + under throttling instead of re-bursting against it. + """ + async with semaphore: + result = await _retry_async( + lambda: fetch_async(args), self.retry_policy + ) + self.record(index, result) + if reporter is not None: + # Parallel progress semantics: chunks finish out of + # order, so tick the completed *count* rather than a + # positional index (see :meth:`resume`). + reporter.start_chunk(self.completed_chunks) + return result + + # Dispatch every pending sub-request concurrently. + # ``return_exceptions`` keeps completed pairs after a sibling + # fails, so partial state stays recoverable via :meth:`resume`. + # Failure precedence, in order: + # 1. Cancellation / interrupt signals (CancelledError, + # KeyboardInterrupt, SystemExit — non-Exception) propagate + # unmodified; wrapping them as a transient would swallow + # the user's stop signal. + # 2. A non-transient failure (a real bug — unrecognized by + # ``wrap_failure``) surfaces raw, so it isn't masked behind + # a resumable handle for a transient sibling that landed + # later. + # 3. Only when every failure is a recognized transient do we + # raise the first as a resumable ``ChunkInterrupted``. + results = await asyncio.gather( + *(track(index, args) for index, args in self._pending()), + return_exceptions=True, + ) + failures = [r for r in results if isinstance(r, BaseException)] + for exc in failures: + if not isinstance(exc, Exception): + raise exc + first_transient: tuple[ChunkInterrupted, BaseException] | None = None + for exc in failures: + interrupted = self.wrap_failure(exc) + if interrupted is None: + raise exc + if first_transient is None: + first_transient = (interrupted, exc) + if first_transient is not None: + interrupted, exc = first_transient + raise interrupted from exc + + return self.combined() def multi_value_chunked( @@ -1679,11 +1758,11 @@ def multi_value_chunked( When ``API_USGS_CONCURRENT`` resolves to a parallelism greater than 1 (the default), the decorator routes execution through - :func:`_fan_out_async` over the provided ``fetch_async``. The - wrapper falls back to the synchronous :class:`ChunkedCall` path - (with a ``UserWarning``) when ``fetch_async`` wasn't wired or - when an asyncio event loop is already running (Jupyter / IPython / - async apps where ``asyncio.run`` would raise ``RuntimeError``). + :meth:`ChunkedCall.resume_async` over the provided ``fetch_async``, run in an + ``anyio`` worker-thread portal so it works whether or not the caller + is already inside an event loop (Jupyter / IPython / async apps). It + falls back to the synchronous :class:`ChunkedCall` path (with a + ``UserWarning``) only when ``fetch_async`` wasn't wired. Parameters ---------- @@ -1727,7 +1806,9 @@ def decorator(fetch_once: _FetchOnce) -> _FetchOnce: @functools.wraps(fetch_once) def wrapper( args: dict[str, Any], - ) -> tuple[pd.DataFrame, httpx.Response]: + *, + finalize: _Finalize = _passthrough_result, + ) -> tuple[pd.DataFrame, Any]: limit = _WATERDATA_URL_BYTE_LIMIT if url_limit is None else url_limit plan = ChunkPlan(args, build_request, limit) concurrency = _read_concurrency_env() @@ -1737,9 +1818,9 @@ def wrapper( # path; ``_execute_in_parallel`` owns the rest of the # serial/parallel decision (async wiring, running loop). if plan.total <= 1 or concurrency == 1: - return plan.execute(fetch_once, retry_policy) + return plan.execute(fetch_once, retry_policy, finalize) return _execute_in_parallel( - plan, fetch_once, fetch_async, concurrency, retry_policy + plan, fetch_once, fetch_async, concurrency, retry_policy, finalize ) return wrapper @@ -1753,19 +1834,19 @@ def _execute_in_parallel( fetch_async: _FetchOnceAsync | None, concurrency: int | None, retry_policy: RetryPolicy = _NO_RETRY, -) -> tuple[pd.DataFrame, httpx.Response]: + finalize: _Finalize = _passthrough_result, +) -> tuple[pd.DataFrame, Any]: """ - Run ``plan`` on the parallel async path, falling back to the - serial sync path when the runtime can't host an event loop. - - Falls back (with a one-time :class:`UserWarning`) when: - - * ``fetch_async`` wasn't wired into the decorator, or - * an asyncio event loop is already running (Jupyter / IPython - kernels, async apps — ``asyncio.run`` would raise). - - Otherwise opens a fresh event loop via :func:`asyncio.run` and - drives :func:`_fan_out_async`. + Run ``plan`` on the parallel async path. + + Falls back to the serial sync path (with a one-time + :class:`UserWarning`) only when ``fetch_async`` wasn't wired into the + decorator. Otherwise it drives :meth:`ChunkedCall.resume_async` in a short-lived + worker thread via an ``anyio`` blocking portal, so the fan-out runs + whether or not the caller is already inside an event loop (Jupyter / + IPython / async apps) — no nested-``asyncio.run`` error and no silent + degradation to serial. The portal copies the calling context, so the + active progress reporter still reaches the fan-out. """ if fetch_async is None: warnings.warn( @@ -1777,31 +1858,10 @@ def _execute_in_parallel( UserWarning, stacklevel=3, ) - return plan.execute(fetch_once, retry_policy) - if _running_event_loop() is not None: - warnings.warn( - "Detected a running asyncio event loop; the parallel " - f"chunker path cannot run inside one. Falling back to " - f"the serial path. Set {_CONCURRENCY_ENV}=1 to silence " - f"this warning.", - UserWarning, - stacklevel=3, - ) - return plan.execute(fetch_once, retry_policy) - return asyncio.run( - _fan_out_async( - plan, - fetch_once, - fetch_async, - max_concurrent=concurrency, - retry_policy=retry_policy, - ) + return plan.execute(fetch_once, retry_policy, finalize) + call = ChunkedCall(plan, fetch_once, retry_policy, finalize) + fan_out = functools.partial( + call.resume_async, fetch_async, max_concurrent=concurrency ) - - -def _running_event_loop() -> asyncio.AbstractEventLoop | None: - """Return the active asyncio event loop, or ``None`` when none.""" - try: - return asyncio.get_running_loop() - except RuntimeError: - return None + with start_blocking_portal() as portal: + return portal.call(fan_out) diff --git a/dataretrieval/waterdata/utils.py b/dataretrieval/waterdata/utils.py index f8475957..c22c2cff 100644 --- a/dataretrieval/waterdata/utils.py +++ b/dataretrieval/waterdata/utils.py @@ -4,6 +4,7 @@ import functools import json import logging +import numbers import os import re from collections.abc import ( @@ -15,6 +16,7 @@ Mapping, ) from contextlib import asynccontextmanager, contextmanager +from contextvars import ContextVar from datetime import datetime, timedelta from typing import Any, TypeVar, get_args from zoneinfo import ZoneInfo @@ -905,6 +907,26 @@ def _aggregate_paginated_response( _Cursor = TypeVar("_Cursor") +# Optional cap on the total rows a single paginated call accumulates before it +# stops following ``next`` links. ``None`` (the default the data getters use) +# means "no cap — fetch the whole series". Set via :func:`_row_cap` so the deep +# ``_paginate`` loop can honor it without threading the value through the +# generic chunker; this mirrors the ``_progress`` ambient-reporter pattern. +_row_cap_var: ContextVar[int | None] = ContextVar("waterdata_row_cap", default=None) + + +@contextmanager +def _row_cap(max_rows: int | None) -> Iterator[None]: + """Cap the rows any :func:`_paginate` / :func:`_paginate_async` under this + context will accumulate (``None`` = uncapped). Used by + :func:`get_reference_table` to preview large tables without downloading + every page.""" + token = _row_cap_var.set(max_rows) + try: + yield + finally: + _row_cap_var.reset(token) + def _paginate( initial_req: httpx.Request, @@ -988,18 +1010,24 @@ def _paginate( logger.warning("Initial response parse failed.") raise RuntimeError(_paginated_failure_message(0, e)) from e dfs = [df] + # Stop following ``next`` links once the optional row cap is reached + # (see :func:`_row_cap`); ``None`` means uncapped. The concatenation is + # sliced to the cap below so a final over-budget page can't exceed it. + cap = _row_cap_var.get() + nrows = len(df) if reporter is not None: reporter.set_rate_remaining( resp.headers.get(_QUOTA_HEADER), limit=resp.headers.get("x-ratelimit-limit"), ) reporter.add_page(rows=len(df)) - while cursor is not None: + while cursor is not None and (cap is None or nrows < cap): try: resp = follow_up(cursor, client) _raise_for_non_200(resp) df, cursor = parse_response(resp) dfs.append(df) + nrows += len(df) total_elapsed += _safe_elapsed(resp) if reporter is not None: reporter.set_rate_remaining( @@ -1021,7 +1049,10 @@ def _paginate( final_response = _aggregate_paginated_response( initial_response, resp, total_elapsed ) - return pd.concat(dfs, ignore_index=True), final_response + result = pd.concat(dfs, ignore_index=True) + if cap is not None: + result = result.head(cap) + return result, final_response async def _paginate_async( @@ -1037,7 +1068,7 @@ async def _paginate_async( Runs the same per-page loop but issues HTTP asynchronously so multiple sub-requests of a chunked call can run concurrently from - :func:`_fan_out_async`. + :meth:`~dataretrieval.waterdata.chunking.ChunkedCall.resume_async`. """ logger.debug("Requesting: %s", initial_req.url) reporter = _progress.current() @@ -1058,18 +1089,24 @@ async def _paginate_async( logger.warning("Initial response parse failed.") raise RuntimeError(_paginated_failure_message(0, e)) from e dfs = [df] + # Stop following ``next`` links once the optional row cap is reached + # (see :func:`_row_cap`); ``None`` means uncapped. Mirrors the sync + # :func:`_paginate`; the concatenation is sliced to the cap below. + cap = _row_cap_var.get() + nrows = len(df) if reporter is not None: reporter.set_rate_remaining( resp.headers.get(_QUOTA_HEADER), limit=resp.headers.get("x-ratelimit-limit"), ) reporter.add_page(rows=len(df)) - while cursor is not None: + while cursor is not None and (cap is None or nrows < cap): try: resp = await follow_up(cursor, sess) _raise_for_non_200(resp) df, cursor = parse_response(resp) dfs.append(df) + nrows += len(df) total_elapsed += _safe_elapsed(resp) if reporter is not None: reporter.set_rate_remaining( @@ -1091,7 +1128,10 @@ async def _paginate_async( final_response = _aggregate_paginated_response( initial_response, resp, total_elapsed ) - return pd.concat(dfs, ignore_index=True), final_response + result = pd.concat(dfs, ignore_index=True) + if cap is not None: + result = result.head(cap) + return result, final_response def _ogc_parse_response( @@ -1356,8 +1396,50 @@ def _sort_rows(df: pd.DataFrame) -> pd.DataFrame: return df +def _finalize_ogc( + frame: pd.DataFrame, + response: httpx.Response, + *, + properties: list[str] | None, + output_id: str, + convert_type: bool, + service: str, + max_rows: int | None = None, +) -> tuple[pd.DataFrame, BaseMetadata]: + """Shape a combined OGC result into the user-facing ``(df, md)``. + + The single home for the OGC getters' result shaping: empties + normalized, types coerced (when ``convert_type``), the wire ``id`` + renamed and columns ordered, rows sorted, optionally truncated to + ``max_rows``, and the response wrapped as + :class:`~dataretrieval.utils.BaseMetadata`. + + Injected into the chunker as its ``finalize`` hook (see + :data:`~dataretrieval.waterdata.chunking._Finalize`) so the + un-interrupted return *and* a resumed ``ChunkInterrupted.call.resume()`` + produce the same shape — closing the gap where resume used to hand back + the chunker's raw frame and bare ``httpx.Response``. + + ``max_rows`` is applied here (after dedup/sort, on the *combined* frame) + rather than only per-sub-request, so a chunked call's total is bounded + to exactly ``max_rows`` and a resumed call honors the cap too — the + per-``_paginate`` ``_row_cap`` is only an early-stop download bound. + """ + frame = _deal_with_empty(frame, properties, service) + if convert_type: + frame = _type_cols(frame) + frame = _arrange_cols(frame, properties, output_id) + frame = _sort_rows(frame) + if max_rows is not None: + frame = frame.head(max_rows) + return frame, BaseMetadata(response) + + def get_ogc_data( - args: dict[str, Any], output_id: str, service: str + args: dict[str, Any], + output_id: str, + service: str, + max_rows: int | None = None, ) -> tuple[pd.DataFrame, BaseMetadata]: """ Retrieves OGC (Open Geospatial Consortium) data from a specified @@ -1376,6 +1458,11 @@ def get_ogc_data( service : str The OGC API collection name (e.g., ``"daily"``, ``"monitoring-locations"``, ``"continuous"``). + max_rows : int, optional + Stop paginating once this many rows have been collected and + truncate the result to exactly ``max_rows``. ``None`` (default) + fetches the full result. Intended for cheap previews of large, + un-chunked tables (e.g. :func:`get_reference_table`). Returns ------- @@ -1390,6 +1477,19 @@ def get_ogc_data( - Handles optional arguments such as `convert_type`. - Applies column cleanup and reordering based on service and properties. """ + # Enforce a genuine positive integer: a float (even ``10.0``) or ``bool`` + # would pass a bare ``< 1`` check and then crash deep in + # ``pd.DataFrame.head`` with an opaque ``TypeError`` after HTTP I/O has + # already fired. ``numbers.Integral`` (not ``int``) so numpy integers — + # e.g. ``max_rows`` derived from a numpy/pandas computation — are accepted; + # ``bool`` is an ``Integral`` subtype, so exclude it explicitly. + if max_rows is not None and ( + not isinstance(max_rows, numbers.Integral) + or isinstance(max_rows, bool) + or max_rows < 1 + ): + raise ValueError(f"max_rows must be a positive integer (got {max_rows!r}).") + args = args.copy() args["service"] = service args = _switch_arg_id(args, id_name=output_id, service=service) @@ -1402,15 +1502,23 @@ def get_ogc_data( convert_type = args.pop("convert_type", False) args = {k: v for k, v in args.items() if v is not None} - with _progress.progress_context(service=service): - return_list, response = _fetch_once(args) - return_list = _deal_with_empty(return_list, properties, service) - if convert_type: - return_list = _type_cols(return_list) - return_list = _arrange_cols(return_list, properties, output_id) - return_list = _sort_rows(return_list) - - return return_list, BaseMetadata(response) + # Post-processing is injected into the chunker rather than applied here, + # so it runs on *every* exit: the normal return AND a later + # ``exc.call.resume()`` after a ChunkInterrupted (which never re-enters + # this function). ``_finalize_ogc`` is the single source of result shape; + # it also applies ``max_rows`` to the *combined* frame so the cap is the + # exact total even when the plan chunks or the call is resumed, while + # ``_row_cap`` below only early-stops each sub-request's pagination. + finalize = functools.partial( + _finalize_ogc, + properties=properties, + output_id=output_id, + convert_type=convert_type, + service=service, + max_rows=max_rows, + ) + with _progress.progress_context(service=service), _row_cap(max_rows): + return _fetch_once(args, finalize=finalize) async def _fetch_once_async( diff --git a/tests/waterdata_chunking_test.py b/tests/waterdata_chunking_test.py index e9500bb4..612ae154 100644 --- a/tests/waterdata_chunking_test.py +++ b/tests/waterdata_chunking_test.py @@ -16,6 +16,7 @@ """ import asyncio +import concurrent.futures import datetime import sys from unittest import mock @@ -1214,8 +1215,8 @@ def test_iter_sub_args_passthrough_yields_a_copy(): # for the whole suite. Each test below overrides it so the wrapper takes # the parallel branch. The decorator's ``fetch_async`` accepts any # coroutine returning ``(df, response)`` — no real ``httpx.AsyncClient`` -# round-trip occurs, even though :func:`_fan_out_async` opens one for -# pool management. +# round-trip occurs, even though :meth:`ChunkedCall.resume_async` opens one +# for pool management. def _async_chunked_fetch(monkeypatch, fetch_async, *, max_concurrent=16): @@ -1327,58 +1328,90 @@ def fetch(args): assert len(df) == interrupted.total_chunks -@pytest.mark.parametrize( - "fallback_trigger,warning_match", - [ - pytest.param("running_loop", "running asyncio event loop", id="running-loop"), - pytest.param("no_fetch_async", "no async fetch sibling", id="missing-async"), - ], -) -def test_async_falls_back_to_serial_with_warning( - monkeypatch, fallback_trigger, warning_match -): - """The parallel path falls back to the serial ``ChunkedCall`` - (with a ``UserWarning``) in two situations: - - * a running asyncio event loop (Jupyter / IPython kernels, async - apps) — ``asyncio.run`` would otherwise raise ``RuntimeError``; - * the decorator wasn't wired with a ``fetch_async=`` sibling — - ``API_USGS_CONCURRENT`` would otherwise be a silent no-op. - """ - sync_calls = [] +def test_async_fan_out_resume_applies_finalize(monkeypatch): + """The ``finalize`` injected for a PARALLEL call survives the interruption + (carried on the ``ChunkedCall`` through the anyio portal), so a serial + ``exc.call.resume()`` still returns the finalized shape — guarding the + parallel resume_async -> resume -> finalize path the serial-pinned finalize + test can't reach. Partials stay raw (no finalize in the exception ctor).""" + + def finalize(frame, response): + return frame.assign(finalized=True), ("MD", response) + + call_count = {"async": 0} + + async def fetch_async(args): + call_count["async"] += 1 + if call_count["async"] == 1: + return pd.DataFrame({"id": [_atom_id(args)]}), _ok_response(remaining=99) + raise ServiceUnavailable("503: simulated") + monkeypatch.setenv("API_USGS_CONCURRENT", "16") - if fallback_trigger == "running_loop": + @multi_value_chunked( + build_request=_fake_build, fetch_async=fetch_async, url_limit=240 + ) + def fetch(args): + return pd.DataFrame({"id": [_atom_id(args)]}), _ok_response(remaining=99) - async def fetch_async(args): - raise AssertionError("parallel path must not run inside an active loop") + sites = ["S1" * 10, "S2" * 10, "S3" * 10, "S4" * 10] + with pytest.raises(ServiceInterrupted) as exc_info: + fetch({"sites": sites}, finalize=finalize) - @multi_value_chunked( - build_request=_fake_build, fetch_async=fetch_async, url_limit=240 - ) - def fetch(args): - sync_calls.append(tuple(args["sites"])) - return pd.DataFrame({"id": [_atom_id(args)]}), _ok_response() + # Partial snapshot stays raw — building the exception must not finalize. + assert "finalized" not in exc_info.value.partial_frame.columns + # Resume applies the finalize carried on the ChunkedCall. + df, md = exc_info.value.call.resume() + assert "finalized" in df.columns + assert md[0] == "MD" - async def driver(): - with pytest.warns(UserWarning, match=warning_match): - return fetch({"sites": ["S1" * 10, "S2" * 10, "S3" * 10, "S4" * 10]}) - df, _ = asyncio.run(driver()) - else: +def test_async_falls_back_to_serial_when_no_fetch_async(monkeypatch): + """With no ``fetch_async=`` sibling wired, a parallel + ``API_USGS_CONCURRENT`` can't be honored, so the call falls back to + the serial ``ChunkedCall`` path with a one-time ``UserWarning`` + rather than silently no-op'ing the env var.""" + sync_calls = [] + monkeypatch.setenv("API_USGS_CONCURRENT", "16") - @multi_value_chunked(build_request=_fake_build, url_limit=240) - def fetch(args): - sync_calls.append(tuple(args["sites"])) - return pd.DataFrame({"id": [_atom_id(args)]}), _ok_response() + @multi_value_chunked(build_request=_fake_build, url_limit=240) + def fetch(args): + sync_calls.append(tuple(args["sites"])) + return pd.DataFrame({"id": [_atom_id(args)]}), _ok_response() - with pytest.warns(UserWarning, match=warning_match): - df, _ = fetch({"sites": ["S1" * 10, "S2" * 10, "S3" * 10, "S4" * 10]}) + with pytest.warns(UserWarning, match="no async fetch sibling"): + df, _ = fetch({"sites": ["S1" * 10, "S2" * 10, "S3" * 10, "S4" * 10]}) assert len(sync_calls) > 1 assert len(df) == len(sync_calls) +def test_async_fan_out_runs_inside_running_event_loop(monkeypatch): + """The parallel fan-out works even when the caller is already inside a + running event loop (Jupyter / async apps): the anyio blocking portal + runs it in a worker thread, so it neither raises a nested + ``asyncio.run`` error nor silently degrades to the serial path.""" + monkeypatch.setenv("API_USGS_CONCURRENT", "16") + async_calls = [] + + async def fetch_async(args): + async_calls.append(tuple(args["sites"])) + return pd.DataFrame({"id": [_atom_id(args)]}), _ok_response() + + @multi_value_chunked( + build_request=_fake_build, fetch_async=fetch_async, url_limit=240 + ) + def fetch(args): # sync sibling must NOT run — the async path handles it + raise AssertionError("serial fallback must not run inside a live loop") + + async def driver(): # call the sync getter from within a running loop + return fetch({"sites": ["S1" * 10, "S2" * 10, "S3" * 10, "S4" * 10]}) + + df, _ = asyncio.run(driver()) + assert len(async_calls) > 1 # every sub-request ran on the async path + assert len(df) == len(async_calls) + + def test_async_fan_out_cancellation_wins_over_transient_sibling(monkeypatch): """``asyncio.CancelledError`` raised by any sub-request must propagate unmodified, even when a sibling raises a recognized @@ -1387,11 +1420,16 @@ def test_async_fan_out_cancellation_wins_over_transient_sibling(monkeypatch): signal — letting a transient-classification path consume it would silently swallow the user's stop request. - fetch_async has no ``await`` inside its body, so gather schedules + ``fetch_async`` has no ``await`` in its body, so the gather schedules the tasks in submission order and each runs synchronously to its - raise — making ``call_count`` deterministic for this test: - 1 = probe, 2 = first fan-out task (transient), 3 = second - fan-out task (cancellation). + raise — making ``call_count`` deterministic: 1 = first chunk + (success), 2 = second chunk (transient), 3 = third chunk (cancel). + + Through the sync→async blocking portal an in-flight cancellation + surfaces to the caller as ``concurrent.futures.CancelledError`` (the + thread-boundary cancellation type) rather than ``asyncio.CancelledError`` + — either way it propagates unmodified rather than being swallowed and + wrapped as a resumable ``ChunkInterrupted``. """ call_count = {"async": 0} @@ -1412,7 +1450,7 @@ async def fetch_async(args): # transient (call 2) AND the cancellation (call 3). sites = [f"S{i}" * 10 for i in range(1, 9)] - with pytest.raises(asyncio.CancelledError): + with pytest.raises((asyncio.CancelledError, concurrent.futures.CancelledError)): fetch({"sites": sites}) @@ -1714,3 +1752,57 @@ async def fetch_async(args): fetch = _async_chunked_fetch(monkeypatch, fetch_async) with pytest.raises(ValueError, match="deterministic bug"): fetch({"sites": ["S1" * 10, "S2" * 10, "S3" * 10, "S4" * 10]}) + + +# --- finalize hook (resume finalizes; partials stay raw) ------------------- +# +# Regression for the bug where ``exc.call.resume()`` returned the chunker's +# raw ``(frame, httpx.Response)`` instead of the post-processed shape a normal +# getter call yields. The fix injects a ``finalize`` transform applied at the +# terminal resume()/resume_async() returns. The partial_* accessors stay RAW +# so building/inspecting a ChunkInterrupted never triggers finalize's side +# effects (for OGC, _deal_with_empty issues a schema network GET on an empty +# frame — that must NOT fire inside the exception constructor). + + +def test_resume_finalizes_but_partials_stay_raw(monkeypatch): + """resume() applies the injected ``finalize``; ``partial_frame`` / + ``partial_response`` are the raw snapshot, and constructing the + ``ChunkInterrupted`` must not invoke ``finalize`` at all (no side effects + such as a schema fetch in the exception ctor).""" + calls = {"finalize": 0} + + def finalize(frame, response): + # Stand in for the OGC pipeline: mark the frame and wrap the response. + calls["finalize"] += 1 + return frame.assign(finalized=True), ("METADATA", response) + + # Fail the 2nd issued sub-request once (the 1st completes, so partial + # state is non-empty), then succeed on resume. Conftest pins the serial, + # no-retry path, so the failure surfaces immediately. + state = {"n": 0} + + @multi_value_chunked(build_request=_fake_build, url_limit=240) + def fetch(args): + state["n"] += 1 + if state["n"] == 2: + raise ServiceUnavailable("503: simulated") + return pd.DataFrame({"id": [_atom_id(args)]}), _ok_response() + + sites = ["S1" * 10, "S2" * 10, "S3" * 10, "S4" * 10] + with pytest.raises(ServiceInterrupted) as exc_info: + fetch({"sites": sites}, finalize=finalize) + + interrupted = exc_info.value + assert interrupted.completed_chunks >= 1 + # Building the exception did NOT run finalize — no network/side effects. + assert calls["finalize"] == 0 + # Partial snapshot is the RAW combined frame/response (not finalized). + assert "finalized" not in interrupted.partial_frame.columns + assert not isinstance(interrupted.partial_response, tuple) + + # Resume DOES finalize and yields the same shape a normal call would. + df, md = interrupted.call.resume() + assert "finalized" in df.columns + assert md[0] == "METADATA" + assert calls["finalize"] >= 1 diff --git a/tests/waterdata_progress_test.py b/tests/waterdata_progress_test.py index a98dc76a..6efba212 100644 --- a/tests/waterdata_progress_test.py +++ b/tests/waterdata_progress_test.py @@ -105,6 +105,16 @@ def test_note_retry_is_noop_when_disabled(): assert stream.getvalue() == "" +def test_note_retry_accepts_integer_wait(): + # An int ``wait`` (e.g. whole seconds) must render without raising: + # ``round(int, 1)`` returns an int and ``int.is_integer()`` only exists + # on Python 3.12+, while the package floor is 3.9. Renders like the float. + stream = io.StringIO() + reporter = ProgressReporter(stream=stream, enabled=True) + reporter.note_retry(attempt=1, wait=5) + assert "retrying (attempt 1, waiting 5s)" in stream.getvalue() + + def test_chunk_segment_only_shown_when_multiple_chunks(): single = io.StringIO() reporter = ProgressReporter(stream=single, enabled=True) @@ -482,7 +492,7 @@ def test_fan_out_async_sets_chunks_on_active_reporter(monkeypatch): import pandas as pd - from dataretrieval.waterdata.chunking import ChunkPlan, _fan_out_async + from dataretrieval.waterdata.chunking import ChunkedCall, ChunkPlan # Fake build_request whose URL length scales with the sites list, # mirroring the planner's _request_bytes contract. _FakeReq has the @@ -514,7 +524,9 @@ def fetch_once(args): # noqa: ARG001 — never invoked on the happy parallel pa async def run(): with progress_context(service="daily", stream=stream, enabled=True) as rep: - await _fan_out_async(plan, fetch_once, fetch_async, max_concurrent=4) + await ChunkedCall(plan, fetch_once).resume_async( + fetch_async, max_concurrent=4 + ) return rep.total_chunks, rep.current_chunk total_recorded, current_recorded = asyncio.run(run()) diff --git a/tests/waterdata_test.py b/tests/waterdata_test.py index 09f66aa5..ea91ae9e 100644 --- a/tests/waterdata_test.py +++ b/tests/waterdata_test.py @@ -551,6 +551,25 @@ def test_get_reference_table_wrong_name(): get_reference_table("agency-cod") +@pytest.mark.parametrize("bad", [0, -1, 2.5, 10.0, True]) +def test_get_reference_table_rejects_bad_max_rows(bad): + # max_rows must be a genuine positive int; a non-positive value, a float + # (even integral like 10.0), or a bool must raise ValueError up front — + # not crash later inside pandas .head(). Raises before any HTTP request. + with pytest.raises(ValueError, match="positive integer"): + get_reference_table("agency-codes", max_rows=bad) + + +def test_get_reference_table_accepts_numpy_int_max_rows(): + # numpy integers are valid caps: isinstance(np.int64, int) is False, so the + # validation must accept numbers.Integral (not just int) — otherwise a cap + # derived from a numpy/pandas computation is wrongly rejected. + import numpy as np + + df, _ = get_reference_table("agency-codes", max_rows=np.int64(2)) + assert len(df) == 2 + + def test_get_stats_por(): df, _ = get_stats_por( monitoring_location_id="USGS-12451000", diff --git a/tests/waterdata_utils_test.py b/tests/waterdata_utils_test.py index 413f39c8..6aa2728b 100644 --- a/tests/waterdata_utils_test.py +++ b/tests/waterdata_utils_test.py @@ -96,6 +96,99 @@ def test_walk_pages_multiple_mocked(): assert mock_client.request.call_args[0][1] == "https://example.com/page2" +def test_row_cap_truncates_and_stops_within_first_page(): + # Regression for BUG 2: ``_row_cap`` bounds the TOTAL rows. A first page + # already over the cap is truncated to exactly ``max_rows`` and the + # ``next`` link is never followed. + from dataretrieval.waterdata.utils import _row_cap + + resp1 = mock.MagicMock() + resp1.json.return_value = { + "numberReturned": 3, + "features": [{"id": str(i), "properties": {"val": i}} for i in range(3)], + "links": [{"rel": "next", "href": "https://example.com/page2"}], + } + resp1.headers = {} + resp1.status_code = 200 + resp1.url = "https://example.com/page1" + + mock_client = mock.MagicMock(spec=httpx.Client) + mock_client.send.return_value = resp1 + + mock_req = mock.MagicMock(spec=httpx.Request) + mock_req.method = "GET" + mock_req.headers = {} + mock_req.url = "https://example.com/page1" + + with _row_cap(2): + df, _ = _walk_pages(geopd=False, req=mock_req, client=mock_client) + + assert len(df) == 2 # truncated to the cap, not the page's 3 rows + assert not mock_client.request.called # ``next`` link never followed + + +def test_row_cap_stops_across_pages(): + # The cap accumulates across pages: page 1 (1 row) is under the cap so + # page 2 is fetched; once the cap (2) is met the third page is NOT. + from dataretrieval.waterdata.utils import _row_cap + + def _page(idx, *, has_next): + resp = mock.MagicMock() + nxt = f"https://example.com/page{idx + 1}" + resp.json.return_value = { + "numberReturned": 1, + "features": [{"id": str(idx), "properties": {"val": idx}}], + "links": [{"rel": "next", "href": nxt}] if has_next else [], + } + resp.headers = {} + resp.status_code = 200 + resp.url = f"https://example.com/page{idx}" + return resp + + mock_client = mock.MagicMock(spec=httpx.Client) + mock_client.send.return_value = _page(1, has_next=True) + # page 2 still advertises a ``next`` (page 3) that must never be fetched. + mock_client.request.return_value = _page(2, has_next=True) + + mock_req = mock.MagicMock(spec=httpx.Request) + mock_req.method = "GET" + mock_req.headers = {} + mock_req.url = "https://example.com/page1" + + with _row_cap(2): + df, _ = _walk_pages(geopd=False, req=mock_req, client=mock_client) + + assert len(df) == 2 + assert mock_client.request.call_count == 1 # fetched page 2, stopped before 3 + + +def test_finalize_ogc_truncates_combined_to_max_rows(): + # max_rows is enforced on the *combined* frame in _finalize_ogc (after + # dedup/sort), so it bounds the total exactly even when a chunked call's + # per-sub-request pages overshoot the per-_paginate early-stop. + import datetime + + from dataretrieval.waterdata.utils import _finalize_ogc + + frame = pd.DataFrame({"id": [str(i) for i in range(10)]}) + resp = mock.MagicMock() + resp.url = "https://example.com/q" + resp.elapsed = datetime.timedelta(seconds=0.1) + resp.headers = {} + + df, md = _finalize_ogc( + frame, + resp, + properties=None, + output_id="thing_id", + convert_type=False, + service="things", + max_rows=3, + ) + assert len(df) == 3 + assert hasattr(md, "url") # wrapped as BaseMetadata + + def _resp_ok(features): """Build a 200-OK mock response carrying the given features list.""" links = [{"rel": "next", "href": "https://example.com/page2"}] if features else [] From f094d42fd19f873ee10d2d53338939fb5d2913ea Mon Sep 17 00:00:00 2001 From: thodson-usgs Date: Wed, 27 May 2026 13:06:26 -0500 Subject: [PATCH 04/16] refactor(waterdata): consolidate trivial chunker helpers; NumPy docstrings MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit No behavior change — tightening and doc consistency on chunking.py: - Inline single-use helpers into their sole callers: `_note_retry` into `_retry_delay`, and `_ordered_chunks`/`_responses_by_completion` into `_combine_raw` (the two distinct orderings — frames by index, responses by completion — are now documented inline). - Drop the redundant `_completion_order` shadow list: `record` is the only writer of `_chunks` and `dict` preserves insertion order, so completion order is just iteration order. - Use the `completed_chunks` accessor consistently (was `len(self._chunks)` in `wrap_failure`). - Un-quote the `_Finalize` alias (`tuple[...]` is a valid runtime expression on the >=3.9 floor). - Reformat 15 prose/undocumented PR functions to NumPy docstring style (Parameters/Returns/Raises/Yields), matching the rest of the package and fixing sibling inconsistencies (e.g. `get_active_async_client`, `combined`, `_pending`). Co-Authored-By: Claude Opus 4.7 (1M context) --- dataretrieval/waterdata/chunking.py | 263 +++++++++++++++++++++++----- 1 file changed, 215 insertions(+), 48 deletions(-) diff --git a/dataretrieval/waterdata/chunking.py b/dataretrieval/waterdata/chunking.py index 75121cd8..b2645a26 100644 --- a/dataretrieval/waterdata/chunking.py +++ b/dataretrieval/waterdata/chunking.py @@ -253,12 +253,19 @@ def __post_init__(self) -> None: @classmethod def from_env(cls) -> RetryPolicy: - """Build a policy from the module-level defaults, resolved now. + """ + Build a policy from the module-level defaults, resolved now. ``max_retries`` comes from ``API_USGS_RETRIES``; the timing knobs are read from the ``_RETRY_*`` module constants at call time (not the dataclass field defaults, which freeze at class definition) so ``monkeypatch.setattr`` on those constants takes effect. + + Returns + ------- + RetryPolicy + A policy built from the module-level defaults resolved at + call time. """ return cls( max_retries=_read_retries_env(), @@ -268,18 +275,47 @@ def from_env(cls) -> RetryPolicy: ) def should_retry(self, attempt: int, retry_after: float | None) -> bool: - """Whether a just-failed ``attempt`` (1-based) warrants another try. + """ + Whether a just-failed ``attempt`` (1-based) warrants another try. A ``Retry-After`` longer than ``retry_after_cap`` is *not* slept off inline — it returns ``False`` so the failure escalates to a resumable interruption instead of blocking the call for minutes. + + Parameters + ---------- + attempt : int + The just-failed attempt number (1-based). + retry_after : float or None + Seconds the server suggested waiting (``Retry-After`` hint), + or ``None`` when no hint was given. + + Returns + ------- + bool + ``True`` if another try is warranted, ``False`` otherwise. """ if attempt > self.max_retries: return False return retry_after is None or retry_after <= self.retry_after_cap def backoff(self, attempt: int, retry_after: float | None) -> float: - """Seconds to wait before retry ``attempt`` (1-based).""" + """ + Seconds to wait before retry ``attempt`` (1-based). + + Parameters + ---------- + attempt : int + The retry attempt number (1-based). + retry_after : float or None + Seconds the server suggested waiting (``Retry-After`` hint), + or ``None`` to use the computed exponential backoff instead. + + Returns + ------- + float + Seconds to wait before the retry. + """ if retry_after is not None: return retry_after ceiling = min(self.max_backoff, self.base_backoff * 2 ** (attempt - 1)) @@ -325,6 +361,19 @@ def _publish(var: ContextVar[_ClientT | None], client: _ClientT) -> Iterator[Non ``_chunked_async_client``) paths share one implementation, while the ``_ClientT`` type var still lets a type checker reject a var/client type mismatch. + + Parameters + ---------- + var : ContextVar + The ContextVar to bind ``client`` to for the duration of the + ``with`` block. + client + The client to publish on ``var``. + + Yields + ------ + None + Yields once, for the duration of the bind. """ token = var.set(client) try: @@ -356,6 +405,12 @@ def get_active_async_client() -> httpx.AsyncClient | None: Async sibling of :func:`get_active_client`. Used by async paginated-loop helpers to reuse the per-call AsyncClient pool. + + Returns + ------- + httpx.AsyncClient or None + The client published via :func:`_publish` if currently inside a + :class:`ChunkedCall` ``resume_async`` block; ``None`` otherwise. """ return _chunked_async_client.get() @@ -379,7 +434,7 @@ def get_active_async_client() -> httpx.AsyncClient | None: # result"; the OGC getters inject the actual type-coercion / column-arrangement # / ``BaseMetadata`` pipeline (see ``utils._finalize_ogc``). The default is # identity, so direct ``ChunkedCall`` use and the tests are unaffected. -_Finalize = Callable[[pd.DataFrame, httpx.Response], "tuple[pd.DataFrame, Any]"] +_Finalize = Callable[[pd.DataFrame, httpx.Response], tuple[pd.DataFrame, Any]] def _passthrough_result( @@ -1145,13 +1200,6 @@ def _retryable(exc: BaseException) -> tuple[bool, float | None]: _ASLEEP = asyncio.sleep -def _note_retry(attempt: int, wait: float) -> None: - """Surface an imminent retry on the active progress reporter, if any.""" - reporter = _progress.current() - if reporter is not None: - reporter.note_retry(attempt=attempt, wait=wait) - - def _retry_delay(exc: BaseException, attempt: int, policy: RetryPolicy) -> float | None: """ Decide the backoff for a just-failed ``attempt`` (1-based), or ``None`` @@ -1163,12 +1211,30 @@ def _retry_delay(exc: BaseException, attempt: int, policy: RetryPolicy) -> float Otherwise returns the seconds to wait and emits the progress-bar retry note. This is the whole retry *decision* — the sync and async drivers share it and differ only in how they call the fetch and how they sleep. + + Parameters + ---------- + exc : BaseException + The exception raised by the just-failed attempt. + attempt : int + The just-failed attempt number (1-based). + policy : RetryPolicy + The retry-with-backoff policy governing the decision. + + Returns + ------- + float or None + Seconds to wait before retrying, or ``None`` to give up and + re-raise. """ retryable, retry_after = _retryable(exc) if not retryable or not policy.should_retry(attempt, retry_after): return None delay = policy.backoff(attempt, retry_after) - _note_retry(attempt, delay) + # Surface the imminent retry on the active progress reporter, if any. + reporter = _progress.current() + if reporter is not None: + reporter.note_retry(attempt=attempt, wait=delay) return delay @@ -1182,6 +1248,19 @@ def _retry_sync( A non-retryable or policy-exhausted failure (see :func:`_retry_delay`) propagates unchanged so the caller's existing handling wraps it as a resumable :class:`ChunkInterrupted`. + + Parameters + ---------- + fn : Callable + Zero-arg callable that issues a single sub-request and returns + ``(frame, response)``. + policy : RetryPolicy + The retry-with-backoff policy governing the retries. + + Returns + ------- + tuple of (pandas.DataFrame, httpx.Response) + The ``(frame, response)`` pair from the first successful call. """ attempt = 0 while True: @@ -1199,7 +1278,22 @@ async def _retry_async( afn: Callable[[], Awaitable[tuple[pd.DataFrame, httpx.Response]]], policy: RetryPolicy, ) -> tuple[pd.DataFrame, httpx.Response]: - """Async sibling of :func:`_retry_sync` (awaits :func:`asyncio.sleep`).""" + """ + Async sibling of :func:`_retry_sync` (awaits :func:`asyncio.sleep`). + + Parameters + ---------- + afn : Callable + Zero-arg awaitable callable that issues a single sub-request and + returns ``(frame, response)``. + policy : RetryPolicy + The retry-with-backoff policy governing the retries. + + Returns + ------- + tuple of (pandas.DataFrame, httpx.Response) + The ``(frame, response)`` pair from the first successful call. + """ attempt = 0 while True: try: @@ -1378,8 +1472,8 @@ class ChunkedCall: accessors deliberately skip it and stay raw. partial_frame : pandas.DataFrame Raw combined frame of completed sub-requests (live; recomputed per - access). Not finalized — see :attr:`partial_frame`. - partial_response + access). Not finalized — call :meth:`resume` for the finished shape. + partial_response : httpx.Response or None Raw aggregate response (canonical URL restored), or ``None`` when nothing has completed yet (live; recomputed per access). """ @@ -1401,34 +1495,52 @@ def __init__( # subsequent ``resume()`` only re-issues the missing indices. # On the serial path this fills contiguously from 0. self._chunks: dict[int, tuple[pd.DataFrame, httpx.Response]] = {} - # Explicit completion order for response-header aggregation. - # Keeping this separate from ``_chunks`` avoids coupling that - # behavior to dict insertion semantics or future write patterns. - self._completion_order: list[int] = [] def record(self, index: int, pair: tuple[pd.DataFrame, httpx.Response]) -> None: - """Record a completed sub-request's ``(frame, response)`` pair - under its sub-args index. Used by both the serial loop in - :meth:`resume` and the parallel fan-out in - :meth:`resume_async` so the completion set stays - encapsulated.""" - if index not in self._chunks: - self._completion_order.append(index) + """ + Record a completed sub-request's ``(frame, response)`` pair under + its sub-args index. + + The single writer of ``self._chunks`` — used by both the serial + loop in :meth:`resume` and the parallel fan-out in + :meth:`resume_async` — so ``dict`` insertion order is completion + order (see :meth:`_responses_by_completion`). + + Parameters + ---------- + index : int + The sub-args index this completed pair belongs to. + pair : tuple of (pandas.DataFrame, httpx.Response) + The completed sub-request's ``(frame, response)`` pair. + """ self._chunks[index] = pair def wrap_failure(self, exc: BaseException) -> ChunkInterrupted | None: - """Build the matching :class:`ChunkInterrupted` carrying this + """ + Build the matching :class:`ChunkInterrupted` carrying this call when ``exc`` is a recognized transient transport failure; return ``None`` for unrecognized failures so the caller can re-raise. Encapsulates the ``classify → instantiate-with-call-state`` recipe so - :class:`ChunkedCall`'s private fields stay private.""" + :class:`ChunkedCall`'s private fields stay private. + + Parameters + ---------- + exc : BaseException + The exception raised by a sub-request. + + Returns + ------- + ChunkInterrupted or None + The matching :class:`ChunkInterrupted` subclass carrying this + call for a recognized transient failure; ``None`` otherwise. + """ classification = _classify_chunk_error(exc) if classification is None: return None interrupted_class, retry_after = classification return interrupted_class( - completed_chunks=len(self._chunks), + completed_chunks=self.completed_chunks, total_chunks=self.plan.total, call=self, retry_after=retry_after, @@ -1437,35 +1549,44 @@ def wrap_failure(self, exc: BaseException) -> ChunkInterrupted | None: @property def completed_chunks(self) -> int: - return len(self._chunks) - - def _ordered_chunks(self) -> list[tuple[pd.DataFrame, httpx.Response]]: - return [self._chunks[i] for i in sorted(self._chunks)] + """ + Number of sub-requests completed so far. - def _responses_by_completion(self) -> list[httpx.Response]: - # The final element is the most-recently completed sub-request, whose - # headers carry the freshest ``x-ratelimit-remaining`` for aggregation. - return [self._chunks[i][1] for i in self._completion_order] + Returns + ------- + int + The count of completed sub-requests. + """ + return len(self._chunks) def _combine_raw(self) -> tuple[pd.DataFrame, httpx.Response]: """Assemble the raw ``(frame, response)`` from completed sub-requests, before :attr:`finalize` runs. - Frames concatenate in sub-args *index* order (deterministic, - independent of parallel completion order); the aggregated response - takes its headers from the most-recently-*completed* sub-request, so - a fan-out that finished chunks out of index order still surfaces the - latest rate-limit state the server reported rather than a stale one. + Frames concatenate in sub-args *index* order (``sorted`` keys — + deterministic, independent of parallel completion order). The + aggregated response takes its headers from the most-recently- + *completed* sub-request: ``record`` is the only writer of + ``self._chunks`` and ``dict`` preserves insertion order, so the + chunks' natural order is completion order and the last one carries + the freshest ``x-ratelimit-remaining``. + + Returns + ------- + tuple of (pandas.DataFrame, httpx.Response) + The concatenated frame and the aggregated response, before + :attr:`finalize` is applied. """ + frames = [self._chunks[i][0] for i in sorted(self._chunks)] + responses = [response for _, response in self._chunks.values()] return ( - _combine_chunk_frames([frame for frame, _ in self._ordered_chunks()]), - _combine_chunk_responses( - self._responses_by_completion(), self.plan.canonical_url - ), + _combine_chunk_frames(frames), + _combine_chunk_responses(responses, self.plan.canonical_url), ) def combined(self) -> tuple[pd.DataFrame, Any]: - """Combine every recorded sub-request and apply :attr:`finalize`. + """ + Combine every recorded sub-request and apply :attr:`finalize`. The terminal *success* result: :meth:`resume` and :meth:`resume_async` both return this, so a completed call (whether @@ -1475,6 +1596,12 @@ def combined(self) -> tuple[pd.DataFrame, Any]: plus ``BaseMetadata``. The ``partial_*`` accessors deliberately do NOT go through here — they return the raw :meth:`_combine_raw` snapshot to stay cheap and side-effect-free. + + Returns + ------- + tuple of (pandas.DataFrame, finalized response) + The combined frame and the finalized aggregate response / + metadata that :attr:`finalize` produces. """ return self.finalize(*self._combine_raw()) @@ -1523,13 +1650,20 @@ def partial_response(self) -> httpx.Response | None: return self._combine_raw()[1] def _pending(self) -> Iterator[tuple[int, dict[str, Any]]]: - """Yield ``(index, sub_args)`` for sub-requests not yet completed. + """ + Yield ``(index, sub_args)`` for sub-requests not yet completed. The single source of the "walk :meth:`ChunkPlan.iter_sub_args` in deterministic order, skip any index already in ``self._chunks``" rule, shared by the serial :meth:`resume` and the parallel :meth:`resume_async` so the two execution paths can't drift on *which* sub-requests they still owe. + + Yields + ------ + tuple of (int, dict) + The sub-args ``index`` and its ``sub_args`` dict for each + sub-request not yet completed. """ for index, sub_args in enumerate(self.plan.iter_sub_args()): if index not in self._chunks: @@ -1605,6 +1739,13 @@ def _issue(self, index: int, sub_args: dict[str, Any]) -> None: ``TimeoutException``), and :class:`httpx.InvalidURL` (which inherits directly from ``Exception``, not ``HTTPError``); all three feed :func:`_classify_chunk_error`. + + Parameters + ---------- + index : int + The sub-args index this sub-request belongs to. + sub_args : dict + The substituted args dict for this sub-request. """ try: chunk = _retry_sync(lambda: self.fetch_once(sub_args), self.retry_policy) @@ -1847,6 +1988,32 @@ def _execute_in_parallel( IPython / async apps) — no nested-``asyncio.run`` error and no silent degradation to serial. The portal copies the calling context, so the active progress reporter still reaches the fan-out. + + Parameters + ---------- + plan : ChunkPlan + The chunking plan to execute. + fetch_once : Callable + Sync per-sub-request fetcher returning ``(df, response)``, used + on the serial fallback path. + fetch_async : Callable or None + Async per-sub-request fetcher returning ``(df, response)``. When + ``None``, the call falls back to the serial sync path with a + :class:`UserWarning`. + concurrency : int or None + Maximum in-flight sub-requests. ``None`` disables the cap. + retry_policy : RetryPolicy, optional + Per-sub-request retry-with-backoff policy. Defaults to + :data:`_NO_RETRY`. + finalize : Callable, optional + Transform applied to the combined ``(frame, response)`` (see + :data:`_Finalize`). Defaults to :func:`_passthrough_result`. + + Returns + ------- + tuple of (pandas.DataFrame, finalized response) + The combined frame and the finalized aggregate response that + ``finalize`` produces. """ if fetch_async is None: warnings.warn( From 7deedd74d7dcfcc467c866bbd790a418c1fb722f Mon Sep 17 00:00:00 2001 From: thodson-usgs Date: Wed, 27 May 2026 15:18:23 -0500 Subject: [PATCH 05/16] refactor(waterdata): collapse sync/async chunker into one async core MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The chunker carried full sync/async twins. Collapse to a single async implementation behind a synchronous facade (the public getters and `resume()` stay sync, same signatures/returns, by driving the async core through the anyio blocking portal). Removed twins: sync `_paginate`, `_walk_pages`, `_retry_sync`, `_client_for`, `get_active_client`, the `_chunked_client` ContextVar, and the `_fan_out_async`/`_execute_in_parallel`/`resume_async` split (folded into `ChunkedCall.resume()` -> `_run`). `multi_value_chunked` now decorates an `async def` fetcher (drops the `fetch_async=` param). `get_stats_data` drives `_paginate_async` through the portal. Concurrency is now bounded purely by the httpx connection pool (`httpx.Limits(max_connections=N)`) — the explicit `asyncio.Semaphore` is gone; `gather` dispatches every pending sub-request and the pool throttles (N=1 is a single-connection gather, total<=1 a one-element gather). Behavior note: because execution is now `gather(..., return_exceptions=True)` over all pending sub-requests, an interruption completes every non-failing sub-request before surfacing (even at concurrency=1) rather than stopping at the first failure; `resume()` then re-issues only the still-failed chunks. The public API, `resume()` contract, ChunkInterrupted/partials, finalize hook, max_rows cap, retries, and progress reporting are unchanged. Net ~-216 lines. Offline suite (265) + live getter suite (63) green. Co-Authored-By: Claude Opus 4.7 (1M context) --- dataretrieval/waterdata/chunking.py | 522 +++++++++------------------- dataretrieval/waterdata/utils.py | 322 ++++++----------- tests/conftest.py | 35 +- tests/waterdata_chunking_test.py | 399 ++++++++++++--------- tests/waterdata_filters_test.py | 25 +- tests/waterdata_progress_test.py | 40 ++- tests/waterdata_utils_test.py | 29 +- 7 files changed, 578 insertions(+), 794 deletions(-) diff --git a/dataretrieval/waterdata/chunking.py b/dataretrieval/waterdata/chunking.py index b2645a26..48da4159 100644 --- a/dataretrieval/waterdata/chunking.py +++ b/dataretrieval/waterdata/chunking.py @@ -9,18 +9,20 @@ sub-request URL fits. Requests that already fit get a trivial single-step plan — ``ChunkedCall`` has one code path either way. -Concurrency: when ``API_USGS_CONCURRENT`` is set to an integer N > 1 -(or the literal ``unbounded``), ``multi_value_chunked`` fans the plan -out across ``N`` async coroutines sharing one ``httpx.AsyncClient`` -instead of issuing sub-requests serially. ``N=1`` forces the -synchronous path. The default (16) is the server-friendly sweet -spot; higher values can trip USGS burst-protection 5xx in practice. -The fan-out runs in a short-lived worker thread (an -``anyio`` blocking portal), so it works whether or not the caller is -already inside an event loop (Jupyter / IPython / async apps) — no -nested-loop error and no silent serial degradation. It falls back to -the serial path (with a ``UserWarning``) only when no async fetch -sibling is wired into the decorator. +Concurrency: the execution core is async-only. ``multi_value_chunked`` +fans every pending sub-request out under one ``asyncio.gather`` sharing +a single ``httpx.AsyncClient``; concurrency is bounded purely by the +client's connection pool (``httpx.Limits(max_connections=N, +max_keepalive_connections=N)``), so the pool — not a semaphore — +throttles. ``API_USGS_CONCURRENT`` resolves ``N``: an integer N > 1 +caps connections at N; ``1`` pins a single connection (effectively +serial); the literal ``unbounded`` removes the cap (``N=None``). The +default (16) is the server-friendly sweet spot; higher values can trip +USGS burst-protection 5xx in practice. The fan-out runs in a +short-lived worker thread (an ``anyio`` blocking portal), so the +synchronous public API drives it whether or not the caller is already +inside an event loop (Jupyter / IPython / async apps) — no nested-loop +error and no silent serial degradation. Retries: each sub-request is retried on a transient failure (429, 5xx, connect/read timeout) with exponential backoff + full jitter, @@ -34,11 +36,11 @@ as a ``ChunkInterrupted`` subclass — ``QuotaExhausted`` for 429, ``ServiceInterrupted`` for 5xx. The exception carries ``.call``, a ``ChunkedCall`` handle that owns the already-completed sub-request -state (sparse-indexed on the parallel path, contiguous-prefix on -the serial path). Call ``.call.resume()`` once the underlying -condition clears; only the still-pending sub-requests are -re-issued, via the serial sync path. ``Retry-After`` (when the -server sets it) is surfaced on the exception as ``.retry_after``. +state (sparse-indexed, since gathered sub-requests complete out of +order). Call ``.call.resume()`` once the underlying condition clears; +only the still-pending sub-requests are re-issued (``resume()`` is a +synchronous facade over the same async runner). ``Retry-After`` (when +the server sets it) is surfaced on the exception as ``.retry_after``. Dedup: list-axis chunks don't overlap; filter-axis chunks can, so ``_combine_chunk_frames`` dedupes by feature ``id``. ``properties``, @@ -56,15 +58,12 @@ import math import os import random -import sys -import time -import warnings from collections.abc import Awaitable, Callable, Iterator from contextlib import contextmanager, suppress from contextvars import ContextVar from dataclasses import dataclass from datetime import timedelta -from typing import Any, ClassVar, TypeVar +from typing import Any, ClassVar from urllib.parse import quote_plus import httpx @@ -328,89 +327,57 @@ def backoff(self, attempt: int, retry_after: float | None) -> float: _NO_RETRY = RetryPolicy(max_retries=0) -# Client shared across all sub-requests of a single chunked call so -# paginated-loop helpers downstream (``_walk_pages``) reuse one -# connection pool across the whole fan-out. ``None`` when not inside a -# chunked call — paginated helpers fall back to their own short-lived -# client in that case. -_chunked_client: ContextVar[httpx.Client | None] = ContextVar( - "_chunked_client", default=None -) - -# Async sibling of ``_chunked_client``. Published (via :func:`_publish`) -# during ``ChunkedCall.resume_async`` so async paginated-loop helpers reuse one -# ``httpx.AsyncClient`` (and its connection pool) across every concurrent -# sub-request of a single chunked call. +# The single shared ``httpx.AsyncClient`` of an in-flight chunked call, +# published (via :func:`_publish`) during ``ChunkedCall._run`` so async +# paginated-loop helpers downstream (``_walk_pages_async``) reuse one +# connection pool across every gathered sub-request of the call. ``None`` +# when not inside a chunked call — paginated helpers fall back to their +# own short-lived client in that case. _chunked_async_client: ContextVar[httpx.AsyncClient | None] = ContextVar( "_chunked_async_client", default=None ) -_ClientT = TypeVar("_ClientT") - @contextmanager -def _publish(var: ContextVar[_ClientT | None], client: _ClientT) -> Iterator[None]: +def _publish(client: httpx.AsyncClient) -> Iterator[None]: """ - Bind ``client`` to the ContextVar ``var`` for the duration of the - ``with`` block (wrapping the set/reset token dance), so paginated-loop - helpers can borrow the chunker's shared client via - :func:`get_active_client` / :func:`get_active_async_client`. - - Generic over the client type so the sync (:class:`httpx.Client` via - ``_chunked_client``) and async (:class:`httpx.AsyncClient` via - ``_chunked_async_client``) paths share one implementation, while the - ``_ClientT`` type var still lets a type checker reject a var/client - type mismatch. + Bind ``client`` to the ``_chunked_async_client`` ContextVar for the + duration of the ``with`` block (wrapping the set/reset token dance), + so async paginated-loop helpers can borrow the chunker's shared + client via :func:`get_active_async_client`. Parameters ---------- - var : ContextVar - The ContextVar to bind ``client`` to for the duration of the - ``with`` block. - client - The client to publish on ``var``. + client : httpx.AsyncClient + The client to publish on ``_chunked_async_client``. Yields ------ None Yields once, for the duration of the bind. """ - token = var.set(client) + token = _chunked_async_client.set(client) try: yield finally: - var.reset(token) - - -def get_active_client() -> httpx.Client | None: - """ - Return the chunker's currently-published sync client, or ``None``. - - Public accessor for the ``_chunked_client`` ContextVar so - sibling modules (notably :func:`dataretrieval.waterdata.utils._client_for`) - don't have to reach into the private ContextVar directly. - - Returns - ------- - httpx.Client or None - The client published via :func:`_publish` if currently inside a - :class:`ChunkedCall` ``resume`` block; ``None`` otherwise. - """ - return _chunked_client.get() + _chunked_async_client.reset(token) def get_active_async_client() -> httpx.AsyncClient | None: """ Return the chunker's currently-published async client, or ``None``. - Async sibling of :func:`get_active_client`. Used by async + Public accessor for the ``_chunked_async_client`` ContextVar so + sibling modules (notably + :func:`dataretrieval.waterdata.utils._client_for_async`) don't have + to reach into the private ContextVar directly. Used by async paginated-loop helpers to reuse the per-call AsyncClient pool. Returns ------- httpx.AsyncClient or None The client published via :func:`_publish` if currently inside a - :class:`ChunkedCall` ``resume_async`` block; ``None`` otherwise. + :class:`ChunkedCall` run; ``None`` otherwise. """ return _chunked_async_client.get() @@ -421,10 +388,10 @@ def get_active_async_client() -> httpx.AsyncClient | None: _LIST_SEP = "," _OR_SEP = " OR " -_FetchOnce = Callable[[dict[str, Any]], tuple[pd.DataFrame, httpx.Response]] -_FetchOnceAsync = Callable[ - [dict[str, Any]], Awaitable[tuple[pd.DataFrame, httpx.Response]] -] +# The chunker's execution core is async-only: the decorated fetcher, the +# ``ChunkedCall`` it drives, and the per-sub-request runner are all +# coroutines. ``_Fetch`` is an ``async def fetch(args) -> (df, response)``. +_Fetch = Callable[[dict[str, Any]], Awaitable[tuple[pd.DataFrame, httpx.Response]]] # Caller-supplied transform applied to the *combined* chunk result. It lets a # resumed call (:meth:`ChunkedCall.resume` / :attr:`~ChunkedCall.partial_frame` @@ -938,7 +905,7 @@ def __init__( axes = _extract_axes(args) # No chunkable axes → skip ``build_request`` entirely; the # common Water Data call shape shouldn't pay for an unused - # request prep on the passthrough hot path. ``fetch_once`` + # request prep on the passthrough hot path. The fetcher # will run with the user's args verbatim; if that produces # an over-budget URL, the server (or httpx itself) rejects. if not axes: @@ -1071,20 +1038,20 @@ def iter_sub_args(self) -> Iterator[dict[str, Any]]: def execute( self, - fetch_once: _FetchOnce, + fetch: _Fetch, retry_policy: RetryPolicy = _NO_RETRY, finalize: _Finalize = _passthrough_result, ) -> tuple[pd.DataFrame, Any]: """ Run the plan and return the combined, finalized result. - Thin wrapper around ``ChunkedCall(self, fetch_once).resume()``; + Thin wrapper around ``ChunkedCall(self, fetch).resume()``; see :class:`ChunkedCall` for the per-sub-request semantics. Parameters ---------- - fetch_once : Callable - Function that issues a single sub-request, given the + fetch : Callable + ``async def`` that issues a single sub-request, given the substituted args dict, and returns ``(frame, response)``. retry_policy : RetryPolicy, optional Per-sub-request retry-with-backoff policy. Defaults to @@ -1109,7 +1076,7 @@ def execute( :class:`ServiceInterrupted` for 5xx). The resumable handle is on ``exc.call``. """ - return ChunkedCall(self, fetch_once, retry_policy, finalize).resume() + return ChunkedCall(self, fetch, retry_policy, finalize).resume() def _classify_chunk_error( @@ -1169,7 +1136,7 @@ def _retryable(exc: BaseException) -> tuple[bool, float | None]: Inspects only the *top-level* exception, by design — and so is deliberately narrower than :func:`_classify_chunk_error`, which walks - the ``__cause__`` chain for resumability. ``_paginate`` raises an + the ``__cause__`` chain for resumability. ``_paginate_async`` raises an initial-request transient (429 / 5xx / :class:`httpx.TransportError` such as ``ConnectError`` / ``ReadTimeout``) *raw*, but re-wraps any mid-pagination failure as a ``RuntimeError``. Retrying only the raw, @@ -1193,10 +1160,9 @@ def _retryable(exc: BaseException) -> tuple[bool, float | None]: return False, None -# Sleep hooks, indirected through module globals so tests can -# ``monkeypatch.setattr`` them to no-ops instead of waiting for real -# backoff. Production uses the stdlib calls. -_SLEEP = time.sleep +# Sleep hook, indirected through a module global so tests can +# ``monkeypatch.setattr`` it to a no-op instead of waiting for real +# backoff. Production uses the stdlib call. _ASLEEP = asyncio.sleep @@ -1238,48 +1204,18 @@ def _retry_delay(exc: BaseException, attempt: int, policy: RetryPolicy) -> float return delay -def _retry_sync( - fn: Callable[[], tuple[pd.DataFrame, httpx.Response]], +async def _retry_async( + afn: Callable[[], Awaitable[tuple[pd.DataFrame, httpx.Response]]], policy: RetryPolicy, ) -> tuple[pd.DataFrame, httpx.Response]: """ - Call ``fn`` with bounded retry-with-backoff on transient failures. + Call ``afn`` with bounded retry-with-backoff on transient failures. A non-retryable or policy-exhausted failure (see :func:`_retry_delay`) propagates unchanged so the caller's existing handling wraps it as a - resumable :class:`ChunkInterrupted`. - - Parameters - ---------- - fn : Callable - Zero-arg callable that issues a single sub-request and returns - ``(frame, response)``. - policy : RetryPolicy - The retry-with-backoff policy governing the retries. - - Returns - ------- - tuple of (pandas.DataFrame, httpx.Response) - The ``(frame, response)`` pair from the first successful call. - """ - attempt = 0 - while True: - try: - return fn() - except Exception as exc: # noqa: BLE001 — re-raised unless retryable - attempt += 1 - delay = _retry_delay(exc, attempt, policy) - if delay is None: - raise - _SLEEP(delay) - - -async def _retry_async( - afn: Callable[[], Awaitable[tuple[pd.DataFrame, httpx.Response]]], - policy: RetryPolicy, -) -> tuple[pd.DataFrame, httpx.Response]: - """ - Async sibling of :func:`_retry_sync` (awaits :func:`asyncio.sleep`). + resumable :class:`ChunkInterrupted`. The whole retry *decision* lives + in :func:`_retry_delay`; this driver only awaits the sleep between + attempts. Parameters ---------- @@ -1390,7 +1326,7 @@ def _combine_chunk_responses( One response per completed sub-request, in execution order. canonical_url : str or None URL of the unchunked original request. ``None`` skips the URL - override — used by the passthrough path (``fetch_once``'s + override — used by the passthrough path (the fetcher's response already carries the original-query URL) and by the worst-case overflow path (no buildable canonical URL exists). @@ -1433,43 +1369,51 @@ class ChunkedCall: Stateful handle for a chunked call. Holds the in-flight state (per-sub-request frames and responses) - and exposes a single :meth:`resume` entry point that drives the - call from wherever it is to completion — used both for the first - invocation (from :meth:`ChunkPlan.execute`) and for subsequent - retries after a :class:`ChunkInterrupted`. + and the async fetcher, and exposes a single :meth:`resume` entry + point that drives the call from wherever it is to completion — used + both for the first invocation (from :meth:`ChunkPlan.execute`) and + for subsequent retries after a :class:`ChunkInterrupted`. + + The execution core is the async :meth:`_run` (gather every pending + sub-request over one shared :class:`httpx.AsyncClient`, apply the + failure-precedence rules, combine); :meth:`resume` is a thin + synchronous facade that drives :meth:`_run` through an ``anyio`` + blocking portal so it works whether or not the caller is already + inside an event loop. There is no separate serial path: concurrency + is bounded purely by the client's connection pool, so a single + connection (``API_USGS_CONCURRENT=1``) is just a degenerate gather. A ``ChunkedCall`` is created internally when a :class:`ChunkPlan` executes; callers reach it via :attr:`ChunkInterrupted.call` on the exception raised by a mid-stream failure. - :meth:`resume` is idempotent: it iterates + :meth:`resume` is idempotent: :meth:`_run` iterates :meth:`ChunkPlan.iter_sub_args` (deterministic order) and skips any index whose result is already in ``self._chunks``. The completion set is a sparse ``dict[int, (df, response)]`` so the - parallel path can record scattered completions (e.g. indices - [0, 2, 5] after siblings [1, 3, 4] failed) and a subsequent - ``resume`` only re-issues the missing indices — via the serial - sync ``fetch_once`` path. + gather can record scattered completions (e.g. indices [0, 2, 5] + after siblings [1, 3, 4] failed) and a subsequent ``resume`` only + re-issues the missing indices. Parameters ---------- plan : ChunkPlan The chunking plan to execute. - fetch_once : Callable - Function that issues a single sub-request, given the + fetch : Callable + ``async def`` that issues a single sub-request, given the substituted args dict, and returns ``(frame, response)``. Attributes ---------- plan : ChunkPlan The plan being driven (read-only after construction). - fetch_once : Callable - The per-sub-request fetch function. + fetch : Callable + The async per-sub-request fetch function. finalize : Callable Transform applied to the combined result (see :data:`_Finalize`) at - the terminal :meth:`resume` / :meth:`resume_async` returns, so a - completed call yields the caller's finished shape. The ``partial_*`` - accessors deliberately skip it and stay raw. + the terminal :meth:`_run` return, so a completed call yields the + caller's finished shape. The ``partial_*`` accessors deliberately + skip it and stay raw. partial_frame : pandas.DataFrame Raw combined frame of completed sub-requests (live; recomputed per access). Not finalized — call :meth:`resume` for the finished shape. @@ -1481,19 +1425,18 @@ class ChunkedCall: def __init__( self, plan: ChunkPlan, - fetch_once: _FetchOnce, + fetch: _Fetch, retry_policy: RetryPolicy = _NO_RETRY, finalize: _Finalize = _passthrough_result, ) -> None: self.plan = plan - self.fetch_once = fetch_once + self.fetch = fetch self.retry_policy = retry_policy self.finalize = finalize # Completed (frame, response) pairs keyed by sub-args index. - # Sparse so the parallel fan-out path can record scattered - # completions (e.g. indices [0, 2, 5] when 1/3/4 failed) and a - # subsequent ``resume()`` only re-issues the missing indices. - # On the serial path this fills contiguously from 0. + # Sparse so the gather can record scattered completions (e.g. + # indices [0, 2, 5] when 1/3/4 failed) and a subsequent + # ``resume()`` only re-issues the missing indices. self._chunks: dict[int, tuple[pd.DataFrame, httpx.Response]] = {} def record(self, index: int, pair: tuple[pd.DataFrame, httpx.Response]) -> None: @@ -1501,10 +1444,9 @@ def record(self, index: int, pair: tuple[pd.DataFrame, httpx.Response]) -> None: Record a completed sub-request's ``(frame, response)`` pair under its sub-args index. - The single writer of ``self._chunks`` — used by both the serial - loop in :meth:`resume` and the parallel fan-out in - :meth:`resume_async` — so ``dict`` insertion order is completion - order (see :meth:`_responses_by_completion`). + The single writer of ``self._chunks`` — used by the gather in + :meth:`_run` — so ``dict`` insertion order is completion order + (see :meth:`_combine_raw`). Parameters ---------- @@ -1588,9 +1530,8 @@ def combined(self) -> tuple[pd.DataFrame, Any]: """ Combine every recorded sub-request and apply :attr:`finalize`. - The terminal *success* result: :meth:`resume` and - :meth:`resume_async` both return this, so a completed call (whether - serial or parallel, first run or resume) yields the same shape + The terminal *success* result: :meth:`_run` returns this, so a + completed call (first run or resume) yields the same shape ``finalize`` produces — a raw ``(frame, httpx.Response)`` by default, or the OGC getters' type-coerced / column-arranged frame plus ``BaseMetadata``. The ``partial_*`` accessors deliberately do @@ -1655,9 +1596,8 @@ def _pending(self) -> Iterator[tuple[int, dict[str, Any]]]: The single source of the "walk :meth:`ChunkPlan.iter_sub_args` in deterministic order, skip any index already in ``self._chunks``" - rule, shared by the serial :meth:`resume` and the parallel - :meth:`resume_async` so the two execution paths can't drift on - *which* sub-requests they still owe. + rule that :meth:`_run` uses to decide *which* sub-requests it + still owes (first run and every resume alike). Yields ------ @@ -1671,21 +1611,19 @@ def _pending(self) -> Iterator[tuple[int, dict[str, Any]]]: def resume(self) -> tuple[pd.DataFrame, Any]: """ - Drive the chunked call to completion via the sync ``fetch_once``. + Drive the chunked call to completion. Synchronous facade. - Opens one ``httpx.Client`` for the run and publishes it on - the ``_chunked_client`` ``ContextVar`` so paginated-loop - helpers downstream (``_walk_pages``) reuse the same connection - pool across every sub-request instead of handshaking fresh on - each. The client is closed when ``resume`` returns or raises; - a follow-up ``resume`` call (after a ``ChunkInterrupted``) - opens a new one. + Runs the async core :meth:`_run` through an ``anyio`` blocking + portal (a short-lived worker thread), so the synchronous public + API works whether or not the caller is already inside an event + loop (Jupyter / IPython / async apps) — no nested-``asyncio.run`` + error. The portal copies the calling context, so the active + progress reporter still reaches the fan-out. Idempotent: only sub-requests whose index isn't already in ``self._chunks`` are re-issued. Sub-args order matches :meth:`ChunkPlan.iter_sub_args` and is deterministic, so a - parallel-mode partial completion (sparse indices) resumes - correctly via the sync path. + partial completion (sparse indices) resumes correctly. Returns ------- @@ -1693,9 +1631,10 @@ def resume(self) -> tuple[pd.DataFrame, Any]: Combined data from every successful sub-request. response The finalized aggregate — a raw :class:`httpx.Response` - (canonical URL, last page's headers, cumulative elapsed time) - by default, or whatever :attr:`finalize` produces (e.g. - ``BaseMetadata`` for the OGC getters). + (canonical URL, most-recently-completed sub-request's headers, + cumulative elapsed time) by default, or whatever + :attr:`finalize` produces (e.g. ``BaseMetadata`` for the OGC + getters). Raises ------ @@ -1706,62 +1645,15 @@ def resume(self) -> tuple[pd.DataFrame, Any]: is on ``exc.call`` — wait for the underlying condition to clear and call ``exc.call.resume()`` again. """ - with httpx.Client(**HTTPX_DEFAULTS) as client: - with _publish(_chunked_client, client): - reporter = _progress.current() - if reporter is not None: - reporter.set_chunks(self.plan.total) - for index, sub_args in self._pending(): - # Serial progress semantics: announce the chunk we're - # *about to* fetch (1-based), so the line reads - # "chunk k/total" while that fetch + its pages are in - # flight. (The parallel path can't do this — chunks fire - # at once and finish out of order — so :meth:`resume_async` - # instead ticks the completed *count*; the two are - # deliberately different, not drift.) - if reporter is not None: - reporter.start_chunk(index + 1) - self._issue(index, sub_args) - return self.combined() + concurrency = _read_concurrency_env() + with start_blocking_portal() as portal: + return portal.call(functools.partial(self._run, concurrency)) - def _issue(self, index: int, sub_args: dict[str, Any]) -> None: - """ - Issue one sub-request and record its ``(frame, response)`` pair - under ``index``. - - On failure, classify the exception and either wrap it as a - resumable :class:`ChunkInterrupted` carrying this call, or - re-raise it unchanged to preserve its type. Catches - ``RuntimeError`` (the layer's typed contract: - :class:`RateLimited`, :class:`ServiceUnavailable`, or the - mid-pagination wrapper), :class:`httpx.HTTPError` - (transport-level failures like ``ConnectError`` / - ``TimeoutException``), and :class:`httpx.InvalidURL` (which - inherits directly from ``Exception``, not ``HTTPError``); all - three feed :func:`_classify_chunk_error`. - - Parameters - ---------- - index : int - The sub-args index this sub-request belongs to. - sub_args : dict - The substituted args dict for this sub-request. - """ - try: - chunk = _retry_sync(lambda: self.fetch_once(sub_args), self.retry_policy) - except (RuntimeError, httpx.HTTPError, httpx.InvalidURL) as exc: - interrupted = self.wrap_failure(exc) - if interrupted is None: - raise - raise interrupted from exc - self.record(index, chunk) - - async def resume_async( - self, fetch_async: _FetchOnceAsync, *, max_concurrent: int | None - ) -> tuple[pd.DataFrame, Any]: + async def _run(self, max_concurrent: int | None) -> tuple[pd.DataFrame, Any]: """ - Drive the chunked call to completion concurrently over one shared - :class:`httpx.AsyncClient`. Async sibling of :meth:`resume`. + The async execution core: gather every pending sub-request over + one shared :class:`httpx.AsyncClient` and return the combined, + finalized result. Pending sub-requests (:meth:`_pending`) fan out under ``asyncio.gather`` with ``return_exceptions=True`` so completed @@ -1769,23 +1661,22 @@ async def resume_async( transient (:class:`RateLimited`, :class:`ServiceUnavailable`) a :class:`ChunkInterrupted` subclass is raised carrying ``self`` on ``.call``; ``exc.call.resume()`` then re-issues only the unfinished - indices via the serial sync ``fetch_once`` path. The per-sub-request - bookkeeping (:meth:`_pending`, :meth:`record`, :meth:`wrap_failure`, - :meth:`combined`) is shared with :meth:`resume`, so the two execution - paths differ only in serial ``for`` vs concurrent ``gather``. + indices through this same runner. - In-flight sub-requests are capped by an :class:`asyncio.Semaphore`; - ``max_concurrent=None`` ("unbounded") uses ``sys.maxsize`` so every - call site takes the same ``async with semaphore`` path. The shared - client is published on :data:`_chunked_async_client` so async - paginated-loop helpers reuse its connection pool. + Concurrency is bounded purely by the client's connection pool — + ``httpx.Limits(max_connections=N, max_keepalive_connections=N)`` + where ``N = max_concurrent`` (``None`` for unbounded). There is no + semaphore: the gather dispatches *every* pending sub-request and the + pool throttles, so ``N=1`` is just a single-connection gather + (effectively serial) and ``total <= 1`` is just a one-element gather. + The shared client is published on :data:`_chunked_async_client` so + async paginated-loop helpers reuse its connection pool. Parameters ---------- - fetch_async : Callable - Async per-sub-request fetcher returning ``(df, response)``. max_concurrent : int or None - Maximum in-flight sub-requests. ``None`` disables the cap. + Maximum simultaneous connections (the pool cap). ``None`` + disables the cap. Returns ------- @@ -1802,25 +1693,19 @@ async def resume_async( ChunkInterrupted On a transient sub-request failure. ``.call`` is ``self``, holding the sparse completed sub-requests; ``.call.resume()`` - re-issues the unfinished ones serially. + re-issues the unfinished ones. """ - # ``httpx.Limits()`` defaults to ``max_connections=100`` — at - # higher concurrency the pool would silently bottleneck the - # fan-out behind the connection cap. Match it to the semaphore, - # or ``None`` for truly unbounded. + # ``httpx.Limits()`` defaults to ``max_connections=100`` — at higher + # concurrency the pool would silently bottleneck the fan-out behind + # that cap. Set it to the resolved concurrency so the pool *is* the + # throttle (``None`` for truly unbounded). No semaphore: we gather + # every pending sub-request and let the pool serialize. limits = httpx.Limits( max_connections=max_concurrent, max_keepalive_connections=max_concurrent ) - # ``None`` means "unbounded"; ``sys.maxsize`` stands in for it since - # ``asyncio.Semaphore`` only decrements a counter, never preallocates - # slots. Test ``is None`` explicitly so a stray ``0`` isn't silently - # promoted to unbounded by a falsy-``or``. - semaphore = asyncio.Semaphore( - sys.maxsize if max_concurrent is None else max_concurrent - ) async with httpx.AsyncClient(limits=limits, **HTTPX_DEFAULTS) as client: - with _publish(_chunked_async_client, client): + with _publish(client): reporter = _progress.current() if reporter is not None: reporter.set_chunks(self.plan.total) @@ -1828,25 +1713,19 @@ async def resume_async( async def track( index: int, args: dict[str, Any] ) -> tuple[pd.DataFrame, httpx.Response]: - """One sub-request (with retry) + record + progress tick. - - The retry loop runs *inside* the semaphore, so a chunk - backing off holds its slot — effective concurrency shrinks - under throttling instead of re-bursting against it. - """ - async with semaphore: - result = await _retry_async( - lambda: fetch_async(args), self.retry_policy - ) + """One sub-request (with retry) + record + progress tick.""" + result = await _retry_async( + lambda: self.fetch(args), self.retry_policy + ) self.record(index, result) if reporter is not None: - # Parallel progress semantics: chunks finish out of - # order, so tick the completed *count* rather than a - # positional index (see :meth:`resume`). + # Chunks finish out of order under gather, so tick the + # completed *count* rather than a positional index. reporter.start_chunk(self.completed_chunks) return result - # Dispatch every pending sub-request concurrently. + # Dispatch every pending sub-request concurrently; the + # connection pool (``limits``) is the only throttle. # ``return_exceptions`` keeps completed pairs after a sibling # fails, so partial state stays recoverable via :meth:`resume`. # Failure precedence, in order: @@ -1885,11 +1764,10 @@ async def track( def multi_value_chunked( *, build_request: Callable[..., httpx.Request], - fetch_async: _FetchOnceAsync | None = None, url_limit: int | None = None, -) -> Callable[[_FetchOnce], _FetchOnce]: +) -> Callable[[_Fetch], Callable[..., tuple[pd.DataFrame, Any]]]: """ - Decorate a fetch function to transparently chunk over-budget requests. + Decorate an async fetcher to transparently chunk over-budget requests. Splits multi-value list params and cql-text filters across sub-requests so each fits the URL byte limit. Builds a @@ -1897,13 +1775,18 @@ def multi_value_chunked( single-step plan, so the decorated function has one code path either way. - When ``API_USGS_CONCURRENT`` resolves to a parallelism greater than - 1 (the default), the decorator routes execution through - :meth:`ChunkedCall.resume_async` over the provided ``fetch_async``, run in an - ``anyio`` worker-thread portal so it works whether or not the caller - is already inside an event loop (Jupyter / IPython / async apps). It - falls back to the synchronous :class:`ChunkedCall` path (with a - ``UserWarning``) only when ``fetch_async`` wasn't wired. + The decorated function is an ``async def fetch(args) -> (df, + response)``; the *returned wrapper is synchronous*. The wrapper builds + the :class:`ChunkPlan`, constructs a :class:`ChunkedCall` over the + async fetcher, and drives it to completion via + :meth:`ChunkedCall.resume` — which runs the async core in an ``anyio`` + worker-thread portal so it works whether or not the caller is already + inside an event loop (Jupyter / IPython / async apps). Every pending + sub-request is gathered under one :class:`httpx.AsyncClient`; + concurrency is bounded purely by the connection pool, sized from + ``API_USGS_CONCURRENT``. ``API_USGS_CONCURRENT=1`` is just a + single-connection gather and ``plan.total <= 1`` a one-element gather + — neither is a special-cased path. Parameters ---------- @@ -1911,10 +1794,6 @@ def multi_value_chunked( Factory that turns a kwargs dict into a sized httpx request, e.g. ``_construct_api_requests``. Called during planning to measure each candidate plan. - fetch_async : Callable, optional - Async sibling of the decorated sync fetcher. Used when - ``API_USGS_CONCURRENT`` resolves to >1; if omitted, the - wrapper warns and stays on the serial path. url_limit : int, optional Byte budget for the request (URL + body). When ``None`` (default), the module-level ``_WATERDATA_URL_BYTE_LIMIT`` is @@ -1924,9 +1803,9 @@ def multi_value_chunked( Returns ------- Callable - A decorator that wraps a ``fetch_once(args) -> (df, response)`` - callable into one that accepts the same shape but executes the - underlying plan transparently. + A *synchronous* wrapper ``wrapper(args, *, finalize=...) -> + (df, response)`` that executes the underlying plan transparently + over the decorated async fetcher. Raises ------ @@ -1943,8 +1822,8 @@ def multi_value_chunked( ChunkedCall : Per-sub-request execution and resume semantics. """ - def decorator(fetch_once: _FetchOnce) -> _FetchOnce: - @functools.wraps(fetch_once) + def decorator(fetch: _Fetch) -> Callable[..., tuple[pd.DataFrame, Any]]: + @functools.wraps(fetch) def wrapper( args: dict[str, Any], *, @@ -1952,83 +1831,12 @@ def wrapper( ) -> tuple[pd.DataFrame, Any]: limit = _WATERDATA_URL_BYTE_LIMIT if url_limit is None else url_limit plan = ChunkPlan(args, build_request, limit) - concurrency = _read_concurrency_env() retry_policy = RetryPolicy.from_env() - - # Trivial plans and explicit opt-outs stay on the sync - # path; ``_execute_in_parallel`` owns the rest of the - # serial/parallel decision (async wiring, running loop). - if plan.total <= 1 or concurrency == 1: - return plan.execute(fetch_once, retry_policy, finalize) - return _execute_in_parallel( - plan, fetch_once, fetch_async, concurrency, retry_policy, finalize - ) + # The connection-pool cap is resolved inside ``resume()`` from + # ``API_USGS_CONCURRENT``; ``1`` is a single-connection gather, + # ``total <= 1`` a one-element gather — no special branch. + return plan.execute(fetch, retry_policy, finalize) return wrapper return decorator - - -def _execute_in_parallel( - plan: ChunkPlan, - fetch_once: _FetchOnce, - fetch_async: _FetchOnceAsync | None, - concurrency: int | None, - retry_policy: RetryPolicy = _NO_RETRY, - finalize: _Finalize = _passthrough_result, -) -> tuple[pd.DataFrame, Any]: - """ - Run ``plan`` on the parallel async path. - - Falls back to the serial sync path (with a one-time - :class:`UserWarning`) only when ``fetch_async`` wasn't wired into the - decorator. Otherwise it drives :meth:`ChunkedCall.resume_async` in a short-lived - worker thread via an ``anyio`` blocking portal, so the fan-out runs - whether or not the caller is already inside an event loop (Jupyter / - IPython / async apps) — no nested-``asyncio.run`` error and no silent - degradation to serial. The portal copies the calling context, so the - active progress reporter still reaches the fan-out. - - Parameters - ---------- - plan : ChunkPlan - The chunking plan to execute. - fetch_once : Callable - Sync per-sub-request fetcher returning ``(df, response)``, used - on the serial fallback path. - fetch_async : Callable or None - Async per-sub-request fetcher returning ``(df, response)``. When - ``None``, the call falls back to the serial sync path with a - :class:`UserWarning`. - concurrency : int or None - Maximum in-flight sub-requests. ``None`` disables the cap. - retry_policy : RetryPolicy, optional - Per-sub-request retry-with-backoff policy. Defaults to - :data:`_NO_RETRY`. - finalize : Callable, optional - Transform applied to the combined ``(frame, response)`` (see - :data:`_Finalize`). Defaults to :func:`_passthrough_result`. - - Returns - ------- - tuple of (pandas.DataFrame, finalized response) - The combined frame and the finalized aggregate response that - ``finalize`` produces. - """ - if fetch_async is None: - warnings.warn( - f"{_CONCURRENCY_ENV} is set to {concurrency} but this " - f"call site has no async fetch sibling wired; falling " - f"back to the serial path. Either set " - f"{_CONCURRENCY_ENV}=1 to silence this warning or pass " - f"fetch_async= to @multi_value_chunked.", - UserWarning, - stacklevel=3, - ) - return plan.execute(fetch_once, retry_policy, finalize) - call = ChunkedCall(plan, fetch_once, retry_policy, finalize) - fan_out = functools.partial( - call.resume_async, fetch_async, max_concurrent=concurrency - ) - with start_blocking_portal() as portal: - return portal.call(fan_out) diff --git a/dataretrieval/waterdata/utils.py b/dataretrieval/waterdata/utils.py index c22c2cff..8fabca34 100644 --- a/dataretrieval/waterdata/utils.py +++ b/dataretrieval/waterdata/utils.py @@ -23,6 +23,7 @@ import httpx import pandas as pd +from anyio.from_thread import start_blocking_portal from dataretrieval import __version__ from dataretrieval.utils import HTTPX_DEFAULTS, BaseMetadata @@ -33,7 +34,6 @@ ServiceUnavailable, _safe_elapsed, get_active_async_client, - get_active_client, ) from dataretrieval.waterdata.types import ( PROFILE_LOOKUP, @@ -529,8 +529,9 @@ def _paginated_failure_message(pages_collected: int, cause: BaseException) -> st Returns ------- str - A message suitable for the ``RuntimeError`` that ``_walk_pages`` - and ``get_stats_data`` raise from the original exception. + A message suitable for the ``RuntimeError`` that + ``_walk_pages_async`` and ``get_stats_data`` raise from the + original exception. """ cause_str = str(cause).removesuffix(".") # Some ``httpx`` exceptions (e.g. ``TimeoutException()`` with no args) @@ -773,7 +774,7 @@ def _get_resp_data( # ``features`` is a real schema-drift shape (mirrors the guard in # ``_handle_stats_nesting``). Treat as empty rather than crash with # ``KeyError`` — the wrapped failure would otherwise look like a - # transient transport error to ``_paginate``'s exception handler. + # transient transport error to ``_paginate_async``'s exception handler. features = body.get("features") or [] if not features: return gpd.GeoDataFrame() if geopd else pd.DataFrame() @@ -804,56 +805,35 @@ def _get_resp_data( return df -@contextmanager -def _client_for(client: httpx.Client | None) -> Iterator[httpx.Client]: +@asynccontextmanager +async def _client_for_async( + client: httpx.AsyncClient | None, +) -> AsyncIterator[httpx.AsyncClient]: """ - Yield a usable client, picking the best available source. + Yield a usable async client, picking the best available source. Resolution order: 1. ``client`` if the caller supplied one (borrowed; not closed here — the caller owns its lifecycle). - 2. The chunker's shared client if we're inside a - ``ChunkedCall.resume()`` block (per - :func:`chunking.get_active_client`). Borrowed; - ``ChunkedCall.resume`` closes it on exit. - 3. A fresh short-lived ``httpx.Client`` opened here and closed + 2. The chunker's shared async client if we're inside a + :class:`~dataretrieval.waterdata.chunking.ChunkedCall` run (per + :func:`chunking.get_active_async_client`). Borrowed; the chunker + closes it on exit. + 3. A fresh short-lived ``httpx.AsyncClient`` opened here and closed on context exit. Parameters ---------- - client : httpx.Client or None + client : httpx.AsyncClient or None A caller-owned client to borrow, or ``None`` to defer to the chunker's shared client or a temporary one. Yields ------ - httpx.Client + httpx.AsyncClient The chosen client. """ - if client is not None: - yield client - return - shared = get_active_client() - if shared is not None: - yield shared - return - with httpx.Client(**HTTPX_DEFAULTS) as new: - yield new - - -@asynccontextmanager -async def _client_for_async( - client: httpx.AsyncClient | None, -) -> AsyncIterator[httpx.AsyncClient]: - """ - Yield a usable async client, picking the best available source. - Async sibling of :func:`_client_for`. - - Resolution order matches the sync version: explicit caller-owned - ``AsyncClient`` first, the chunker's shared async client next, a - fresh short-lived ``AsyncClient`` last. - """ if client is not None: yield client return @@ -910,17 +890,16 @@ def _aggregate_paginated_response( # Optional cap on the total rows a single paginated call accumulates before it # stops following ``next`` links. ``None`` (the default the data getters use) # means "no cap — fetch the whole series". Set via :func:`_row_cap` so the deep -# ``_paginate`` loop can honor it without threading the value through the +# ``_paginate_async`` loop can honor it without threading the value through the # generic chunker; this mirrors the ``_progress`` ambient-reporter pattern. _row_cap_var: ContextVar[int | None] = ContextVar("waterdata_row_cap", default=None) @contextmanager def _row_cap(max_rows: int | None) -> Iterator[None]: - """Cap the rows any :func:`_paginate` / :func:`_paginate_async` under this - context will accumulate (``None`` = uncapped). Used by - :func:`get_reference_table` to preview large tables without downloading - every page.""" + """Cap the rows any :func:`_paginate_async` under this context will + accumulate (``None`` = uncapped). Used by :func:`get_reference_table` + to preview large tables without downloading every page.""" token = _row_cap_var.set(max_rows) try: yield @@ -928,22 +907,26 @@ def _row_cap(max_rows: int | None) -> Iterator[None]: _row_cap_var.reset(token) -def _paginate( +async def _paginate_async( initial_req: httpx.Request, *, parse_response: Callable[[httpx.Response], tuple[pd.DataFrame, _Cursor | None]], - follow_up: Callable[[_Cursor, httpx.Client], httpx.Response], - client: httpx.Client | None = None, + follow_up: Callable[[_Cursor, httpx.AsyncClient], Awaitable[httpx.Response]], + client: httpx.AsyncClient | None = None, ) -> tuple[pd.DataFrame, httpx.Response]: """ - Drive a paginated request to completion. - - Common shape behind :func:`_walk_pages` and :func:`get_stats_data`: - send the initial request, then loop calling ``follow_up`` until - ``parse_response`` reports a ``None`` cursor, accumulating frames - and elapsed time. Any mid-pagination failure raises - ``RuntimeError`` wrapping the cause — the API exposes no resume - cursor, so the caller's only recovery is to retry the whole call. + Drive a paginated request to completion over an + :class:`httpx.AsyncClient`. + + The common shape behind :func:`_walk_pages_async` and + :func:`get_stats_data`: send the initial request, then loop calling + ``follow_up`` until ``parse_response`` reports a ``None`` cursor, + accumulating frames and elapsed time. Any mid-pagination failure + raises ``RuntimeError`` wrapping the cause — the API exposes no + resume cursor, so the caller's only recovery is to retry the whole + call. Issuing HTTP asynchronously lets the multiple sub-requests of a + chunked call run concurrently under + :meth:`~dataretrieval.waterdata.chunking.ChunkedCall._run`. Parameters ---------- @@ -954,9 +937,9 @@ def _paginate( DataFrame and the cursor (URL, token, …) used to drive ``follow_up`` for the next page; ``None`` terminates the loop. follow_up : callable - ``(cursor, client) -> httpx.Response``. Builds and sends - the next-page request. - client : httpx.Client, optional + ``(cursor, client) -> Awaitable[httpx.Response]``. Builds and + sends the next-page request. + client : httpx.AsyncClient, optional Caller-borrowed client. ``None`` (default) means use the chunker's shared client (if inside a chunked call) or open a temporary one. @@ -967,10 +950,10 @@ def _paginate( Concatenation of every page's parsed frame. response : httpx.Response A shallow copy of the first-page response, with ``.headers`` - rebuilt as a fresh ``httpx.Headers`` reflecting the last - page and ``.elapsed`` set to cumulative wall-clock. The - canonical URL is preserved from the first page. The original - first-page response is not mutated. + rebuilt as a fresh ``httpx.Headers`` reflecting the last page and + ``.elapsed`` set to cumulative wall-clock. The canonical URL is + preserved from the first page. The original first-page response + is not mutated. Raises ------ @@ -991,87 +974,6 @@ def _paginate( """ logger.debug("Requesting: %s", initial_req.url) reporter = _progress.current() - with _client_for(client) as client: - resp = client.send(initial_req) - _raise_for_non_200(resp) - # Keep the original-request response as the "canonical" one for - # ``md.url`` reproducibility; ``.headers`` and ``.elapsed`` get - # overwritten with latest/cumulative values below. - initial_response = resp - total_elapsed = _safe_elapsed(resp) - - try: - df, cursor = parse_response(resp) - except Exception as e: # noqa: BLE001 - # Initial-page parse failures (malformed JSON, missing - # ``features``, schema drift) get the same wrapped-message - # treatment as follow-up failures so callers see a - # consistent diagnostic regardless of which page broke. - logger.warning("Initial response parse failed.") - raise RuntimeError(_paginated_failure_message(0, e)) from e - dfs = [df] - # Stop following ``next`` links once the optional row cap is reached - # (see :func:`_row_cap`); ``None`` means uncapped. The concatenation is - # sliced to the cap below so a final over-budget page can't exceed it. - cap = _row_cap_var.get() - nrows = len(df) - if reporter is not None: - reporter.set_rate_remaining( - resp.headers.get(_QUOTA_HEADER), - limit=resp.headers.get("x-ratelimit-limit"), - ) - reporter.add_page(rows=len(df)) - while cursor is not None and (cap is None or nrows < cap): - try: - resp = follow_up(cursor, client) - _raise_for_non_200(resp) - df, cursor = parse_response(resp) - dfs.append(df) - nrows += len(df) - total_elapsed += _safe_elapsed(resp) - if reporter is not None: - reporter.set_rate_remaining( - resp.headers.get(_QUOTA_HEADER), - limit=resp.headers.get("x-ratelimit-limit"), - ) - reporter.add_page(rows=len(df)) - except Exception as e: # noqa: BLE001 - logger.warning( - "Request failed at cursor %r. Data download interrupted.", - cursor, - ) - raise RuntimeError(_paginated_failure_message(len(dfs), e)) from e - - # Aggregate headers / elapsed onto a COPY of the initial - # response so the user's caller never sees an in-place - # mutation of the response object they may have inspected - # mid-pagination via a hook or test fixture. - final_response = _aggregate_paginated_response( - initial_response, resp, total_elapsed - ) - result = pd.concat(dfs, ignore_index=True) - if cap is not None: - result = result.head(cap) - return result, final_response - - -async def _paginate_async( - initial_req: httpx.Request, - *, - parse_response: Callable[[httpx.Response], tuple[pd.DataFrame, _Cursor | None]], - follow_up: Callable[[_Cursor, httpx.AsyncClient], Awaitable[httpx.Response]], - client: httpx.AsyncClient | None = None, -) -> tuple[pd.DataFrame, httpx.Response]: - """ - Drive a paginated request to completion over an - :class:`httpx.AsyncClient`. Async sibling of :func:`_paginate`. - - Runs the same per-page loop but issues HTTP asynchronously so - multiple sub-requests of a chunked call can run concurrently from - :meth:`~dataretrieval.waterdata.chunking.ChunkedCall.resume_async`. - """ - logger.debug("Requesting: %s", initial_req.url) - reporter = _progress.current() async with _client_for_async(client) as sess: resp = await sess.send(initial_req) _raise_for_non_200(resp) @@ -1081,17 +983,16 @@ async def _paginate_async( try: df, cursor = parse_response(resp) except Exception as e: # noqa: BLE001 - # Mirror the sync path: initial-page parse failures - # (malformed JSON, missing ``features``, schema drift) - # get the same wrapped-message treatment as follow-up - # failures so callers see a consistent diagnostic - # regardless of which page broke. + # Initial-page parse failures (malformed JSON, missing + # ``features``, schema drift) get the same wrapped-message + # treatment as follow-up failures so callers see a consistent + # diagnostic regardless of which page broke. logger.warning("Initial response parse failed.") raise RuntimeError(_paginated_failure_message(0, e)) from e dfs = [df] # Stop following ``next`` links once the optional row cap is reached - # (see :func:`_row_cap`); ``None`` means uncapped. Mirrors the sync - # :func:`_paginate`; the concatenation is sliced to the cap below. + # (see :func:`_row_cap`); ``None`` means uncapped. The concatenation + # is sliced to the cap below so a final over-budget page can't exceed it. cap = _row_cap_var.get() nrows = len(df) if reporter is not None: @@ -1139,11 +1040,10 @@ def _ogc_parse_response( ) -> tuple[pd.DataFrame, str | None]: """Parse one OGC API page: extract the DataFrame and the next-page URL. - Shared between :func:`_walk_pages` and :func:`_walk_pages_async` - since the parse step is identical on either path. Coerces falsy - cursors (empty href, etc.) to ``None`` so the paginate loop's - ``while cursor is not None`` terminates instead of spinning on a - meaningless value. + The parse strategy :func:`_walk_pages_async` hands to + :func:`_paginate_async`. Coerces falsy cursors (empty href, etc.) to + ``None`` so the paginate loop's ``while cursor is not None`` + terminates instead of spinning on a meaningless value. """ body = resp.json() return ( @@ -1152,19 +1052,19 @@ def _ogc_parse_response( ) -def _walk_pages( +async def _walk_pages_async( geopd: bool, req: httpx.Request, - client: httpx.Client | None = None, + client: httpx.AsyncClient | None = None, ) -> tuple[pd.DataFrame, httpx.Response]: """ - Iterate through paginated OGC API responses and aggregate into one - DataFrame. + Iterate paginated OGC API responses asynchronously and aggregate + them into one DataFrame. - Thin wrapper that hands off to :func:`_paginate` with OGC-specific - strategies: pages are parsed via :func:`_get_resp_data` and the - next-page cursor is the URL from the response's ``links`` array - (per :func:`_next_req_url`). + Thin wrapper that hands off to :func:`_paginate_async` with + OGC-specific strategies: pages are parsed via :func:`_get_resp_data` + (through :func:`_ogc_parse_response`) and the next-page cursor is the + URL from the response's ``links`` array (per :func:`_next_req_url`). Parameters ---------- @@ -1172,9 +1072,9 @@ def _walk_pages( Whether geopandas is installed (drives geometry handling). req : httpx.Request The initial HTTP request to send. - client : httpx.Client, optional + client : httpx.AsyncClient, optional Caller-borrowed client; ``None`` defers client management to - :func:`_paginate`. + :func:`_paginate_async`. Returns ------- @@ -1188,39 +1088,14 @@ def _walk_pages( Raises ------ RuntimeError - See :func:`_paginate`. + See :func:`_paginate_async`. httpx.HTTPError - See :func:`_paginate`. + See :func:`_paginate_async`. """ method = req.method # ``httpx.Request.method`` is already upper-cased. headers = req.headers content = req.content if method == "POST" else None - def follow_up(cursor: str, client: httpx.Client) -> httpx.Response: - return client.request(method, cursor, headers=headers, content=content) - - return _paginate( - req, - parse_response=functools.partial(_ogc_parse_response, geopd=geopd), - follow_up=follow_up, - client=client, - ) - - -async def _walk_pages_async( - geopd: bool, - req: httpx.Request, - client: httpx.AsyncClient | None = None, -) -> tuple[pd.DataFrame, httpx.Response]: - """ - Iterate paginated OGC API responses asynchronously and aggregate - them into one DataFrame. Async sibling of :func:`_walk_pages`; - delegates to :func:`_paginate_async`. - """ - method = req.method - headers = req.headers - content = req.content if method == "POST" else None - async def follow_up(cursor: str, sess: httpx.AsyncClient) -> httpx.Response: return await sess.request(method, cursor, headers=headers, content=content) @@ -1423,7 +1298,7 @@ def _finalize_ogc( ``max_rows`` is applied here (after dedup/sort, on the *combined* frame) rather than only per-sub-request, so a chunked call's total is bounded to exactly ``max_rows`` and a resumed call honors the cap too — the - per-``_paginate`` ``_row_cap`` is only an early-stop download bound. + per-``_paginate_async`` ``_row_cap`` is only an early-stop download bound. """ frame = _deal_with_empty(frame, properties, service) if convert_type: @@ -1521,37 +1396,26 @@ def get_ogc_data( return _fetch_once(args, finalize=finalize) -async def _fetch_once_async( +@chunking.multi_value_chunked(build_request=_construct_api_requests) +async def _fetch_once( args: dict[str, Any], ) -> tuple[pd.DataFrame, httpx.Response]: """Send one prepared-args OGC request asynchronously; return the - frame + response. Async sibling of :func:`_fetch_once` used by the - parallel chunker.""" - req = _construct_api_requests(**args) - return await _walk_pages_async(geopd=GEOPANDAS, req=req) - - -@chunking.multi_value_chunked( - build_request=_construct_api_requests, - fetch_async=_fetch_once_async, -) -def _fetch_once( - args: dict[str, Any], -) -> tuple[pd.DataFrame, httpx.Response]: - """Send one prepared-args OGC request; return the frame + response. + frame + response. ``@chunking.multi_value_chunked`` models every multi-value list parameter and the cql-text filter as a chunkable axis, greedy-halves the biggest chunk across all axes until each sub-request URL fits, and iterates the cartesian product. With no chunkable inputs the - decorator passes args through unchanged. When ``API_USGS_CONCURRENT`` - is >1 (the default), the decorator routes execution through - :func:`_fetch_once_async` so the sub-requests run concurrently under - one shared :class:`httpx.AsyncClient`. Either way the return shape - is ``(frame, response)``. + decorator passes args through unchanged. The decorator gathers every + sub-request over one shared :class:`httpx.AsyncClient` (concurrency + bounded by the connection pool, sized from ``API_USGS_CONCURRENT``) + and returns a *synchronous* wrapper, so ``get_ogc_data`` keeps calling + ``_fetch_once(args, finalize=...)`` synchronously. The return shape is + ``(frame, response)``. """ req = _construct_api_requests(**args) - return _walk_pages(geopd=GEOPANDAS, req=req) + return await _walk_pages_async(geopd=GEOPANDAS, req=req) def _handle_stats_nesting( @@ -1714,7 +1578,7 @@ def get_stats_data( args: dict[str, Any], service: str, expand_percentiles: bool, - client: httpx.Client | None = None, + client: httpx.AsyncClient | None = None, ) -> tuple[pd.DataFrame, BaseMetadata]: """ Retrieves statistical data from a specified endpoint and returns it @@ -1724,6 +1588,15 @@ def get_stats_data( handles pagination, processes results, and formats output according to the specified parameters. + Synchronous facade: the stats path doesn't go through + ``multi_value_chunked`` (its query shape has no chunkable list axes), + so it drives :func:`_paginate_async` directly through an ``anyio`` + blocking portal — the same async-only core the chunked getters use — + while keeping a synchronous signature and ``(df, BaseMetadata)`` + return. The portal runs the pagination loop in a short-lived worker + thread, so this works whether or not the caller is already inside an + event loop. + Parameters ---------- args : Dict[str, Any] @@ -1736,6 +1609,9 @@ def get_stats_data( each percentile gets its own row in the returned dataframe. If True and user requests a computation_type other than percentiles, a percentile column is still returned. + client : httpx.AsyncClient, optional + Caller-borrowed async client. ``None`` (default) opens a + temporary one inside the portal. Primarily a test seam. Returns ------- @@ -1757,28 +1633,34 @@ def get_stats_data( def parse_response(resp: httpx.Response) -> tuple[pd.DataFrame, str | None]: body = resp.json() - # Coerce falsy cursors ("", 0) to None so _paginate terminates. + # Coerce falsy cursors ("", 0) to None so _paginate_async terminates. # USGS uses "next": null at end-of-stream, but defensive coerce # protects against any "" sentinel a future schema might use. return _handle_stats_nesting(body, geopd=GEOPANDAS), body.get("next") or None - def follow_up(cursor: str, client: httpx.Client) -> httpx.Response: + async def follow_up(cursor: str, sess: httpx.AsyncClient) -> httpx.Response: # Build a fresh params dict per page so the caller's ``args`` # is never mutated. - return client.request( + return await sess.request( method, url=url, params={**args, "next_token": cursor}, headers=headers ) - # The stats path doesn't go through ``multi_value_chunked``, so it opens - # its own progress context; ``_paginate`` reports pages/rate-limit into it. - with _progress.progress_context(service=service): - df, response = _paginate( + async def _run() -> tuple[pd.DataFrame, httpx.Response]: + return await _paginate_async( req, parse_response=parse_response, follow_up=follow_up, client=client, ) + # The stats path opens its own progress context (it doesn't go through + # ``multi_value_chunked``); ``_paginate_async`` reports pages/rate-limit + # into it. The portal copies the calling context, so the reporter still + # reaches the worker thread. + with _progress.progress_context(service=service): + with start_blocking_portal() as portal: + df, response = portal.call(_run) + if expand_percentiles: df = _expand_percentiles(df) return df, BaseMetadata(response) diff --git a/tests/conftest.py b/tests/conftest.py index 5eb46cb8..93baeb1a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -6,10 +6,13 @@ ``requests-mock``-style permissiveness the test code was written against, and keeps mocked-URL setup terse). * Pins ``API_USGS_CONCURRENT=1`` and ``API_USGS_RETRIES=0`` for every - test by default so the historical mocked suite stays on the - deterministic serial chunker path and a single transient surfaces - immediately (no backoff). Async-mode and retry tests opt in by - re-setting the env vars inside their body via ``monkeypatch.setenv``. + test by default. The chunker core is async-only, so + ``API_USGS_CONCURRENT=1`` now means a single pooled connection (a + one-connection ``asyncio.gather``) rather than a separate serial code + path — deterministic enough for the mocked suite while a single + transient surfaces immediately (no backoff). Async-fan-out and retry + tests opt in by re-setting the env vars inside their body via + ``monkeypatch.setenv``. """ from __future__ import annotations @@ -39,17 +42,19 @@ def non_mocked_hosts() -> list[str]: @pytest.fixture(autouse=True) def _serial_chunker(monkeypatch): - """Default every test to the serial, no-retry chunker path. - - Production defaults ``API_USGS_CONCURRENT`` to 16 (parallel - fan-out) and ``API_USGS_RETRIES`` to 4, but the historical tests - assume sequential, deterministic sub-request ordering — and they - mock the sync ``_walk_pages`` rather than the async sibling, and - expect a single transient to surface immediately rather than be - retried. Pinning ``API_USGS_CONCURRENT=1`` and ``API_USGS_RETRIES=0`` - keeps the test surface focused on the planner / fetch contracts; - async-mode and retry tests opt in by overriding the env inside - their body. + """Default every test to the single-connection, no-retry chunker path. + + Production defaults ``API_USGS_CONCURRENT`` to 16 (a wide pooled + fan-out) and ``API_USGS_RETRIES`` to 4. The chunker core is async-only + now — there is no separate serial path — so ``API_USGS_CONCURRENT=1`` + means a single pooled connection (a one-connection ``asyncio.gather``), + which keeps sub-request dispatch deterministic enough for the mocked + suite. ``API_USGS_RETRIES=0`` makes a single transient surface + immediately rather than be retried. The mocked tests drive the async + ``_walk_pages_async`` (via ``asyncio.run`` / an ``AsyncMock`` client), + not a sync sibling. Pinning both keeps the test surface focused on the + planner / fetch contracts; async-fan-out and retry tests opt in by + overriding the env inside their body. """ monkeypatch.setenv("API_USGS_CONCURRENT", "1") monkeypatch.setenv("API_USGS_RETRIES", "0") diff --git a/tests/waterdata_chunking_test.py b/tests/waterdata_chunking_test.py index 612ae154..6e616c1c 100644 --- a/tests/waterdata_chunking_test.py +++ b/tests/waterdata_chunking_test.py @@ -42,16 +42,25 @@ RetryPolicy, ServiceInterrupted, ServiceUnavailable, - _chunked_client, + _chunked_async_client, _extract_axes, _retry_async, - _retry_sync, _retryable, multi_value_chunked, ) from dataretrieval.waterdata.utils import _construct_api_requests +def _aiozero(_d): + """An async no-op sleep — monkeypatched over ``chunking._ASLEEP`` so + retry backoff doesn't actually wait in tests.""" + + async def _noop(): + return None + + return _noop() + + class _FakeReq: """Stand-in for ``httpx.Request`` whose ``_request_bytes`` shape is ``len(str(url)) + len(content)``.""" @@ -235,7 +244,7 @@ def test_multi_value_chunked_passes_through_when_url_fits(): calls = [] @multi_value_chunked(build_request=_fake_build, url_limit=8000) - def fetch(args): + async def fetch(args): calls.append(args) return pd.DataFrame(), mock.Mock( elapsed=datetime.timedelta(seconds=0.1), headers={} @@ -254,7 +263,7 @@ def test_multi_value_chunked_emits_3d_cartesian_product(): calls = [] @multi_value_chunked(build_request=_fake_build, url_limit=240) - def fetch(args): + async def fetch(args): calls.append(tuple(tuple(args[k]) for k in ("sites", "pcodes", "stats"))) return pd.DataFrame(), mock.Mock( elapsed=datetime.timedelta(seconds=0.1), headers={} @@ -303,7 +312,7 @@ def test_multi_value_chunked_lazy_url_limit(monkeypatch): calls = [] @multi_value_chunked(build_request=_fake_build) # url_limit defaults to None - def fetch(args): + async def fetch(args): calls.append(args) return pd.DataFrame(), mock.Mock( elapsed=datetime.timedelta(seconds=0.1), headers={} @@ -317,20 +326,20 @@ def fetch(args): def test_chunked_session_shared_across_sub_requests(): """Every sub-request of one chunked call sees the same - ``httpx.Client`` on the ``_chunked_client`` ContextVar, so - downstream paginated helpers (``_walk_pages``) can reuse the + ``httpx.AsyncClient`` on the ``_chunked_async_client`` ContextVar, so + downstream paginated helpers (``_walk_pages_async``) can reuse the connection pool instead of handshaking fresh on each sub-request.""" sessions_seen = [] @multi_value_chunked(build_request=_fake_build, url_limit=240) - def fetch(args): - sessions_seen.append(_chunked_client.get()) + async def fetch(args): + sessions_seen.append(_chunked_async_client.get()) return pd.DataFrame(), mock.Mock( elapsed=datetime.timedelta(seconds=0.1), headers={} ) - # Outside a chunked call: no session published. - assert _chunked_client.get() is None + # Outside a chunked call: no session published (in this thread/context). + assert _chunked_async_client.get() is None fetch({"sites": ["S1" * 10, "S2" * 10, "S3" * 10, "S4" * 10]}) @@ -341,19 +350,22 @@ def fetch(args): assert all(s is not None for s in sessions_seen) # And it was the same object every time. assert len({id(s) for s in sessions_seen}) == 1 - # On exit the ContextVar is reset to its default. - assert _chunked_client.get() is None + # The portal's worker context is torn down on exit, so the calling + # thread's ContextVar still reads its default. + assert _chunked_async_client.get() is None def test_chunked_session_isolated_per_resume(): """A follow-up ``resume`` after an interruption opens a fresh session — the previous one was closed when its ``resume`` returned. - The ContextVar is reset between calls so leakage can't carry + The ContextVar is reset between runs so leakage can't carry a closed session into the retry.""" state = {"i": 0, "blow_up": True} + sessions_seen = [] @multi_value_chunked(build_request=_fake_build, url_limit=240) - def fetch(args): + async def fetch(args): + sessions_seen.append(_chunked_async_client.get()) i = state["i"] state["i"] += 1 if i == 1 and state["blow_up"]: @@ -369,13 +381,23 @@ def fetch(args): with pytest.raises(QuotaExhausted) as excinfo: fetch({"sites": ["S1" * 10, "S2" * 10, "S3" * 10, "S4" * 10]}) - # First resume's session is closed; ContextVar is reset. - assert _chunked_client.get() is None + # First run published a shared client to its sub-requests; the calling + # thread's ContextVar is unaffected (reads its default). + assert _chunked_async_client.get() is None + first_run_sessions = list(sessions_seen) + assert first_run_sessions and all(s is not None for s in first_run_sessions) state["blow_up"] = False excinfo.value.call.resume() - # Second resume's session is also cleaned up. - assert _chunked_client.get() is None + # Second run's ContextVar is also reset in the calling thread. + assert _chunked_async_client.get() is None + # The resume opened a FRESH client, distinct from the first run's, so no + # closed client leaks across runs. + resume_sessions = sessions_seen[len(first_run_sessions) :] + assert resume_sessions and all(s is not None for s in resume_sessions) + assert {id(s) for s in resume_sessions}.isdisjoint( + {id(s) for s in first_run_sessions} + ) def _quota_response(remaining: int | str | None) -> mock.Mock: @@ -392,7 +414,7 @@ def test_quota_exhausted_on_mid_call_429(): offset so callers can resume after the window resets.""" state = {"i": 0} - def fetch(args): + async def fetch(args): i = state["i"] state["i"] += 1 if i == 2: @@ -415,10 +437,12 @@ def fetch(args): decorated({"sites": ["S1" * 10, "S2" * 10, "S3" * 10, "S4" * 10, "S5" * 10]}) err = excinfo.value - assert err.completed_chunks == 2 # chunks 0 and 1 completed; 429 hit on i=2 + # Async fan-out: every non-failing sub-request completes (the gather + # runs all of them; only i==2 raises), so 4 of 5 complete. + assert err.completed_chunks == 4 # only the i==2 sub-request failed assert err.total_chunks == 5 assert err.partial_frame is not None - assert set(err.partial_frame["i"]) == {0, 1} + assert set(err.partial_frame["i"]) == {0, 1, 3, 4} def test_quota_exhausted_on_first_chunk_429_has_no_partial_response(): @@ -427,7 +451,7 @@ def test_quota_exhausted_on_first_chunk_429_has_no_partial_response(): is empty) so callers can branch on that to distinguish "abort before any data arrived" from "abort after partial collection".""" - def fetch(args): + async def fetch(args): raise RateLimited("429: Too many requests made.") decorated = multi_value_chunked(build_request=_fake_build, url_limit=240)(fetch) @@ -444,14 +468,18 @@ def test_quota_exhausted_resume_picks_up_where_429_stopped(): once the window resets, ``e.call.resume()`` re-issues only the sub-requests that hadn't completed and returns the full combined result. Chunks completed before the 429 are not re-fetched.""" - # The fake fetch 429s on the third call, then succeeds on every - # subsequent call. We track which sub-args have been issued so we - # can assert chunks 0/1 aren't re-fetched on resume. + # One sub-request (the chunk containing the failing site) 429s on the + # first gather, then succeeds once the window resets. Under the async + # fan-out every OTHER sub-request completes on the first gather, so + # resume re-issues only the single still-pending chunk. We track which + # sub-args have been issued to assert the completed chunks aren't + # re-fetched. fetched_sites: list[tuple[str, ...]] = [] + failing_site = "S3" * 10 rate_limited_once = {"fired": False} - def fetch(args): - if len(fetched_sites) == 2 and not rate_limited_once["fired"]: + async def fetch(args): + if failing_site in args["sites"] and not rate_limited_once["fired"]: rate_limited_once["fired"] = True raise RateLimited("429: Too many requests made.") site_tuple = tuple(args["sites"]) @@ -462,23 +490,24 @@ def fetch(args): ) decorated = multi_value_chunked(build_request=_fake_build, url_limit=240)(fetch) - sites = ["S1" * 10, "S2" * 10, "S3" * 10, "S4" * 10, "S5" * 10] + sites = ["S1" * 10, "S2" * 10, failing_site, "S4" * 10, "S5" * 10] - # First attempt: 429 on the third sub-request. + # First attempt: 429 on the chunk carrying the failing site; the other + # four sub-requests complete. with pytest.raises(QuotaExhausted) as excinfo: decorated({"sites": sites}) err = excinfo.value - assert err.completed_chunks == 2 + assert err.completed_chunks == 4 pre_resume_count = len(fetched_sites) - assert pre_resume_count == 2 # chunks 0 and 1 completed + assert pre_resume_count == 4 # every chunk but the failing one completed - # Resume: re-issues only the still-pending sub-requests. + # Resume: re-issues only the still-pending sub-request. df, _ = err.call.resume() - # Three more fetches happened on resume (chunks 2, 3, 4); chunks 0 - # and 1 were not re-fetched. - assert len(fetched_sites) - pre_resume_count == 3, ( - f"expected 3 new fetches on resume (chunks 2, 3, 4); got " + # Exactly one more fetch happened on resume (the chunk that 429'd); + # the four already-completed chunks were not re-fetched. + assert len(fetched_sites) - pre_resume_count == 1, ( + f"expected 1 new fetch on resume (the failing chunk); got " f"{len(fetched_sites) - pre_resume_count}" ) # Every original site appears in the combined frame exactly once. @@ -490,33 +519,33 @@ def test_quota_exhausted_resume_can_reraise_on_persistent_429(): ``call.resume()`` raises ``QuotaExhausted`` again — the ``ChunkedCall``'s in-flight state carries forward, so a subsequent resume after a longer wait still picks up cleanly.""" - state = {"attempts": 0} - - def fetch(args): - i = state["attempts"] - state["attempts"] += 1 - # First attempt 429s on chunk 2. Resume attempt 429s on what - # would be chunk 2 again (still the first un-completed - # sub-request). - if i == 2 or i == 3: + # Key the failure on the chunk's CONTENT (one persistently-429ing + # site) rather than a global call counter: under the async fan-out + # every other sub-request completes, and the same still-pending + # sub-request re-fails on resume — so the completed count is stable. + failing_site = "S3" * 10 + + async def fetch(args): + if failing_site in args["sites"]: raise RateLimited("429: Too many requests made.") return ( - pd.DataFrame({"i": [i], "sites": list(args["sites"])}), + pd.DataFrame({"sites": list(args["sites"])}), _quota_response(500), ) decorated = multi_value_chunked(build_request=_fake_build, url_limit=240)(fetch) - sites = ["S1" * 10, "S2" * 10, "S3" * 10, "S4" * 10, "S5" * 10] + sites = ["S1" * 10, "S2" * 10, failing_site, "S4" * 10, "S5" * 10] with pytest.raises(QuotaExhausted) as first: decorated({"sites": sites}) with pytest.raises(QuotaExhausted) as second: first.value.call.resume() - # Both exceptions report the same completed_chunks count — the - # second resume didn't make progress (it 429'd on the same chunk). - assert first.value.completed_chunks == 2 - assert second.value.completed_chunks == 2 + # Both exceptions report the same completed_chunks count — every + # sub-request but the persistently-429ing one completed on the first + # gather, and the resume re-issued only that one, which 429'd again. + assert first.value.completed_chunks == 4 + assert second.value.completed_chunks == 4 def test_resume_produces_dataset_identical_to_uninterrupted_run(): @@ -534,7 +563,7 @@ def make_fetch(rate_limit_at_call: int | None): keyed by the sub-args's sites.""" state = {"calls": 0, "tripped": False} - def fetch(args): + async def fetch(args): state["calls"] += 1 if state["calls"] == rate_limit_at_call and not state["tripped"]: state["tripped"] = True @@ -592,7 +621,7 @@ def test_chunker_passes_through_non_429_runtime_error(): it must propagate unchanged so callers see the real cause.""" state = {"i": 0} - def fetch(args): + async def fetch(args): i = state["i"] state["i"] += 1 if i == 2: @@ -615,7 +644,7 @@ def test_chunker_wraps_service_unavailable_as_resumable(): ``.call.resume()`` resumes only the still-pending sub-requests.""" state = {"i": 0, "blow_up": True} - def fetch(args): + async def fetch(args): i = state["i"] state["i"] += 1 if i == 2 and state["blow_up"]: @@ -634,7 +663,9 @@ def fetch(args): err = excinfo.value # Resumable: handle on .call with already-completed work preserved. assert err.call is not None - assert err.completed_chunks == 2 + # Async fan-out: only the i==2 sub-request fails; the gather completes + # the other four, so 4 of 5 are recorded before the failure surfaces. + assert err.completed_chunks == 4 assert err.total_chunks == 5 assert not err.call.partial_frame.empty # Upstream recovers; resuming completes the call. @@ -667,7 +698,7 @@ def test_connection_error_wrapped_as_service_interrupted(): state = {"i": 0, "blow_up": True} - def fetch(args): + async def fetch(args): i = state["i"] state["i"] += 1 if i == 2 and state["blow_up"]: @@ -682,7 +713,8 @@ def fetch(args): decorated({"sites": ["S1" * 10, "S2" * 10, "S3" * 10, "S4" * 10, "S5" * 10]}) err = excinfo.value - assert err.completed_chunks == 2 + # Async fan-out: only the i==2 sub-request fails; the other four complete. + assert err.completed_chunks == 4 assert err.call is not None # The transport exception is on __cause__ so callers can drill in if needed. assert isinstance(err.__cause__, _httpx.ConnectError) @@ -702,7 +734,7 @@ def test_invalid_url_wrapped_as_service_interrupted(): state = {"i": 0, "blow_up": True} - def fetch(args): + async def fetch(args): i = state["i"] state["i"] += 1 if i == 2 and state["blow_up"]: @@ -717,7 +749,8 @@ def fetch(args): decorated({"sites": ["S1" * 10, "S2" * 10, "S3" * 10, "S4" * 10, "S5" * 10]}) err = excinfo.value - assert err.completed_chunks == 2 + # Async fan-out: only the i==2 sub-request fails; the other four complete. + assert err.completed_chunks == 4 assert err.call is not None assert isinstance(err.__cause__, _httpx.InvalidURL) # The top-level message must surface the underlying cause text so @@ -738,7 +771,7 @@ def test_service_interrupted_exposes_partial_frame_and_response(): crashed with AttributeError on 5xx.""" state = {"i": 0} - def fetch(args): + async def fetch(args): i = state["i"] state["i"] += 1 if i == 2: @@ -771,7 +804,7 @@ def test_partial_frame_snapshot_stable_across_resume(): a name that promises pre-resume state.""" state = {"i": 0, "blow_up": True} - def fetch(args): + async def fetch(args): i = state["i"] state["i"] += 1 if i == 2 and state["blow_up"]: @@ -808,7 +841,7 @@ def test_partial_frame_snapshot_is_a_copy_when_single_chunk(): ``pd.concat`` (which already produces a fresh frame).""" state = {"i": 0, "blow_up": True} - def fetch(args): + async def fetch(args): i = state["i"] state["i"] += 1 if i == 1 and state["blow_up"]: @@ -818,12 +851,13 @@ def fetch(args): _quota_response(500), ) - # 4 sites at url_limit=240 → 2 sub-requests. The 429 fires on the - # SECOND sub-request, so the exception captures exactly ONE - # completed chunk — the path where _combine_chunk_frames aliases. + # 2 sites at url_limit=240 → 2 singleton sub-requests. The 429 fires + # on the SECOND sub-request and the gather completes the other, so the + # exception captures exactly ONE completed chunk — the path where + # _combine_chunk_frames aliases its single non-empty frame. decorated = multi_value_chunked(build_request=_fake_build, url_limit=240)(fetch) with pytest.raises(QuotaExhausted) as excinfo: - decorated({"sites": ["S1" * 10, "S2" * 10, "S3" * 10, "S4" * 10]}) + decorated({"sites": ["S1" * 10, "S2" * 10]}) err = excinfo.value assert err.completed_chunks == 1 @@ -887,7 +921,7 @@ def test_paginate_terminates_on_empty_string_cursor(): resp.headers = {} resp.json.return_value = body_with_empty_next - client = _mock.MagicMock(spec=_httpx.Client) + client = _mock.AsyncMock(spec=_httpx.AsyncClient) client.send.return_value = resp req = _mock.MagicMock(spec=_httpx.Request) @@ -896,7 +930,9 @@ def test_paginate_terminates_on_empty_string_cursor(): req.content = b"" req.url = "https://example.com/items?limit=1" - df, final = _utils._walk_pages(geopd=False, req=req, client=client) + df, final = asyncio.run( + _utils._walk_pages_async(geopd=False, req=req, client=client) + ) # Single send + zero follow-ups: the loop terminated on the empty cursor. assert client.send.called @@ -943,7 +979,7 @@ def test_retry_after_surfaces_on_quota_exhausted(): can honor the server's hint instead of guessing a wait.""" state = {"i": 0} - def fetch(args): + async def fetch(args): state["i"] += 1 if state["i"] >= 3: try: @@ -1050,7 +1086,7 @@ def test_multi_value_chunked_restores_canonical_url(): sub_urls: list[str] = [] @multi_value_chunked(build_request=_fake_build, url_limit=240) - def fetch(args): + async def fetch(args): # Each sub-response carries the chunked sub_args's URL, so # without canonical restoration the first chunk's URL would # leak through to md.url. @@ -1211,28 +1247,23 @@ def test_iter_sub_args_passthrough_yields_a_copy(): # --- async fan-out path ---------------------------------------------------- # -# The conftest's ``_serial_chunker`` autouse pins ``API_USGS_CONCURRENT=1`` -# for the whole suite. Each test below overrides it so the wrapper takes -# the parallel branch. The decorator's ``fetch_async`` accepts any -# coroutine returning ``(df, response)`` — no real ``httpx.AsyncClient`` -# round-trip occurs, even though :meth:`ChunkedCall.resume_async` opens one +# The chunker is async-only: every sub-request is gathered over one +# ``httpx.AsyncClient`` and concurrency is bounded purely by that client's +# connection pool, sized from ``API_USGS_CONCURRENT``. The conftest's +# ``_serial_chunker`` autouse pins ``API_USGS_CONCURRENT=1`` (a single +# connection) for the whole suite; each test below raises it so the gather +# can dispatch sub-requests under a wider pool. The decorated async fetcher +# is the SAME one used on both first-run and resume — there is no sync +# sibling anymore. No real ``httpx.AsyncClient`` round-trip occurs (the +# fakes return mock data), even though :meth:`ChunkedCall._run` opens one # for pool management. def _async_chunked_fetch(monkeypatch, fetch_async, *, max_concurrent=16): - """Decorate a deterministic chunkable fetch with the parallel - path forced on via ``API_USGS_CONCURRENT``.""" + """Decorate a deterministic chunkable async fetch with a wide-pool + gather forced on via ``API_USGS_CONCURRENT``.""" monkeypatch.setenv("API_USGS_CONCURRENT", str(max_concurrent)) - - @multi_value_chunked( - build_request=_fake_build, fetch_async=fetch_async, url_limit=240 - ) - def fetch(args): - # Sync sibling — invoked on resume() after a parallel failure - # and never during the happy parallel path. - return pd.DataFrame({"id": [_atom_id(args)]}), _ok_response() - - return fetch + return multi_value_chunked(build_request=_fake_build, url_limit=240)(fetch_async) def _atom_id(args): @@ -1291,69 +1322,75 @@ def test_async_fan_out_failure_yields_resumable_call(monkeypatch): """A transient 5xx mid-fan-out raises ``ServiceInterrupted`` whose ``.call`` is a ``ChunkedCall`` holding the completed sub-requests in a sparse index map. ``exc.call.resume()`` re-issues only the - unfinished sub-requests, via the sync ``fetch_once`` path.""" - call_count = {"async": 0, "sync": 0} + unfinished sub-requests — through the same async fetcher and the same + async runner, just on a fresh gather.""" + # One async fetcher serves both first-run and resume. On the first + # gather it lets exactly one sub-request succeed and fails the rest + # transiently; once ``blow_up`` is cleared the resume gather completes + # every still-pending sub-request. ``calls`` counts every invocation + # across both gathers so we can assert resume only re-issued the owed + # sub-requests. + state = {"first_success": False, "blow_up": True} + calls = {"n": 0} async def fetch_async(args): - call_count["async"] += 1 - # First sub-request succeeds; siblings fail. - if call_count["async"] == 1: - return pd.DataFrame({"id": [_atom_id(args)]}), _ok_response(remaining=99) - raise ServiceUnavailable("503: simulated") - - monkeypatch.setenv("API_USGS_CONCURRENT", "16") - - @multi_value_chunked( - build_request=_fake_build, fetch_async=fetch_async, url_limit=240 - ) - def fetch(args): - call_count["sync"] += 1 + calls["n"] += 1 + if state["blow_up"]: + # Let the first dispatched sub-request through, fail the rest. + if not state["first_success"]: + state["first_success"] = True + return pd.DataFrame({"id": [_atom_id(args)]}), _ok_response( + remaining=99 + ) + raise ServiceUnavailable("503: simulated") return pd.DataFrame({"id": [_atom_id(args)]}), _ok_response(remaining=99) + fetch = _async_chunked_fetch(monkeypatch, fetch_async) + with pytest.raises(ServiceInterrupted) as exc_info: fetch({"sites": ["S1" * 10, "S2" * 10, "S3" * 10, "S4" * 10]}) interrupted = exc_info.value - assert interrupted.call is not None, "parallel-mode interruption must be resumable" - # First sub-request completed; the rest still owe. + assert interrupted.call is not None, "interruption must be resumable" + # Exactly one sub-request completed; the rest still owe. assert interrupted.completed_chunks == 1 assert interrupted.total_chunks > 1 - # Resume on the sync path picks up only the missing sub-requests. - sync_before = call_count["sync"] + # Resume re-issues only the missing sub-requests, via the same async + # runner the first run used. + state["blow_up"] = False + calls_before = calls["n"] df, _ = interrupted.call.resume() - sync_calls_on_resume = call_count["sync"] - sync_before - assert sync_calls_on_resume == interrupted.total_chunks - 1 + calls_on_resume = calls["n"] - calls_before + assert calls_on_resume == interrupted.total_chunks - 1 # Final frame unions every sub-args. assert len(df) == interrupted.total_chunks def test_async_fan_out_resume_applies_finalize(monkeypatch): - """The ``finalize`` injected for a PARALLEL call survives the interruption - (carried on the ``ChunkedCall`` through the anyio portal), so a serial - ``exc.call.resume()`` still returns the finalized shape — guarding the - parallel resume_async -> resume -> finalize path the serial-pinned finalize - test can't reach. Partials stay raw (no finalize in the exception ctor).""" + """The ``finalize`` injected for a wide-pool call survives the + interruption (carried on the ``ChunkedCall`` through the anyio portal), + so ``exc.call.resume()`` still returns the finalized shape — guarding + the run -> resume -> finalize path. Partials stay raw (no finalize in + the exception ctor).""" def finalize(frame, response): return frame.assign(finalized=True), ("MD", response) - call_count = {"async": 0} + state = {"first_success": False, "blow_up": True} async def fetch_async(args): - call_count["async"] += 1 - if call_count["async"] == 1: - return pd.DataFrame({"id": [_atom_id(args)]}), _ok_response(remaining=99) - raise ServiceUnavailable("503: simulated") - - monkeypatch.setenv("API_USGS_CONCURRENT", "16") - - @multi_value_chunked( - build_request=_fake_build, fetch_async=fetch_async, url_limit=240 - ) - def fetch(args): + if state["blow_up"]: + if not state["first_success"]: + state["first_success"] = True + return pd.DataFrame({"id": [_atom_id(args)]}), _ok_response( + remaining=99 + ) + raise ServiceUnavailable("503: simulated") return pd.DataFrame({"id": [_atom_id(args)]}), _ok_response(remaining=99) + fetch = _async_chunked_fetch(monkeypatch, fetch_async) + sites = ["S1" * 10, "S2" * 10, "S3" * 10, "S4" * 10] with pytest.raises(ServiceInterrupted) as exc_info: fetch({"sites": sites}, finalize=finalize) @@ -1361,29 +1398,35 @@ def fetch(args): # Partial snapshot stays raw — building the exception must not finalize. assert "finalized" not in exc_info.value.partial_frame.columns # Resume applies the finalize carried on the ChunkedCall. + state["blow_up"] = False df, md = exc_info.value.call.resume() assert "finalized" in df.columns assert md[0] == "MD" -def test_async_falls_back_to_serial_when_no_fetch_async(monkeypatch): - """With no ``fetch_async=`` sibling wired, a parallel - ``API_USGS_CONCURRENT`` can't be honored, so the call falls back to - the serial ``ChunkedCall`` path with a one-time ``UserWarning`` - rather than silently no-op'ing the env var.""" - sync_calls = [] +def test_wide_concurrency_uses_async_fetcher_with_no_warning(monkeypatch): + """Re-expresses the old "falls back to serial when no async sibling + wired" intent for the async-only core: there is no async sibling to + wire and no serial fallback, so a wide ``API_USGS_CONCURRENT`` is + honored directly by the single async fetcher — every sub-request runs + on it and NO ``UserWarning`` is emitted (the env var is never silently + no-op'd, the previous regression that warning guarded).""" + import warnings + + calls = [] monkeypatch.setenv("API_USGS_CONCURRENT", "16") @multi_value_chunked(build_request=_fake_build, url_limit=240) - def fetch(args): - sync_calls.append(tuple(args["sites"])) + async def fetch(args): + calls.append(tuple(args["sites"])) return pd.DataFrame({"id": [_atom_id(args)]}), _ok_response() - with pytest.warns(UserWarning, match="no async fetch sibling"): + with warnings.catch_warnings(): + warnings.simplefilter("error") # any UserWarning would fail the test df, _ = fetch({"sites": ["S1" * 10, "S2" * 10, "S3" * 10, "S4" * 10]}) - assert len(sync_calls) > 1 - assert len(df) == len(sync_calls) + assert len(calls) > 1 # the gather fanned out across every sub-request + assert len(df) == len(calls) def test_async_fan_out_runs_inside_running_event_loop(monkeypatch): @@ -1394,17 +1437,15 @@ def test_async_fan_out_runs_inside_running_event_loop(monkeypatch): monkeypatch.setenv("API_USGS_CONCURRENT", "16") async_calls = [] - async def fetch_async(args): + @multi_value_chunked(build_request=_fake_build, url_limit=240) + async def fetch(args): # the single async fetcher — there is no sync sibling async_calls.append(tuple(args["sites"])) return pd.DataFrame({"id": [_atom_id(args)]}), _ok_response() - @multi_value_chunked( - build_request=_fake_build, fetch_async=fetch_async, url_limit=240 - ) - def fetch(args): # sync sibling must NOT run — the async path handles it - raise AssertionError("serial fallback must not run inside a live loop") - async def driver(): # call the sync getter from within a running loop + # The sync wrapper drives the async core through the anyio portal in + # a worker thread, so it works even inside a running event loop — + # neither a nested-``asyncio.run`` error nor a silent serial fallback. return fetch({"sites": ["S1" * 10, "S2" * 10, "S3" * 10, "S4" * 10]}) df, _ = asyncio.run(driver()) @@ -1582,72 +1623,90 @@ def test_retryable_taxonomy(): def test_retryable_skips_wrapped_midpagination_transient(): # A transient surfaced mid-pagination is re-wrapped as RuntimeError by - # _paginate; it must NOT be auto-retried (re-walking from page 1 would - # re-spend quota) — it escalates to the resumable handle instead. Only - # the raw, top-level (initial-request) transient is retryable. + # _paginate_async; it must NOT be auto-retried (re-walking from page 1 + # would re-spend quota) — it escalates to the resumable handle instead. + # Only the raw, top-level (initial-request) transient is retryable. assert _retryable(_wrap_cause(RateLimited("429", retry_after=3.0))) == (False, None) assert _retryable(RateLimited("429", retry_after=3.0)) == (True, 3.0) -# -- sync driver ------------------------------------------------------------ +# -- async driver (the single retry driver; sync facade drives it) ---------- +# +# The chunker is async-only now, so the retry loop lives solely in +# ``_retry_async``. These tests pin the same behavioral contracts the old +# ``_retry_sync`` tests asserted (transient-then-success, exhausted-reraises, +# non-retryable-not-retried, long-retry-after-escalates), re-expressed on the +# async driver and run via ``asyncio.run``; the sleep is patched to a no-op so +# backoff doesn't actually wait. -def test_retry_sync_transient_then_success(monkeypatch): - monkeypatch.setattr(_chunking, "_SLEEP", lambda _d: None) +def test_retry_async_transient_then_recovers(monkeypatch): + monkeypatch.setattr(_chunking, "_ASLEEP", _aiozero) calls = {"n": 0} - def fn(): + async def afn(): calls["n"] += 1 if calls["n"] <= 2: raise RateLimited("429") return "ok" - assert _retry_sync(fn, RetryPolicy(max_retries=3, base_backoff=0.0)) == "ok" + out = asyncio.run(_retry_async(afn, RetryPolicy(max_retries=3, base_backoff=0.0))) + assert out == "ok" assert calls["n"] == 3 # two failures + one success -def test_retry_sync_exhausted_reraises(monkeypatch): - monkeypatch.setattr(_chunking, "_SLEEP", lambda _d: None) +def test_retry_async_exhausted_reraises(monkeypatch): + monkeypatch.setattr(_chunking, "_ASLEEP", _aiozero) calls = {"n": 0} - def fn(): + async def afn(): calls["n"] += 1 raise ServiceUnavailable("503") with pytest.raises(ServiceUnavailable): - _retry_sync(fn, RetryPolicy(max_retries=2, base_backoff=0.0)) + asyncio.run(_retry_async(afn, RetryPolicy(max_retries=2, base_backoff=0.0))) assert calls["n"] == 3 # first attempt + 2 retries -def test_retry_sync_non_retryable_not_retried(monkeypatch): +def test_retry_async_non_retryable_not_retried(monkeypatch): slept: list[float] = [] - monkeypatch.setattr(_chunking, "_SLEEP", slept.append) + + def _record(delay): + slept.append(delay) + return _aiozero(delay) + + monkeypatch.setattr(_chunking, "_ASLEEP", _record) calls = {"n": 0} - def fn(): + async def afn(): calls["n"] += 1 raise RuntimeError("400: bad request") with pytest.raises(RuntimeError): - _retry_sync(fn, RetryPolicy(max_retries=3)) + asyncio.run(_retry_async(afn, RetryPolicy(max_retries=3))) assert calls["n"] == 1 and slept == [] -def test_retry_sync_long_retry_after_escalates(monkeypatch): +def test_retry_async_long_retry_after_escalates(monkeypatch): slept: list[float] = [] - monkeypatch.setattr(_chunking, "_SLEEP", slept.append) + + def _record(delay): + slept.append(delay) + return _aiozero(delay) + + monkeypatch.setattr(_chunking, "_ASLEEP", _record) calls = {"n": 0} - def fn(): + async def afn(): calls["n"] += 1 raise RateLimited("429", retry_after=999.0) with pytest.raises(RateLimited): - _retry_sync(fn, RetryPolicy(max_retries=5, retry_after_cap=60.0)) + asyncio.run(_retry_async(afn, RetryPolicy(max_retries=5, retry_after_cap=60.0))) assert calls["n"] == 1 and slept == [] # no inline wait for a long window -# -- async driver ----------------------------------------------------------- +# -- async driver (sleep-patched original) ---------------------------------- def test_retry_async_transient_then_success(monkeypatch): @@ -1674,10 +1733,10 @@ def test_chunker_retries_transient_then_completes(monkeypatch): """A transient on one sub-request is retried transparently; the decorated call completes with no ChunkInterrupted.""" monkeypatch.setenv("API_USGS_RETRIES", "3") - monkeypatch.setattr(_chunking, "_SLEEP", lambda _d: None) + monkeypatch.setattr(_chunking, "_ASLEEP", _aiozero) state = {"failed": False} - def fetch(args): + async def fetch(args): # Fail the first sub-request once, then succeed everywhere. if not state["failed"]: state["failed"] = True @@ -1694,10 +1753,10 @@ def test_chunker_exhausted_retries_still_resumable(monkeypatch): """When retries are exhausted the failure still surfaces as a resumable ChunkInterrupted — retries don't swallow the escape hatch.""" monkeypatch.setenv("API_USGS_RETRIES", "2") - monkeypatch.setattr(_chunking, "_SLEEP", lambda _d: None) + monkeypatch.setattr(_chunking, "_ASLEEP", _aiozero) attempts = {"n": 0} - def fetch(args): + async def fetch(args): sites = list(args["sites"]) if "S1" * 10 in sites: attempts["n"] += 1 @@ -1783,7 +1842,7 @@ def finalize(frame, response): state = {"n": 0} @multi_value_chunked(build_request=_fake_build, url_limit=240) - def fetch(args): + async def fetch(args): state["n"] += 1 if state["n"] == 2: raise ServiceUnavailable("503: simulated") diff --git a/tests/waterdata_filters_test.py b/tests/waterdata_filters_test.py index 32879318..c0968d98 100644 --- a/tests/waterdata_filters_test.py +++ b/tests/waterdata_filters_test.py @@ -157,7 +157,7 @@ def test_long_filter_fans_out_into_multiple_requests(): expr = _filter_chunking_clauses() sent_filters: list[str] = [] - def fake_walk_pages(*, geopd, req): + async def fake_walk_pages_async(*, geopd, req): idx = len(sent_filters) sent_filters.append(_query_params(req).get("filter", [None])[0]) return pd.DataFrame({"id": [f"chunk-{idx}"], "value": [idx]}), _fake_response() @@ -168,7 +168,8 @@ def fake_walk_pages(*, geopd, req): side_effect=_filter_size_aware_build, ), mock.patch( - "dataretrieval.waterdata.utils._walk_pages", side_effect=fake_walk_pages + "dataretrieval.waterdata.utils._walk_pages_async", + side_effect=fake_walk_pages_async, ), ): df, _ = get_continuous( @@ -195,7 +196,7 @@ def test_long_filter_deduplicates_cross_chunk_overlap(): expr = _filter_chunking_clauses() call_count = {"n": 0} - def fake_walk_pages(*_args, **_kwargs): + async def fake_walk_pages_async(*_args, **_kwargs): call_count["n"] += 1 return ( pd.DataFrame({"id": ["shared-feature"], "value": [1]}), @@ -208,7 +209,8 @@ def fake_walk_pages(*_args, **_kwargs): side_effect=_filter_size_aware_build, ), mock.patch( - "dataretrieval.waterdata.utils._walk_pages", side_effect=fake_walk_pages + "dataretrieval.waterdata.utils._walk_pages_async", + side_effect=fake_walk_pages_async, ), ): df, _ = get_continuous( @@ -237,7 +239,7 @@ def test_empty_chunks_do_not_downgrade_geodataframe(): expr = _filter_chunking_clauses() call_count = {"n": 0} - def fake_walk_pages(*_args, **_kwargs): + async def fake_walk_pages_async(*_args, **_kwargs): call_count["n"] += 1 if call_count["n"] == 2: return pd.DataFrame(), _fake_response() @@ -256,7 +258,8 @@ def fake_walk_pages(*_args, **_kwargs): side_effect=_filter_size_aware_build, ), mock.patch( - "dataretrieval.waterdata.utils._walk_pages", side_effect=fake_walk_pages + "dataretrieval.waterdata.utils._walk_pages_async", + side_effect=fake_walk_pages_async, ), ): df, _ = get_continuous( @@ -289,10 +292,12 @@ def fake_construct_api_requests(**kwargs): side_effect=fake_construct_api_requests, ), mock.patch( - "dataretrieval.waterdata.utils._walk_pages", - return_value=( - pd.DataFrame({"id": ["row-1"], "value": [1]}), - _fake_response(), + "dataretrieval.waterdata.utils._walk_pages_async", + new=mock.AsyncMock( + return_value=( + pd.DataFrame({"id": ["row-1"], "value": [1]}), + _fake_response(), + ) ), ), ): diff --git a/tests/waterdata_progress_test.py b/tests/waterdata_progress_test.py index 6efba212..68bdce07 100644 --- a/tests/waterdata_progress_test.py +++ b/tests/waterdata_progress_test.py @@ -6,6 +6,7 @@ reporter. """ +import asyncio import io import sys import types @@ -20,7 +21,19 @@ current, progress_context, ) -from dataretrieval.waterdata.utils import _walk_pages +from dataretrieval.waterdata.utils import _walk_pages_async + + +def _walk_pages(*, geopd, req, client): + """Drive the async ``_walk_pages_async`` to completion synchronously. + + The chunker core is async-only now, so these tests build an + ``AsyncMock(spec=httpx.AsyncClient)`` whose ``.send``/``.request`` are + awaitable and run the coroutine via ``asyncio.run``. The progress + reporter is bound on a contextvar, which the coroutine inherits when + ``asyncio.run`` copies the calling context. + """ + return asyncio.run(_walk_pages_async(geopd=geopd, req=req, client=client)) @pytest.fixture(autouse=True) @@ -355,7 +368,7 @@ def test_walk_pages_reports_pages_and_rate_limit(): ) resp2 = _resp([{"id": "2", "properties": {"v": "b"}}], rate_remaining="4998") - client = mock.MagicMock(spec=httpx.Client) + client = mock.AsyncMock(spec=httpx.AsyncClient) client.send.return_value = resp1 client.request.return_value = resp2 @@ -380,7 +393,7 @@ def test_walk_pages_reports_pages_and_rate_limit(): def test_walk_pages_without_context_does_not_error(): # No active reporter: pagination must still work and stay silent. resp = _resp([{"id": "1", "properties": {"v": "a"}}]) - client = mock.MagicMock(spec=httpx.Client) + client = mock.AsyncMock(spec=httpx.AsyncClient) client.send.return_value = resp req = mock.MagicMock(spec=httpx.Request) @@ -400,7 +413,7 @@ def test_broken_progress_stream_does_not_truncate_pagination(): [{"id": "1", "properties": {"v": "a"}}], next_url="https://example.com/p2" ) resp2 = _resp([{"id": "2", "properties": {"v": "b"}}]) - client = mock.MagicMock(spec=httpx.Client) + client = mock.AsyncMock(spec=httpx.AsyncClient) client.send.return_value = resp1 client.request.return_value = resp2 @@ -484,10 +497,11 @@ async def run(): def test_fan_out_async_sets_chunks_on_active_reporter(monkeypatch): - """``_fan_out_async`` records ``plan.total`` on the active reporter - so the progress line knows how many sub-requests are in flight. - It deliberately does NOT call ``start_chunk`` (which would be - misleading under parallel fan-out — chunks fire concurrently).""" + """The async fan-out core (``ChunkedCall._run``) records + ``plan.total`` on the active reporter so the progress line knows how + many sub-requests are in flight, and ticks ``current_chunk`` via + ``start_chunk(len(completed))`` as each gathered sub-request finishes + — reaching ``plan.total`` in the all-success case.""" import asyncio import pandas as pd @@ -517,16 +531,14 @@ async def fetch_async(args): headers={"x-ratelimit-remaining": "999"}, ) - def fetch_once(args): # noqa: ARG001 — never invoked on the happy parallel path - raise AssertionError("sync fetch must not run in this test") - stream = io.StringIO() async def run(): + # Drive the async execution core directly (the same coroutine the + # sync ``resume()`` facade runs through the anyio portal); the + # decorated async fetcher is the only fetcher now. with progress_context(service="daily", stream=stream, enabled=True) as rep: - await ChunkedCall(plan, fetch_once).resume_async( - fetch_async, max_concurrent=4 - ) + await ChunkedCall(plan, fetch_async)._run(4) return rep.total_chunks, rep.current_chunk total_recorded, current_recorded = asyncio.run(run()) diff --git a/tests/waterdata_utils_test.py b/tests/waterdata_utils_test.py index 6aa2728b..b2b7eecb 100644 --- a/tests/waterdata_utils_test.py +++ b/tests/waterdata_utils_test.py @@ -1,3 +1,4 @@ +import asyncio import json import logging from unittest import mock @@ -16,12 +17,24 @@ _handle_stats_nesting, _parse_retry_after, _raise_for_non_200, - _walk_pages, + _walk_pages_async, ) _LOGGER_NAME = _utils_module.__name__ +def _walk_pages(*, geopd, req, client): + """Drive the async ``_walk_pages_async`` to completion synchronously. + + The chunker core is async-only now, so these tests build an + ``AsyncMock(spec=httpx.AsyncClient)`` whose ``.send``/``.request`` are + awaitable and run the coroutine via ``asyncio.run``. This thin shim + keeps the historical sync-shaped call sites terse while exercising the + real async pagination loop. + """ + return asyncio.run(_walk_pages_async(geopd=geopd, req=req, client=client)) + + def test_get_args_basic(): local_vars = { "monitoring_location_id": "USGS-123", @@ -74,7 +87,7 @@ def test_walk_pages_multiple_mocked(): resp2.status_code = 200 # Mock client (Session) - mock_client = mock.MagicMock(spec=httpx.Client) + mock_client = mock.AsyncMock(spec=httpx.AsyncClient) # First call to send() returns resp1, then call to request() in loop returns resp2 mock_client.send.return_value = resp1 mock_client.request.return_value = resp2 @@ -112,7 +125,7 @@ def test_row_cap_truncates_and_stops_within_first_page(): resp1.status_code = 200 resp1.url = "https://example.com/page1" - mock_client = mock.MagicMock(spec=httpx.Client) + mock_client = mock.AsyncMock(spec=httpx.AsyncClient) mock_client.send.return_value = resp1 mock_req = mock.MagicMock(spec=httpx.Request) @@ -145,7 +158,7 @@ def _page(idx, *, has_next): resp.url = f"https://example.com/page{idx}" return resp - mock_client = mock.MagicMock(spec=httpx.Client) + mock_client = mock.AsyncMock(spec=httpx.AsyncClient) mock_client.send.return_value = _page(1, has_next=True) # page 2 still advertises a ``next`` (page 3) that must never be fetched. mock_client.request.return_value = _page(2, has_next=True) @@ -208,7 +221,7 @@ def _walk_pages_with_failure(failure_resp_or_exc): """Run _walk_pages where page 1 succeeds and page 2 fails as given.""" resp1 = _resp_ok([{"id": "1", "properties": {"val": "a"}}]) - mock_client = mock.MagicMock(spec=httpx.Client) + mock_client = mock.AsyncMock(spec=httpx.AsyncClient) mock_client.send.return_value = resp1 if isinstance(failure_resp_or_exc, BaseException): mock_client.request.side_effect = failure_resp_or_exc @@ -299,7 +312,7 @@ def test_walk_pages_wraps_initial_page_parse_error(): # Body is unparseable JSON (gateway HTML page, truncated stream). resp.json.side_effect = json.JSONDecodeError("Expecting value", "...", 0) - mock_client = mock.MagicMock(spec=httpx.Client) + mock_client = mock.AsyncMock(spec=httpx.AsyncClient) mock_client.send.return_value = resp mock_req = mock.MagicMock(spec=httpx.Request) @@ -394,7 +407,7 @@ def test_walk_pages_does_not_mutate_initial_response(): "links": [], } - mock_client = mock.MagicMock(spec=httpx.Client) + mock_client = mock.AsyncMock(spec=httpx.AsyncClient) mock_client.send.return_value = page1 mock_client.request.return_value = page2 @@ -448,7 +461,7 @@ def _run_get_stats_data_with_failure(failure_resp_or_exc, monkeypatch): mock.MagicMock(return_value=pd.DataFrame()), ) - mock_client = mock.MagicMock(spec=httpx.Client) + mock_client = mock.AsyncMock(spec=httpx.AsyncClient) mock_client.send.return_value = _stats_initial_ok() if isinstance(failure_resp_or_exc, BaseException): mock_client.request.side_effect = failure_resp_or_exc From fd6f67fe161483308cecb3a1802ca80720af4c15 Mon Sep 17 00:00:00 2001 From: thodson-usgs Date: Wed, 27 May 2026 15:42:18 -0500 Subject: [PATCH 06/16] refactor(waterdata): drop _async name suffixes; prune async/sync doc framing MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The module is async-only now, so the docs and names shouldn't imply a removed sync alternative. Docs: removed 'async-only', 'synchronous facade', 'serial sync path', 'no separate serial path', 'no silent serial degradation', and redundant 'async' qualifiers from the module docstring, the concurrency-env docs, the ContextVar/accessor comments, ChunkedCall, resume(), multi_value_chunked, and get_stats_data. resume() is documented by what it does (drive the call to completion) plus the one useful non-obvious property — it works inside a running event loop because it runs in a worker thread — not as a sync-vs-async contrast. Names: since the sync twins are gone, the async versions reclaim the bare names — _paginate_async -> _paginate, _walk_pages_async -> _walk_pages, _client_for_async -> _client_for, _retry_async -> _retry, get_active_async_client -> get_active_client, _chunked_async_client -> _chunked_client. The utils/progress test shim that drives the async _walk_pages synchronously is renamed _run_walk_pages to avoid the name clash, and a redundant initial-page-parse-error test (the former 'async sibling', now identical) is dropped. Docstrings/comments + mechanical rename only — no behavior change. Offline suite (264) + live getter suite (63) green. Co-Authored-By: Claude Opus 4.7 (1M context) --- dataretrieval/waterdata/chunking.py | 150 +++++++++++++--------------- dataretrieval/waterdata/utils.py | 61 ++++++----- tests/conftest.py | 2 +- tests/waterdata_chunking_test.py | 48 +++++---- tests/waterdata_filters_test.py | 20 ++-- tests/waterdata_progress_test.py | 22 ++-- tests/waterdata_utils_test.py | 51 ++-------- 7 files changed, 153 insertions(+), 201 deletions(-) diff --git a/dataretrieval/waterdata/chunking.py b/dataretrieval/waterdata/chunking.py index 48da4159..fe6ee6ab 100644 --- a/dataretrieval/waterdata/chunking.py +++ b/dataretrieval/waterdata/chunking.py @@ -9,20 +9,18 @@ sub-request URL fits. Requests that already fit get a trivial single-step plan — ``ChunkedCall`` has one code path either way. -Concurrency: the execution core is async-only. ``multi_value_chunked`` -fans every pending sub-request out under one ``asyncio.gather`` sharing -a single ``httpx.AsyncClient``; concurrency is bounded purely by the -client's connection pool (``httpx.Limits(max_connections=N, -max_keepalive_connections=N)``), so the pool — not a semaphore — -throttles. ``API_USGS_CONCURRENT`` resolves ``N``: an integer N > 1 -caps connections at N; ``1`` pins a single connection (effectively -serial); the literal ``unbounded`` removes the cap (``N=None``). The -default (16) is the server-friendly sweet spot; higher values can trip -USGS burst-protection 5xx in practice. The fan-out runs in a -short-lived worker thread (an ``anyio`` blocking portal), so the -synchronous public API drives it whether or not the caller is already -inside an event loop (Jupyter / IPython / async apps) — no nested-loop -error and no silent serial degradation. +Concurrency: ``multi_value_chunked`` fans every pending sub-request out +under one ``asyncio.gather`` sharing a single ``httpx.AsyncClient``; +concurrency is bounded purely by the client's connection pool +(``httpx.Limits(max_connections=N, max_keepalive_connections=N)``), so +the pool — not a semaphore — throttles. ``API_USGS_CONCURRENT`` resolves +``N``: an integer N > 1 caps connections at N; ``1`` pins a single +connection (one request at a time); the literal ``unbounded`` removes +the cap (``N=None``). The default (16) is the server-friendly sweet +spot; higher values can trip USGS burst-protection 5xx in practice. The +fan-out runs in a short-lived worker thread (an ``anyio`` blocking +portal), so it works whether or not the caller is already inside an +event loop (Jupyter / IPython / async apps). Retries: each sub-request is retried on a transient failure (429, 5xx, connect/read timeout) with exponential backoff + full jitter, @@ -38,8 +36,7 @@ ``ChunkedCall`` handle that owns the already-completed sub-request state (sparse-indexed, since gathered sub-requests complete out of order). Call ``.call.resume()`` once the underlying condition clears; -only the still-pending sub-requests are re-issued (``resume()`` is a -synchronous facade over the same async runner). ``Retry-After`` (when +only the still-pending sub-requests are re-issued. ``Retry-After`` (when the server sets it) is surfaced on the exception as ``.retry_after``. Dedup: list-axis chunks don't overlap; filter-axis chunks can, so @@ -121,13 +118,12 @@ # Response header USGS uses to advertise remaining hourly quota. _QUOTA_HEADER = "x-ratelimit-remaining" -# Environment variable that controls async fan-out concurrency. Read -# at call time (not import) so test patches via ``monkeypatch.setenv`` -# take effect. The default (16) is the server-friendly sweet spot: -# higher values trip the upstream into 5xx burst-protection in -# practice. Set to ``1`` to force the serial sync path, set to -# ``unbounded`` for no per-call cap (use sparingly — you own the -# upstream-burst risk). +# Environment variable that controls fan-out concurrency. Read at call +# time (not import) so test patches via ``monkeypatch.setenv`` take +# effect. The default (16) is the server-friendly sweet spot: higher +# values trip the upstream into 5xx burst-protection in practice. Set to +# ``1`` for a single connection, set to ``unbounded`` for no per-call cap +# (use sparingly — you own the upstream-burst risk). _CONCURRENCY_ENV = "API_USGS_CONCURRENT" _CONCURRENCY_DEFAULT = 16 _CONCURRENCY_UNBOUNDED = "unbounded" @@ -140,8 +136,8 @@ def _read_concurrency_env() -> int | None: Returns ------- int or None - ``1`` for the serial sync path; an integer >1 for bounded - parallelism; ``None`` to disable the per-call cap entirely + ``1`` for a single connection; an integer >1 for bounded + concurrency; ``None`` to disable the per-call cap entirely (``unbounded`` keyword). Unset → default of ``_CONCURRENCY_DEFAULT``. """ @@ -328,50 +324,50 @@ def backoff(self, attempt: int, retry_after: float | None) -> float: # The single shared ``httpx.AsyncClient`` of an in-flight chunked call, -# published (via :func:`_publish`) during ``ChunkedCall._run`` so async -# paginated-loop helpers downstream (``_walk_pages_async``) reuse one +# published (via :func:`_publish`) during ``ChunkedCall._run`` so the +# paginated-loop helpers downstream (``_walk_pages``) reuse one # connection pool across every gathered sub-request of the call. ``None`` # when not inside a chunked call — paginated helpers fall back to their # own short-lived client in that case. -_chunked_async_client: ContextVar[httpx.AsyncClient | None] = ContextVar( - "_chunked_async_client", default=None +_chunked_client: ContextVar[httpx.AsyncClient | None] = ContextVar( + "_chunked_client", default=None ) @contextmanager def _publish(client: httpx.AsyncClient) -> Iterator[None]: """ - Bind ``client`` to the ``_chunked_async_client`` ContextVar for the + Bind ``client`` to the ``_chunked_client`` ContextVar for the duration of the ``with`` block (wrapping the set/reset token dance), - so async paginated-loop helpers can borrow the chunker's shared - client via :func:`get_active_async_client`. + so the paginated-loop helpers can borrow the chunker's shared client + via :func:`get_active_client`. Parameters ---------- client : httpx.AsyncClient - The client to publish on ``_chunked_async_client``. + The client to publish on ``_chunked_client``. Yields ------ None Yields once, for the duration of the bind. """ - token = _chunked_async_client.set(client) + token = _chunked_client.set(client) try: yield finally: - _chunked_async_client.reset(token) + _chunked_client.reset(token) -def get_active_async_client() -> httpx.AsyncClient | None: +def get_active_client() -> httpx.AsyncClient | None: """ - Return the chunker's currently-published async client, or ``None``. + Return the chunker's currently-published client, or ``None``. - Public accessor for the ``_chunked_async_client`` ContextVar so + Public accessor for the ``_chunked_client`` ContextVar so sibling modules (notably - :func:`dataretrieval.waterdata.utils._client_for_async`) don't have - to reach into the private ContextVar directly. Used by async - paginated-loop helpers to reuse the per-call AsyncClient pool. + :func:`dataretrieval.waterdata.utils._client_for`) don't have + to reach into the private ContextVar directly. Used by the + paginated-loop helpers to reuse the per-call connection pool. Returns ------- @@ -379,7 +375,7 @@ def get_active_async_client() -> httpx.AsyncClient | None: The client published via :func:`_publish` if currently inside a :class:`ChunkedCall` run; ``None`` otherwise. """ - return _chunked_async_client.get() + return _chunked_client.get() # Separators the two axis kinds use to join their atoms back into @@ -388,9 +384,8 @@ def get_active_async_client() -> httpx.AsyncClient | None: _LIST_SEP = "," _OR_SEP = " OR " -# The chunker's execution core is async-only: the decorated fetcher, the -# ``ChunkedCall`` it drives, and the per-sub-request runner are all -# coroutines. ``_Fetch`` is an ``async def fetch(args) -> (df, response)``. +# ``_Fetch`` is the per-sub-request fetcher the decorator wraps and +# ``ChunkedCall`` drives: an ``async def fetch(args) -> (df, response)``. _Fetch = Callable[[dict[str, Any]], Awaitable[tuple[pd.DataFrame, httpx.Response]]] # Caller-supplied transform applied to the *combined* chunk result. It lets a @@ -1136,7 +1131,7 @@ def _retryable(exc: BaseException) -> tuple[bool, float | None]: Inspects only the *top-level* exception, by design — and so is deliberately narrower than :func:`_classify_chunk_error`, which walks - the ``__cause__`` chain for resumability. ``_paginate_async`` raises an + the ``__cause__`` chain for resumability. ``_paginate`` raises an initial-request transient (429 / 5xx / :class:`httpx.TransportError` such as ``ConnectError`` / ``ReadTimeout``) *raw*, but re-wraps any mid-pagination failure as a ``RuntimeError``. Retrying only the raw, @@ -1204,7 +1199,7 @@ def _retry_delay(exc: BaseException, attempt: int, policy: RetryPolicy) -> float return delay -async def _retry_async( +async def _retry( afn: Callable[[], Awaitable[tuple[pd.DataFrame, httpx.Response]]], policy: RetryPolicy, ) -> tuple[pd.DataFrame, httpx.Response]: @@ -1374,14 +1369,13 @@ class ChunkedCall: both for the first invocation (from :meth:`ChunkPlan.execute`) and for subsequent retries after a :class:`ChunkInterrupted`. - The execution core is the async :meth:`_run` (gather every pending - sub-request over one shared :class:`httpx.AsyncClient`, apply the - failure-precedence rules, combine); :meth:`resume` is a thin - synchronous facade that drives :meth:`_run` through an ``anyio`` - blocking portal so it works whether or not the caller is already - inside an event loop. There is no separate serial path: concurrency - is bounded purely by the client's connection pool, so a single - connection (``API_USGS_CONCURRENT=1``) is just a degenerate gather. + :meth:`_run` gathers every pending sub-request over one shared + :class:`httpx.AsyncClient`, applies the failure-precedence rules, and + combines; :meth:`resume` drives it through an ``anyio`` blocking + portal so it works whether or not the caller is already inside an + event loop. Concurrency is bounded purely by the client's connection + pool, so a single connection (``API_USGS_CONCURRENT=1``) is just a + degenerate gather. A ``ChunkedCall`` is created internally when a :class:`ChunkPlan` executes; callers reach it via :attr:`ChunkInterrupted.call` on @@ -1611,14 +1605,13 @@ def _pending(self) -> Iterator[tuple[int, dict[str, Any]]]: def resume(self) -> tuple[pd.DataFrame, Any]: """ - Drive the chunked call to completion. Synchronous facade. + Drive the chunked call to completion and return the combined result. - Runs the async core :meth:`_run` through an ``anyio`` blocking - portal (a short-lived worker thread), so the synchronous public - API works whether or not the caller is already inside an event - loop (Jupyter / IPython / async apps) — no nested-``asyncio.run`` - error. The portal copies the calling context, so the active - progress reporter still reaches the fan-out. + Runs :meth:`_run` through an ``anyio`` blocking portal (a + short-lived worker thread), so it works whether or not the caller + is already inside an event loop (Jupyter / IPython / async apps). + The portal copies the calling context, so the active progress + reporter still reaches the sub-requests. Idempotent: only sub-requests whose index isn't already in ``self._chunks`` are re-issued. Sub-args order matches @@ -1667,10 +1660,10 @@ async def _run(self, max_concurrent: int | None) -> tuple[pd.DataFrame, Any]: ``httpx.Limits(max_connections=N, max_keepalive_connections=N)`` where ``N = max_concurrent`` (``None`` for unbounded). There is no semaphore: the gather dispatches *every* pending sub-request and the - pool throttles, so ``N=1`` is just a single-connection gather - (effectively serial) and ``total <= 1`` is just a one-element gather. - The shared client is published on :data:`_chunked_async_client` so - async paginated-loop helpers reuse its connection pool. + pool throttles, so ``N=1`` is just a single-connection gather (one + request at a time) and ``total <= 1`` is just a one-element gather. + The shared client is published on :data:`_chunked_client` so + the paginated-loop helpers reuse its connection pool. Parameters ---------- @@ -1714,9 +1707,7 @@ async def track( index: int, args: dict[str, Any] ) -> tuple[pd.DataFrame, httpx.Response]: """One sub-request (with retry) + record + progress tick.""" - result = await _retry_async( - lambda: self.fetch(args), self.retry_policy - ) + result = await _retry(lambda: self.fetch(args), self.retry_policy) self.record(index, result) if reporter is not None: # Chunks finish out of order under gather, so tick the @@ -1775,18 +1766,15 @@ def multi_value_chunked( single-step plan, so the decorated function has one code path either way. - The decorated function is an ``async def fetch(args) -> (df, - response)``; the *returned wrapper is synchronous*. The wrapper builds - the :class:`ChunkPlan`, constructs a :class:`ChunkedCall` over the - async fetcher, and drives it to completion via - :meth:`ChunkedCall.resume` — which runs the async core in an ``anyio`` - worker-thread portal so it works whether or not the caller is already - inside an event loop (Jupyter / IPython / async apps). Every pending - sub-request is gathered under one :class:`httpx.AsyncClient`; - concurrency is bounded purely by the connection pool, sized from - ``API_USGS_CONCURRENT``. ``API_USGS_CONCURRENT=1`` is just a - single-connection gather and ``plan.total <= 1`` a one-element gather - — neither is a special-cased path. + Decorates an ``async def fetch(args) -> (df, response)`` and returns a + callable that builds the :class:`ChunkPlan`, constructs a + :class:`ChunkedCall` over the fetcher, and drives it to completion via + :meth:`ChunkedCall.resume` (an ``anyio`` worker-thread portal, so it + works whether or not the caller is already inside an event loop — + Jupyter / IPython / async apps). Every pending sub-request is gathered + under one :class:`httpx.AsyncClient`; concurrency is bounded purely by + the connection pool, sized from ``API_USGS_CONCURRENT`` (``1`` is a + single-connection gather, ``plan.total <= 1`` a one-element gather). Parameters ---------- diff --git a/dataretrieval/waterdata/utils.py b/dataretrieval/waterdata/utils.py index 8fabca34..9a2be2c4 100644 --- a/dataretrieval/waterdata/utils.py +++ b/dataretrieval/waterdata/utils.py @@ -33,7 +33,7 @@ RateLimited, ServiceUnavailable, _safe_elapsed, - get_active_async_client, + get_active_client, ) from dataretrieval.waterdata.types import ( PROFILE_LOOKUP, @@ -530,7 +530,7 @@ def _paginated_failure_message(pages_collected: int, cause: BaseException) -> st ------- str A message suitable for the ``RuntimeError`` that - ``_walk_pages_async`` and ``get_stats_data`` raise from the + ``_walk_pages`` and ``get_stats_data`` raise from the original exception. """ cause_str = str(cause).removesuffix(".") @@ -774,7 +774,7 @@ def _get_resp_data( # ``features`` is a real schema-drift shape (mirrors the guard in # ``_handle_stats_nesting``). Treat as empty rather than crash with # ``KeyError`` — the wrapped failure would otherwise look like a - # transient transport error to ``_paginate_async``'s exception handler. + # transient transport error to ``_paginate``'s exception handler. features = body.get("features") or [] if not features: return gpd.GeoDataFrame() if geopd else pd.DataFrame() @@ -806,7 +806,7 @@ def _get_resp_data( @asynccontextmanager -async def _client_for_async( +async def _client_for( client: httpx.AsyncClient | None, ) -> AsyncIterator[httpx.AsyncClient]: """ @@ -818,7 +818,7 @@ async def _client_for_async( here — the caller owns its lifecycle). 2. The chunker's shared async client if we're inside a :class:`~dataretrieval.waterdata.chunking.ChunkedCall` run (per - :func:`chunking.get_active_async_client`). Borrowed; the chunker + :func:`chunking.get_active_client`). Borrowed; the chunker closes it on exit. 3. A fresh short-lived ``httpx.AsyncClient`` opened here and closed on context exit. @@ -837,7 +837,7 @@ async def _client_for_async( if client is not None: yield client return - shared = get_active_async_client() + shared = get_active_client() if shared is not None: yield shared return @@ -890,14 +890,14 @@ def _aggregate_paginated_response( # Optional cap on the total rows a single paginated call accumulates before it # stops following ``next`` links. ``None`` (the default the data getters use) # means "no cap — fetch the whole series". Set via :func:`_row_cap` so the deep -# ``_paginate_async`` loop can honor it without threading the value through the +# ``_paginate`` loop can honor it without threading the value through the # generic chunker; this mirrors the ``_progress`` ambient-reporter pattern. _row_cap_var: ContextVar[int | None] = ContextVar("waterdata_row_cap", default=None) @contextmanager def _row_cap(max_rows: int | None) -> Iterator[None]: - """Cap the rows any :func:`_paginate_async` under this context will + """Cap the rows any :func:`_paginate` under this context will accumulate (``None`` = uncapped). Used by :func:`get_reference_table` to preview large tables without downloading every page.""" token = _row_cap_var.set(max_rows) @@ -907,7 +907,7 @@ def _row_cap(max_rows: int | None) -> Iterator[None]: _row_cap_var.reset(token) -async def _paginate_async( +async def _paginate( initial_req: httpx.Request, *, parse_response: Callable[[httpx.Response], tuple[pd.DataFrame, _Cursor | None]], @@ -918,7 +918,7 @@ async def _paginate_async( Drive a paginated request to completion over an :class:`httpx.AsyncClient`. - The common shape behind :func:`_walk_pages_async` and + The common shape behind :func:`_walk_pages` and :func:`get_stats_data`: send the initial request, then loop calling ``follow_up`` until ``parse_response`` reports a ``None`` cursor, accumulating frames and elapsed time. Any mid-pagination failure @@ -974,7 +974,7 @@ async def _paginate_async( """ logger.debug("Requesting: %s", initial_req.url) reporter = _progress.current() - async with _client_for_async(client) as sess: + async with _client_for(client) as sess: resp = await sess.send(initial_req) _raise_for_non_200(resp) initial_response = resp @@ -1040,8 +1040,8 @@ def _ogc_parse_response( ) -> tuple[pd.DataFrame, str | None]: """Parse one OGC API page: extract the DataFrame and the next-page URL. - The parse strategy :func:`_walk_pages_async` hands to - :func:`_paginate_async`. Coerces falsy cursors (empty href, etc.) to + The parse strategy :func:`_walk_pages` hands to + :func:`_paginate`. Coerces falsy cursors (empty href, etc.) to ``None`` so the paginate loop's ``while cursor is not None`` terminates instead of spinning on a meaningless value. """ @@ -1052,7 +1052,7 @@ def _ogc_parse_response( ) -async def _walk_pages_async( +async def _walk_pages( geopd: bool, req: httpx.Request, client: httpx.AsyncClient | None = None, @@ -1061,7 +1061,7 @@ async def _walk_pages_async( Iterate paginated OGC API responses asynchronously and aggregate them into one DataFrame. - Thin wrapper that hands off to :func:`_paginate_async` with + Thin wrapper that hands off to :func:`_paginate` with OGC-specific strategies: pages are parsed via :func:`_get_resp_data` (through :func:`_ogc_parse_response`) and the next-page cursor is the URL from the response's ``links`` array (per :func:`_next_req_url`). @@ -1074,7 +1074,7 @@ async def _walk_pages_async( The initial HTTP request to send. client : httpx.AsyncClient, optional Caller-borrowed client; ``None`` defers client management to - :func:`_paginate_async`. + :func:`_paginate`. Returns ------- @@ -1088,9 +1088,9 @@ async def _walk_pages_async( Raises ------ RuntimeError - See :func:`_paginate_async`. + See :func:`_paginate`. httpx.HTTPError - See :func:`_paginate_async`. + See :func:`_paginate`. """ method = req.method # ``httpx.Request.method`` is already upper-cased. headers = req.headers @@ -1099,7 +1099,7 @@ async def _walk_pages_async( async def follow_up(cursor: str, sess: httpx.AsyncClient) -> httpx.Response: return await sess.request(method, cursor, headers=headers, content=content) - return await _paginate_async( + return await _paginate( req, parse_response=functools.partial(_ogc_parse_response, geopd=geopd), follow_up=follow_up, @@ -1298,7 +1298,7 @@ def _finalize_ogc( ``max_rows`` is applied here (after dedup/sort, on the *combined* frame) rather than only per-sub-request, so a chunked call's total is bounded to exactly ``max_rows`` and a resumed call honors the cap too — the - per-``_paginate_async`` ``_row_cap`` is only an early-stop download bound. + per-``_paginate`` ``_row_cap`` is only an early-stop download bound. """ frame = _deal_with_empty(frame, properties, service) if convert_type: @@ -1415,7 +1415,7 @@ async def _fetch_once( ``(frame, response)``. """ req = _construct_api_requests(**args) - return await _walk_pages_async(geopd=GEOPANDAS, req=req) + return await _walk_pages(geopd=GEOPANDAS, req=req) def _handle_stats_nesting( @@ -1588,14 +1588,11 @@ def get_stats_data( handles pagination, processes results, and formats output according to the specified parameters. - Synchronous facade: the stats path doesn't go through - ``multi_value_chunked`` (its query shape has no chunkable list axes), - so it drives :func:`_paginate_async` directly through an ``anyio`` - blocking portal — the same async-only core the chunked getters use — - while keeping a synchronous signature and ``(df, BaseMetadata)`` - return. The portal runs the pagination loop in a short-lived worker - thread, so this works whether or not the caller is already inside an - event loop. + The stats path doesn't go through ``multi_value_chunked`` (its query + shape has no chunkable list axes), so it drives :func:`_paginate` + directly through an ``anyio`` blocking portal. The portal runs the + pagination loop in a short-lived worker thread, so this works whether + or not the caller is already inside an event loop. Parameters ---------- @@ -1633,7 +1630,7 @@ def get_stats_data( def parse_response(resp: httpx.Response) -> tuple[pd.DataFrame, str | None]: body = resp.json() - # Coerce falsy cursors ("", 0) to None so _paginate_async terminates. + # Coerce falsy cursors ("", 0) to None so _paginate terminates. # USGS uses "next": null at end-of-stream, but defensive coerce # protects against any "" sentinel a future schema might use. return _handle_stats_nesting(body, geopd=GEOPANDAS), body.get("next") or None @@ -1646,7 +1643,7 @@ async def follow_up(cursor: str, sess: httpx.AsyncClient) -> httpx.Response: ) async def _run() -> tuple[pd.DataFrame, httpx.Response]: - return await _paginate_async( + return await _paginate( req, parse_response=parse_response, follow_up=follow_up, @@ -1654,7 +1651,7 @@ async def _run() -> tuple[pd.DataFrame, httpx.Response]: ) # The stats path opens its own progress context (it doesn't go through - # ``multi_value_chunked``); ``_paginate_async`` reports pages/rate-limit + # ``multi_value_chunked``); ``_paginate`` reports pages/rate-limit # into it. The portal copies the calling context, so the reporter still # reaches the worker thread. with _progress.progress_context(service=service): diff --git a/tests/conftest.py b/tests/conftest.py index 93baeb1a..0928d610 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -51,7 +51,7 @@ def _serial_chunker(monkeypatch): which keeps sub-request dispatch deterministic enough for the mocked suite. ``API_USGS_RETRIES=0`` makes a single transient surface immediately rather than be retried. The mocked tests drive the async - ``_walk_pages_async`` (via ``asyncio.run`` / an ``AsyncMock`` client), + ``_walk_pages`` (via ``asyncio.run`` / an ``AsyncMock`` client), not a sync sibling. Pinning both keeps the test surface focused on the planner / fetch contracts; async-fan-out and retry tests opt in by overriding the env inside their body. diff --git a/tests/waterdata_chunking_test.py b/tests/waterdata_chunking_test.py index 6e616c1c..1f954f1e 100644 --- a/tests/waterdata_chunking_test.py +++ b/tests/waterdata_chunking_test.py @@ -42,9 +42,9 @@ RetryPolicy, ServiceInterrupted, ServiceUnavailable, - _chunked_async_client, + _chunked_client, _extract_axes, - _retry_async, + _retry, _retryable, multi_value_chunked, ) @@ -326,20 +326,20 @@ async def fetch(args): def test_chunked_session_shared_across_sub_requests(): """Every sub-request of one chunked call sees the same - ``httpx.AsyncClient`` on the ``_chunked_async_client`` ContextVar, so - downstream paginated helpers (``_walk_pages_async``) can reuse the + ``httpx.AsyncClient`` on the ``_chunked_client`` ContextVar, so + downstream paginated helpers (``_walk_pages``) can reuse the connection pool instead of handshaking fresh on each sub-request.""" sessions_seen = [] @multi_value_chunked(build_request=_fake_build, url_limit=240) async def fetch(args): - sessions_seen.append(_chunked_async_client.get()) + sessions_seen.append(_chunked_client.get()) return pd.DataFrame(), mock.Mock( elapsed=datetime.timedelta(seconds=0.1), headers={} ) # Outside a chunked call: no session published (in this thread/context). - assert _chunked_async_client.get() is None + assert _chunked_client.get() is None fetch({"sites": ["S1" * 10, "S2" * 10, "S3" * 10, "S4" * 10]}) @@ -352,7 +352,7 @@ async def fetch(args): assert len({id(s) for s in sessions_seen}) == 1 # The portal's worker context is torn down on exit, so the calling # thread's ContextVar still reads its default. - assert _chunked_async_client.get() is None + assert _chunked_client.get() is None def test_chunked_session_isolated_per_resume(): @@ -365,7 +365,7 @@ def test_chunked_session_isolated_per_resume(): @multi_value_chunked(build_request=_fake_build, url_limit=240) async def fetch(args): - sessions_seen.append(_chunked_async_client.get()) + sessions_seen.append(_chunked_client.get()) i = state["i"] state["i"] += 1 if i == 1 and state["blow_up"]: @@ -383,14 +383,14 @@ async def fetch(args): # First run published a shared client to its sub-requests; the calling # thread's ContextVar is unaffected (reads its default). - assert _chunked_async_client.get() is None + assert _chunked_client.get() is None first_run_sessions = list(sessions_seen) assert first_run_sessions and all(s is not None for s in first_run_sessions) state["blow_up"] = False excinfo.value.call.resume() # Second run's ContextVar is also reset in the calling thread. - assert _chunked_async_client.get() is None + assert _chunked_client.get() is None # The resume opened a FRESH client, distinct from the first run's, so no # closed client leaks across runs. resume_sessions = sessions_seen[len(first_run_sessions) :] @@ -930,9 +930,7 @@ def test_paginate_terminates_on_empty_string_cursor(): req.content = b"" req.url = "https://example.com/items?limit=1" - df, final = asyncio.run( - _utils._walk_pages_async(geopd=False, req=req, client=client) - ) + df, final = asyncio.run(_utils._walk_pages(geopd=False, req=req, client=client)) # Single send + zero follow-ups: the loop terminated on the empty cursor. assert client.send.called @@ -1623,7 +1621,7 @@ def test_retryable_taxonomy(): def test_retryable_skips_wrapped_midpagination_transient(): # A transient surfaced mid-pagination is re-wrapped as RuntimeError by - # _paginate_async; it must NOT be auto-retried (re-walking from page 1 + # _paginate; it must NOT be auto-retried (re-walking from page 1 # would re-spend quota) — it escalates to the resumable handle instead. # Only the raw, top-level (initial-request) transient is retryable. assert _retryable(_wrap_cause(RateLimited("429", retry_after=3.0))) == (False, None) @@ -1633,14 +1631,14 @@ def test_retryable_skips_wrapped_midpagination_transient(): # -- async driver (the single retry driver; sync facade drives it) ---------- # # The chunker is async-only now, so the retry loop lives solely in -# ``_retry_async``. These tests pin the same behavioral contracts the old +# ``_retry``. These tests pin the same behavioral contracts the old # ``_retry_sync`` tests asserted (transient-then-success, exhausted-reraises, # non-retryable-not-retried, long-retry-after-escalates), re-expressed on the # async driver and run via ``asyncio.run``; the sleep is patched to a no-op so # backoff doesn't actually wait. -def test_retry_async_transient_then_recovers(monkeypatch): +def test_retry_transient_then_recovers(monkeypatch): monkeypatch.setattr(_chunking, "_ASLEEP", _aiozero) calls = {"n": 0} @@ -1650,12 +1648,12 @@ async def afn(): raise RateLimited("429") return "ok" - out = asyncio.run(_retry_async(afn, RetryPolicy(max_retries=3, base_backoff=0.0))) + out = asyncio.run(_retry(afn, RetryPolicy(max_retries=3, base_backoff=0.0))) assert out == "ok" assert calls["n"] == 3 # two failures + one success -def test_retry_async_exhausted_reraises(monkeypatch): +def test_retry_exhausted_reraises(monkeypatch): monkeypatch.setattr(_chunking, "_ASLEEP", _aiozero) calls = {"n": 0} @@ -1664,11 +1662,11 @@ async def afn(): raise ServiceUnavailable("503") with pytest.raises(ServiceUnavailable): - asyncio.run(_retry_async(afn, RetryPolicy(max_retries=2, base_backoff=0.0))) + asyncio.run(_retry(afn, RetryPolicy(max_retries=2, base_backoff=0.0))) assert calls["n"] == 3 # first attempt + 2 retries -def test_retry_async_non_retryable_not_retried(monkeypatch): +def test_retry_non_retryable_not_retried(monkeypatch): slept: list[float] = [] def _record(delay): @@ -1683,11 +1681,11 @@ async def afn(): raise RuntimeError("400: bad request") with pytest.raises(RuntimeError): - asyncio.run(_retry_async(afn, RetryPolicy(max_retries=3))) + asyncio.run(_retry(afn, RetryPolicy(max_retries=3))) assert calls["n"] == 1 and slept == [] -def test_retry_async_long_retry_after_escalates(monkeypatch): +def test_retry_long_retry_after_escalates(monkeypatch): slept: list[float] = [] def _record(delay): @@ -1702,14 +1700,14 @@ async def afn(): raise RateLimited("429", retry_after=999.0) with pytest.raises(RateLimited): - asyncio.run(_retry_async(afn, RetryPolicy(max_retries=5, retry_after_cap=60.0))) + asyncio.run(_retry(afn, RetryPolicy(max_retries=5, retry_after_cap=60.0))) assert calls["n"] == 1 and slept == [] # no inline wait for a long window # -- async driver (sleep-patched original) ---------------------------------- -def test_retry_async_transient_then_success(monkeypatch): +def test_retry_transient_then_success(monkeypatch): async def _noslept(_d): return None @@ -1722,7 +1720,7 @@ async def afn(): raise httpx.ReadTimeout("slow") return "ok" - out = asyncio.run(_retry_async(afn, RetryPolicy(max_retries=3, base_backoff=0.0))) + out = asyncio.run(_retry(afn, RetryPolicy(max_retries=3, base_backoff=0.0))) assert out == "ok" and calls["n"] == 2 diff --git a/tests/waterdata_filters_test.py b/tests/waterdata_filters_test.py index c0968d98..b57b0c53 100644 --- a/tests/waterdata_filters_test.py +++ b/tests/waterdata_filters_test.py @@ -157,7 +157,7 @@ def test_long_filter_fans_out_into_multiple_requests(): expr = _filter_chunking_clauses() sent_filters: list[str] = [] - async def fake_walk_pages_async(*, geopd, req): + async def fake_walk_pages(*, geopd, req): idx = len(sent_filters) sent_filters.append(_query_params(req).get("filter", [None])[0]) return pd.DataFrame({"id": [f"chunk-{idx}"], "value": [idx]}), _fake_response() @@ -168,8 +168,8 @@ async def fake_walk_pages_async(*, geopd, req): side_effect=_filter_size_aware_build, ), mock.patch( - "dataretrieval.waterdata.utils._walk_pages_async", - side_effect=fake_walk_pages_async, + "dataretrieval.waterdata.utils._walk_pages", + side_effect=fake_walk_pages, ), ): df, _ = get_continuous( @@ -196,7 +196,7 @@ def test_long_filter_deduplicates_cross_chunk_overlap(): expr = _filter_chunking_clauses() call_count = {"n": 0} - async def fake_walk_pages_async(*_args, **_kwargs): + async def fake_walk_pages(*_args, **_kwargs): call_count["n"] += 1 return ( pd.DataFrame({"id": ["shared-feature"], "value": [1]}), @@ -209,8 +209,8 @@ async def fake_walk_pages_async(*_args, **_kwargs): side_effect=_filter_size_aware_build, ), mock.patch( - "dataretrieval.waterdata.utils._walk_pages_async", - side_effect=fake_walk_pages_async, + "dataretrieval.waterdata.utils._walk_pages", + side_effect=fake_walk_pages, ), ): df, _ = get_continuous( @@ -239,7 +239,7 @@ def test_empty_chunks_do_not_downgrade_geodataframe(): expr = _filter_chunking_clauses() call_count = {"n": 0} - async def fake_walk_pages_async(*_args, **_kwargs): + async def fake_walk_pages(*_args, **_kwargs): call_count["n"] += 1 if call_count["n"] == 2: return pd.DataFrame(), _fake_response() @@ -258,8 +258,8 @@ async def fake_walk_pages_async(*_args, **_kwargs): side_effect=_filter_size_aware_build, ), mock.patch( - "dataretrieval.waterdata.utils._walk_pages_async", - side_effect=fake_walk_pages_async, + "dataretrieval.waterdata.utils._walk_pages", + side_effect=fake_walk_pages, ), ): df, _ = get_continuous( @@ -292,7 +292,7 @@ def fake_construct_api_requests(**kwargs): side_effect=fake_construct_api_requests, ), mock.patch( - "dataretrieval.waterdata.utils._walk_pages_async", + "dataretrieval.waterdata.utils._walk_pages", new=mock.AsyncMock( return_value=( pd.DataFrame({"id": ["row-1"], "value": [1]}), diff --git a/tests/waterdata_progress_test.py b/tests/waterdata_progress_test.py index 68bdce07..8e1176de 100644 --- a/tests/waterdata_progress_test.py +++ b/tests/waterdata_progress_test.py @@ -21,11 +21,11 @@ current, progress_context, ) -from dataretrieval.waterdata.utils import _walk_pages_async +from dataretrieval.waterdata.utils import _walk_pages -def _walk_pages(*, geopd, req, client): - """Drive the async ``_walk_pages_async`` to completion synchronously. +def _run_walk_pages(*, geopd, req, client): + """Drive the async ``_walk_pages`` to completion synchronously. The chunker core is async-only now, so these tests build an ``AsyncMock(spec=httpx.AsyncClient)`` whose ``.send``/``.request`` are @@ -33,7 +33,7 @@ def _walk_pages(*, geopd, req, client): reporter is bound on a contextvar, which the coroutine inherits when ``asyncio.run`` copies the calling context. """ - return asyncio.run(_walk_pages_async(geopd=geopd, req=req, client=client)) + return asyncio.run(_walk_pages(geopd=geopd, req=req, client=client)) @pytest.fixture(autouse=True) @@ -379,7 +379,7 @@ def test_walk_pages_reports_pages_and_rate_limit(): stream = io.StringIO() with progress_context(service="daily", stream=stream, enabled=True): - df, _ = _walk_pages(geopd=False, req=req, client=client) + df, _ = _run_walk_pages(geopd=False, req=req, client=client) assert len(df) == 2 out = stream.getvalue() @@ -401,7 +401,7 @@ def test_walk_pages_without_context_does_not_error(): req.headers = {} req.url = "https://example.com/p1" - df, _ = _walk_pages(geopd=False, req=req, client=client) + df, _ = _run_walk_pages(geopd=False, req=req, client=client) assert len(df) == 1 assert current() is None @@ -423,7 +423,7 @@ def test_broken_progress_stream_does_not_truncate_pagination(): req.url = "https://example.com/p1" with progress_context(stream=_RaisingStream(), enabled=True): - df, _ = _walk_pages(geopd=False, req=req, client=client) + df, _ = _run_walk_pages(geopd=False, req=req, client=client) assert len(df) == 2 # both pages returned despite the broken progress stream @@ -431,14 +431,14 @@ def test_broken_progress_stream_does_not_truncate_pagination(): # -- async path integration ---------------------------------------------------- -def test_paginate_async_reports_pages_through_active_reporter(monkeypatch): +def test_paginate_reports_pages_through_active_reporter(monkeypatch): """The async paginate path must drive the same progress reporter the sync path does. Pages and rate-limit updates from each completed page should land via the active ``ProgressReporter``, exactly as they would on ``_walk_pages``.""" import asyncio - from dataretrieval.waterdata.utils import _paginate_async + from dataretrieval.waterdata.utils import _paginate resp1 = _resp( [{"id": "1", "properties": {"v": "a"}}], @@ -454,7 +454,7 @@ async def parse_response(resp): ) return mock.MagicMock(empty=False, __len__=lambda self: 1), nxt - # _paginate_async expects parse_response to be sync, like the sync path. + # _paginate expects parse_response to be sync, like the sync path. def parse_sync(resp): body = resp.json() nxt = next( @@ -479,7 +479,7 @@ async def follow_up(cursor, sess): async def run(): with progress_context(service="continuous", stream=stream, enabled=True): - df, _ = await _paginate_async( + df, _ = await _paginate( req, parse_response=parse_sync, follow_up=follow_up, diff --git a/tests/waterdata_utils_test.py b/tests/waterdata_utils_test.py index b2b7eecb..ac48788e 100644 --- a/tests/waterdata_utils_test.py +++ b/tests/waterdata_utils_test.py @@ -17,14 +17,14 @@ _handle_stats_nesting, _parse_retry_after, _raise_for_non_200, - _walk_pages_async, + _walk_pages, ) _LOGGER_NAME = _utils_module.__name__ -def _walk_pages(*, geopd, req, client): - """Drive the async ``_walk_pages_async`` to completion synchronously. +def _run_walk_pages(*, geopd, req, client): + """Drive the async ``_walk_pages`` to completion synchronously. The chunker core is async-only now, so these tests build an ``AsyncMock(spec=httpx.AsyncClient)`` whose ``.send``/``.request`` are @@ -32,7 +32,7 @@ def _walk_pages(*, geopd, req, client): keeps the historical sync-shaped call sites terse while exercising the real async pagination loop. """ - return asyncio.run(_walk_pages_async(geopd=geopd, req=req, client=client)) + return asyncio.run(_walk_pages(geopd=geopd, req=req, client=client)) def test_get_args_basic(): @@ -99,7 +99,7 @@ def test_walk_pages_multiple_mocked(): mock_req.url = "https://example.com/page1" # Call _walk_pages - df, final_resp = _walk_pages(geopd=False, req=mock_req, client=mock_client) + df, final_resp = _run_walk_pages(geopd=False, req=mock_req, client=mock_client) assert len(df) == 2 assert list(df["val"]) == ["a", "b"] @@ -134,7 +134,7 @@ def test_row_cap_truncates_and_stops_within_first_page(): mock_req.url = "https://example.com/page1" with _row_cap(2): - df, _ = _walk_pages(geopd=False, req=mock_req, client=mock_client) + df, _ = _run_walk_pages(geopd=False, req=mock_req, client=mock_client) assert len(df) == 2 # truncated to the cap, not the page's 3 rows assert not mock_client.request.called # ``next`` link never followed @@ -169,7 +169,7 @@ def _page(idx, *, has_next): mock_req.url = "https://example.com/page1" with _row_cap(2): - df, _ = _walk_pages(geopd=False, req=mock_req, client=mock_client) + df, _ = _run_walk_pages(geopd=False, req=mock_req, client=mock_client) assert len(df) == 2 assert mock_client.request.call_count == 1 # fetched page 2, stopped before 3 @@ -233,7 +233,7 @@ def _walk_pages_with_failure(failure_resp_or_exc): mock_req.headers = {} mock_req.url = "https://example.com/page1" - return _walk_pages(geopd=False, req=mock_req, client=mock_client) + return _run_walk_pages(geopd=False, req=mock_req, client=mock_client) def test_walk_pages_raises_on_connection_error_mid_pagination(): @@ -321,43 +321,12 @@ def test_walk_pages_wraps_initial_page_parse_error(): mock_req.url = "https://example.com/page1" with pytest.raises(RuntimeError, match="Paginated request failed") as excinfo: - _walk_pages(geopd=False, req=mock_req, client=mock_client) + _run_walk_pages(geopd=False, req=mock_req, client=mock_client) # The JSONDecodeError causing it is on __cause__ so callers can drill in. assert isinstance(excinfo.value.__cause__, json.JSONDecodeError) -def test_walk_pages_async_wraps_initial_page_parse_error(): - """Async sibling of the above. ``_paginate_async`` must wrap an - initial-page parse failure with the same ``RuntimeError`` shape so - callers get a consistent diagnostic across sync and async paths.""" - import asyncio - - from dataretrieval.waterdata.utils import _walk_pages_async - - resp = mock.MagicMock() - resp.status_code = 200 - resp.url = "https://example.com/page1" - resp.json.side_effect = json.JSONDecodeError("Expecting value", "...", 0) - - mock_client = mock.AsyncMock(spec=httpx.AsyncClient) - mock_client.send.return_value = resp - - mock_req = mock.MagicMock(spec=httpx.Request) - mock_req.method = "GET" - mock_req.headers = {} - mock_req.content = b"" - mock_req.url = "https://example.com/page1" - - async def run(): - await _walk_pages_async(geopd=False, req=mock_req, client=mock_client) - - with pytest.raises(RuntimeError, match="Paginated request failed") as excinfo: - asyncio.run(run()) - - assert isinstance(excinfo.value.__cause__, json.JSONDecodeError) - - def test_get_resp_data_handles_missing_features_key(): """Regression: a 200 with ``numberReturned > 0`` but no ``features`` key (real schema-drift shape) used to crash @@ -416,7 +385,7 @@ def test_walk_pages_does_not_mutate_initial_response(): mock_req.headers = {} mock_req.url = "https://example.com/page1" - df, final = _walk_pages(geopd=False, req=mock_req, client=mock_client) + df, final = _run_walk_pages(geopd=False, req=mock_req, client=mock_client) assert len(df) == 2 # The original first-page response object must be unmutated: From 9fe3a5a489a2c32a83bdc64b774a66a4dded8b14 Mon Sep 17 00:00:00 2001 From: thodson-usgs Date: Wed, 27 May 2026 15:58:34 -0500 Subject: [PATCH 07/16] test(waterdata): tidy the test suite MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Cleanup only — no test logic, assertions, or coverage changed: - Hoist ~40 in-body imports to module level across the waterdata test files (dedup against existing top-level imports; consolidate stray aliases — _httpx -> httpx, _dt -> datetime, _mock -> mock). The only imports left in-body are the importorskip-guarded geopandas/shapely ones. - Drop stale sync/async framing from docstrings/comments (the chunker is async now): no more 'serial path', 'sync sibling', 'async-only', 'requests-mock', or 'falls back to serial' archaeology. - Rename the autouse fixture _serial_chunker -> _pin_chunker_env (it pins API_USGS_CONCURRENT=1 / API_USGS_RETRIES=0; 'serial' implied a removed path) and tighten its docstring. Offline suite (264) + live getter suite (63) green; ruff clean. Co-Authored-By: Claude Opus 4.7 (1M context) --- tests/conftest.py | 37 +++------ tests/waterdata_chunking_test.py | 132 +++++++++++-------------------- tests/waterdata_filters_test.py | 11 +-- tests/waterdata_progress_test.py | 33 +++----- tests/waterdata_test.py | 5 +- tests/waterdata_utils_test.py | 39 +++------ 6 files changed, 85 insertions(+), 172 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 0928d610..6958c480 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,15 +2,10 @@ Test scaffolding for the dataretrieval test suite. * Relaxes ``pytest-httpx``'s strict-mode flags so unconsumed mocks and - unmatched real requests don't fail the suite (matches the historical - ``requests-mock``-style permissiveness the test code was written - against, and keeps mocked-URL setup terse). + unmatched requests don't fail the suite (keeps mocked-URL setup terse). * Pins ``API_USGS_CONCURRENT=1`` and ``API_USGS_RETRIES=0`` for every - test by default. The chunker core is async-only, so - ``API_USGS_CONCURRENT=1`` now means a single pooled connection (a - one-connection ``asyncio.gather``) rather than a separate serial code - path — deterministic enough for the mocked suite while a single - transient surfaces immediately (no backoff). Async-fan-out and retry + test by default, so sub-request dispatch is deterministic and a single + transient surfaces immediately (no backoff). Concurrency and retry tests opt in by re-setting the env vars inside their body via ``monkeypatch.setenv``. """ @@ -21,9 +16,8 @@ def pytest_collection_modifyitems(config, items): - """Apply relaxed ``pytest-httpx`` strict-mode settings to every - test in the suite — matches the permissive defaults the historical - tests were written against.""" + """Apply relaxed ``pytest-httpx`` strict-mode settings to every test + so unconsumed mocks and unmatched requests don't fail the suite.""" marker = pytest.mark.httpx_mock( assert_all_responses_were_requested=False, assert_all_requests_were_expected=False, @@ -41,19 +35,14 @@ def non_mocked_hosts() -> list[str]: @pytest.fixture(autouse=True) -def _serial_chunker(monkeypatch): - """Default every test to the single-connection, no-retry chunker path. - - Production defaults ``API_USGS_CONCURRENT`` to 16 (a wide pooled - fan-out) and ``API_USGS_RETRIES`` to 4. The chunker core is async-only - now — there is no separate serial path — so ``API_USGS_CONCURRENT=1`` - means a single pooled connection (a one-connection ``asyncio.gather``), - which keeps sub-request dispatch deterministic enough for the mocked - suite. ``API_USGS_RETRIES=0`` makes a single transient surface - immediately rather than be retried. The mocked tests drive the async - ``_walk_pages`` (via ``asyncio.run`` / an ``AsyncMock`` client), - not a sync sibling. Pinning both keeps the test surface focused on the - planner / fetch contracts; async-fan-out and retry tests opt in by +def _pin_chunker_env(monkeypatch): + """Pin every test to one connection and no retries. + + Production defaults ``API_USGS_CONCURRENT`` to 16 and + ``API_USGS_RETRIES`` to 4. Pinning ``API_USGS_CONCURRENT=1`` keeps + sub-request dispatch deterministic for the mocked suite, and + ``API_USGS_RETRIES=0`` makes a single transient surface immediately + rather than be retried. Concurrency and retry tests opt in by overriding the env inside their body. """ monkeypatch.setenv("API_USGS_CONCURRENT", "1") diff --git a/tests/waterdata_chunking_test.py b/tests/waterdata_chunking_test.py index 1f954f1e..ab23c075 100644 --- a/tests/waterdata_chunking_test.py +++ b/tests/waterdata_chunking_test.py @@ -19,10 +19,12 @@ import concurrent.futures import datetime import sys +import warnings from unittest import mock from urllib.parse import quote_plus import httpx +import numpy as np import pandas as pd import pytest @@ -30,8 +32,10 @@ pytest.skip("Skip entire module on Python < 3.10", allow_module_level=True) from dataretrieval.waterdata import chunking as _chunking +from dataretrieval.waterdata import utils as _utils from dataretrieval.waterdata.chunking import ( _LIST_SEP, + _NEVER_CHUNK, _OR_SEP, _QUOTA_HEADER, ChunkInterrupted, @@ -43,12 +47,16 @@ ServiceInterrupted, ServiceUnavailable, _chunked_client, + _combine_chunk_frames, + _combine_chunk_responses, _extract_axes, + _request_bytes, _retry, _retryable, + _safe_request_bytes, multi_value_chunked, ) -from dataretrieval.waterdata.utils import _construct_api_requests +from dataretrieval.waterdata.utils import _DATE_RANGE_PARAMS, _construct_api_requests def _aiozero(_d): @@ -103,9 +111,6 @@ def test_never_chunk_covers_all_date_range_params(): adding a new param to ``_DATE_RANGE_PARAMS`` without also adding it to ``_NEVER_CHUNK`` would silently let the chunker try to comma-join an interval string.""" - from dataretrieval.waterdata.chunking import _NEVER_CHUNK - from dataretrieval.waterdata.utils import _DATE_RANGE_PARAMS - missing = _DATE_RANGE_PARAMS - _NEVER_CHUNK assert not missing, ( f"_DATE_RANGE_PARAMS contains entries not in _NEVER_CHUNK: " @@ -694,15 +699,13 @@ def test_connection_error_wrapped_as_service_interrupted(): and the user would lose the resumable handle to ``.call.resume()``. Verify ``ChunkedCall`` wraps it as ``ServiceInterrupted`` so partial progress is preserved.""" - import httpx as _httpx - state = {"i": 0, "blow_up": True} async def fetch(args): i = state["i"] state["i"] += 1 if i == 2 and state["blow_up"]: - raise _httpx.ConnectError("connection reset") + raise httpx.ConnectError("connection reset") return ( pd.DataFrame({"sites": list(args["sites"])}), _quota_response(500), @@ -717,7 +720,7 @@ async def fetch(args): assert err.completed_chunks == 4 assert err.call is not None # The transport exception is on __cause__ so callers can drill in if needed. - assert isinstance(err.__cause__, _httpx.ConnectError) + assert isinstance(err.__cause__, httpx.ConnectError) # Resume after the upstream recovers. state["blow_up"] = False df, _ = err.call.resume() @@ -730,15 +733,13 @@ def test_invalid_url_wrapped_as_service_interrupted(): ``_classify_chunk_error`` an oversize follow-up URL escapes as raw ``InvalidURL`` and the user loses ``.call.resume()`` access to the partial state. Mirror the ConnectError test.""" - import httpx as _httpx - state = {"i": 0, "blow_up": True} async def fetch(args): i = state["i"] state["i"] += 1 if i == 2 and state["blow_up"]: - raise _httpx.InvalidURL("URL is too long: 65536 bytes > 65000") + raise httpx.InvalidURL("URL is too long: 65536 bytes > 65000") return ( pd.DataFrame({"sites": list(args["sites"])}), _quota_response(500), @@ -752,7 +753,7 @@ async def fetch(args): # Async fan-out: only the i==2 sub-request fails; the other four complete. assert err.completed_chunks == 4 assert err.call is not None - assert isinstance(err.__cause__, _httpx.InvalidURL) + assert isinstance(err.__cause__, httpx.InvalidURL) # The top-level message must surface the underlying cause text so # the user doesn't have to traverse ``__cause__`` to know what # actually failed (previously the message was generic "Service @@ -875,8 +876,6 @@ def test_combine_chunk_responses_returns_independent_headers(): hooks, metadata extensions) must not back-propagate into the underlying chunk response's headers, which still live on ``ChunkedCall._chunks``.""" - from dataretrieval.waterdata.chunking import _combine_chunk_responses - r0 = mock.Mock( elapsed=datetime.timedelta(seconds=0.1), headers={"X-Foo": "0"}, url="u0" ) @@ -899,13 +898,6 @@ def test_paginate_terminates_on_empty_string_cursor(): coerce falsy non-None values to None so an empty-string next- cursor (a real-but-unusual end-of-stream sentinel some pagination APIs use) doesn't trap us in an infinite ``follow_up('')`` loop.""" - import datetime as _dt - from unittest import mock as _mock - - import httpx as _httpx - - from dataretrieval.waterdata import utils as _utils - # Synthesize an OGC response with numberReturned > 0 and a "next" # link whose href is an empty string — simulating a server-side # sentinel that ``_next_req_url`` reads as ``""``. @@ -914,17 +906,17 @@ def test_paginate_terminates_on_empty_string_cursor(): "features": [{"id": "1", "properties": {"val": "a"}}], "links": [{"rel": "next", "href": ""}], } - resp = _mock.MagicMock(spec=_httpx.Response) + resp = mock.MagicMock(spec=httpx.Response) resp.status_code = 200 resp.url = "https://example.com/items?limit=1" - resp.elapsed = _dt.timedelta(seconds=0.1) + resp.elapsed = datetime.timedelta(seconds=0.1) resp.headers = {} resp.json.return_value = body_with_empty_next - client = _mock.AsyncMock(spec=_httpx.AsyncClient) + client = mock.AsyncMock(spec=httpx.AsyncClient) client.send.return_value = resp - req = _mock.MagicMock(spec=_httpx.Request) + req = mock.MagicMock(spec=httpx.Request) req.method = "GET" req.headers = {} req.content = b"" @@ -943,10 +935,6 @@ def test_combine_chunk_frames_does_not_collapse_none_ids(): so a blanket dedup would collapse every id-less row into one — silent data loss. The function must dedupe only the id-bearing rows and preserve id-less rows verbatim.""" - import numpy as np - - from dataretrieval.waterdata.chunking import _combine_chunk_frames - # Frame A has real ids; frame B has feature-IDs of None for two # different rows that must both survive. df_a = pd.DataFrame({"id": ["x", "y"], "val": [1, 2]}) @@ -962,8 +950,6 @@ def test_combine_chunk_frames_still_dedupes_overlapping_ids(): """The original dedup contract — overlapping OR-clause partitions that produce duplicate-id rows across chunks must still collapse to one row — has to keep working when ids ARE present.""" - from dataretrieval.waterdata.chunking import _combine_chunk_frames - df_a = pd.DataFrame({"id": ["x", "y"], "val": [1, 2]}) df_b = pd.DataFrame({"id": ["y", "z"], "val": [2, 3]}) combined = _combine_chunk_frames([df_a, df_b]) @@ -1016,10 +1002,6 @@ def test_request_bytes_sums_url_and_content(): the chunker just needs to size that single attribute alongside the URL. """ - import httpx - - from dataretrieval.waterdata.chunking import _request_bytes - # GET request with no body req = httpx.Request("GET", "https://x.example/ab") assert _request_bytes(req) == len("https://x.example/ab") @@ -1037,9 +1019,6 @@ def test_safe_request_bytes_treats_invalid_url_as_overflow(): contract is that ``_safe_request_bytes`` returns ``url_limit + 1`` (a value strictly greater than the limit) when ``build_request`` raises ``InvalidURL``.""" - import httpx - - from dataretrieval.waterdata.chunking import _safe_request_bytes def build_request(**kwargs): raise httpx.InvalidURL("URL too long") @@ -1054,8 +1033,6 @@ def test_chunk_plan_handles_initial_url_overflow(): crash ``ChunkPlan.__init__``; the planner falls back to a worst-case sub-request URL for ``canonical_url`` and proceeds to halve the over-limit axes normally.""" - import httpx - real_build = _fake_build def overflowing_build(**kwargs): @@ -1208,8 +1185,6 @@ def test_combine_chunk_frames_all_empty_preserves_geo_type(): pytest.importorskip("geopandas") import geopandas as gpd - from dataretrieval.waterdata.chunking import _combine_chunk_frames - empty_gdfs = [gpd.GeoDataFrame() for _ in range(3)] combined = _combine_chunk_frames(empty_gdfs) assert isinstance(combined, gpd.GeoDataFrame), ( @@ -1222,8 +1197,6 @@ def test_combine_chunk_frames_single_frame_is_safe_to_mutate(): input on the single-chunk fast path — a caller mutating ``call.partial_frame`` (a live view) must not back-propagate into the underlying ``_chunks[0][0]`` frame.""" - from dataretrieval.waterdata.chunking import _combine_chunk_frames - chunk = pd.DataFrame({"id": ["A", "B"], "value": [1, 2]}) returned = _combine_chunk_frames([chunk]) returned["new_col"] = "x" @@ -1245,16 +1218,15 @@ def test_iter_sub_args_passthrough_yields_a_copy(): # --- async fan-out path ---------------------------------------------------- # -# The chunker is async-only: every sub-request is gathered over one -# ``httpx.AsyncClient`` and concurrency is bounded purely by that client's -# connection pool, sized from ``API_USGS_CONCURRENT``. The conftest's -# ``_serial_chunker`` autouse pins ``API_USGS_CONCURRENT=1`` (a single -# connection) for the whole suite; each test below raises it so the gather -# can dispatch sub-requests under a wider pool. The decorated async fetcher -# is the SAME one used on both first-run and resume — there is no sync -# sibling anymore. No real ``httpx.AsyncClient`` round-trip occurs (the -# fakes return mock data), even though :meth:`ChunkedCall._run` opens one -# for pool management. +# Every sub-request is gathered over one ``httpx.AsyncClient`` and +# concurrency is bounded purely by that client's connection pool, sized +# from ``API_USGS_CONCURRENT``. The conftest's ``_pin_chunker_env`` +# autouse pins ``API_USGS_CONCURRENT=1`` (a single connection) for the +# whole suite; each test below raises it so the gather can dispatch +# sub-requests under a wider pool. The decorated async fetcher is the +# SAME one used on both first-run and resume. No real ``httpx.AsyncClient`` +# round-trip occurs (the fakes return mock data), even though +# :meth:`ChunkedCall._run` opens one for pool management. def _async_chunked_fetch(monkeypatch, fetch_async, *, max_concurrent=16): @@ -1275,8 +1247,8 @@ def _ok_response(remaining=None): def test_async_fan_out_emits_one_call_per_sub_request(monkeypatch): - """Parallel mode hits every sub-args once — same coverage as the - sync ``ChunkedCall`` path, just dispatched concurrently.""" + """The fan-out hits every sub-args exactly once, dispatched + concurrently.""" seen_args = [] async def fetch_async(args): @@ -1403,14 +1375,9 @@ async def fetch_async(args): def test_wide_concurrency_uses_async_fetcher_with_no_warning(monkeypatch): - """Re-expresses the old "falls back to serial when no async sibling - wired" intent for the async-only core: there is no async sibling to - wire and no serial fallback, so a wide ``API_USGS_CONCURRENT`` is - honored directly by the single async fetcher — every sub-request runs - on it and NO ``UserWarning`` is emitted (the env var is never silently - no-op'd, the previous regression that warning guarded).""" - import warnings - + """A wide ``API_USGS_CONCURRENT`` is honored directly by the single + async fetcher: every sub-request fans out across it and NO + ``UserWarning`` is emitted.""" calls = [] monkeypatch.setenv("API_USGS_CONCURRENT", "16") @@ -1430,20 +1397,20 @@ async def fetch(args): def test_async_fan_out_runs_inside_running_event_loop(monkeypatch): """The parallel fan-out works even when the caller is already inside a running event loop (Jupyter / async apps): the anyio blocking portal - runs it in a worker thread, so it neither raises a nested - ``asyncio.run`` error nor silently degrades to the serial path.""" + runs it in a worker thread, so it does not raise a nested + ``asyncio.run`` error.""" monkeypatch.setenv("API_USGS_CONCURRENT", "16") async_calls = [] @multi_value_chunked(build_request=_fake_build, url_limit=240) - async def fetch(args): # the single async fetcher — there is no sync sibling + async def fetch(args): # the single async fetcher async_calls.append(tuple(args["sites"])) return pd.DataFrame({"id": [_atom_id(args)]}), _ok_response() async def driver(): # call the sync getter from within a running loop # The sync wrapper drives the async core through the anyio portal in - # a worker thread, so it works even inside a running event loop — - # neither a nested-``asyncio.run`` error nor a silent serial fallback. + # a worker thread, so it works even inside a running event loop + # without raising a nested-``asyncio.run`` error. return fetch({"sites": ["S1" * 10, "S2" * 10, "S3" * 10, "S4" * 10]}) df, _ = asyncio.run(driver()) @@ -1503,14 +1470,10 @@ def test_combine_chunk_responses_does_not_mutate_input_urls(): 'input responses are not mutated' invariant. The fix is to swap in a fresh ``httpx.Request`` rather than mutate the existing one. """ - import httpx as _httpx - - from dataretrieval.waterdata.chunking import _combine_chunk_responses - - req1 = _httpx.Request("GET", "https://example.com/chunk0") - req2 = _httpx.Request("GET", "https://example.com/chunk1") - r1 = _httpx.Response(200, request=req1) - r2 = _httpx.Response(200, request=req2) + req1 = httpx.Request("GET", "https://example.com/chunk0") + req2 = httpx.Request("GET", "https://example.com/chunk1") + r1 = httpx.Response(200, request=req1) + r2 = httpx.Response(200, request=req2) out = _combine_chunk_responses( [r1, r2], canonical_url="https://canonical.example/full" @@ -1630,12 +1593,11 @@ def test_retryable_skips_wrapped_midpagination_transient(): # -- async driver (the single retry driver; sync facade drives it) ---------- # -# The chunker is async-only now, so the retry loop lives solely in -# ``_retry``. These tests pin the same behavioral contracts the old -# ``_retry_sync`` tests asserted (transient-then-success, exhausted-reraises, -# non-retryable-not-retried, long-retry-after-escalates), re-expressed on the -# async driver and run via ``asyncio.run``; the sleep is patched to a no-op so -# backoff doesn't actually wait. +# The retry loop lives in ``_retry``. These tests pin its behavioral +# contracts (transient-then-success, exhausted-reraises, +# non-retryable-not-retried, long-retry-after-escalates), run via +# ``asyncio.run``; the sleep is patched to a no-op so backoff doesn't +# actually wait. def test_retry_transient_then_recovers(monkeypatch): @@ -1835,8 +1797,8 @@ def finalize(frame, response): return frame.assign(finalized=True), ("METADATA", response) # Fail the 2nd issued sub-request once (the 1st completes, so partial - # state is non-empty), then succeed on resume. Conftest pins the serial, - # no-retry path, so the failure surfaces immediately. + # state is non-empty), then succeed on resume. Conftest pins a single + # connection and no retries, so the failure surfaces immediately. state = {"n": 0} @multi_value_chunked(build_request=_fake_build, url_limit=240) diff --git a/tests/waterdata_filters_test.py b/tests/waterdata_filters_test.py index b57b0c53..a447cada 100644 --- a/tests/waterdata_filters_test.py +++ b/tests/waterdata_filters_test.py @@ -6,6 +6,7 @@ import pandas as pd import pytest +from dataretrieval.waterdata import get_continuous from dataretrieval.waterdata.filters import ( _check_numeric_filter_pitfall, _split_top_level_or, @@ -152,8 +153,6 @@ def test_long_filter_fans_out_into_multiple_requests(): sub-requests via the joint planner; every original clause is preserved across sub-requests; results concatenate to one row per sub-request given the one-row-per-chunk mock.""" - from dataretrieval.waterdata import get_continuous - expr = _filter_chunking_clauses() sent_filters: list[str] = [] @@ -191,8 +190,6 @@ async def fake_walk_pages(*, geopd, req): def test_long_filter_deduplicates_cross_chunk_overlap(): """Features returned by multiple sub-requests with the same ``id`` are deduplicated in the concatenated result.""" - from dataretrieval.waterdata import get_continuous - expr = _filter_chunking_clauses() call_count = {"n": 0} @@ -234,8 +231,6 @@ def test_empty_chunks_do_not_downgrade_geodataframe(): import geopandas as gpd from shapely.geometry import Point - from dataretrieval.waterdata import get_continuous - expr = _filter_chunking_clauses() call_count = {"n": 0} @@ -276,8 +271,6 @@ async def fake_walk_pages(*_args, **_kwargs): def test_cql_json_filter_is_not_chunked(): """Chunking applies only to cql-text; cql-json is passed through unchanged.""" - from dataretrieval.waterdata import get_continuous - clause = "(time >= '2023-01-01T00:00:00Z' AND time <= '2023-01-01T00:30:00Z')" expr = " OR ".join([clause] * 300) sent_filters = [] @@ -433,8 +426,6 @@ def test_get_continuous_surfaces_pitfall_to_caller(): """End-to-end: the check runs at the ``get_continuous`` boundary, not as a deep internal-only protection, so callers see the error before any HTTP traffic.""" - from dataretrieval.waterdata import get_continuous - with mock.patch("dataretrieval.waterdata.utils._construct_api_requests") as build: with pytest.raises(ValueError, match="lexicographic"): get_continuous( diff --git a/tests/waterdata_progress_test.py b/tests/waterdata_progress_test.py index 8e1176de..08f6ca26 100644 --- a/tests/waterdata_progress_test.py +++ b/tests/waterdata_progress_test.py @@ -7,12 +7,14 @@ """ import asyncio +import datetime import io import sys import types from unittest import mock import httpx +import pandas as pd import pytest from dataretrieval.waterdata import _progress @@ -21,7 +23,8 @@ current, progress_context, ) -from dataretrieval.waterdata.utils import _walk_pages +from dataretrieval.waterdata.chunking import ChunkedCall, ChunkPlan +from dataretrieval.waterdata.utils import _paginate, _walk_pages def _run_walk_pages(*, geopd, req, client): @@ -428,18 +431,14 @@ def test_broken_progress_stream_does_not_truncate_pagination(): assert len(df) == 2 # both pages returned despite the broken progress stream -# -- async path integration ---------------------------------------------------- +# -- pagination integration ---------------------------------------------------- def test_paginate_reports_pages_through_active_reporter(monkeypatch): - """The async paginate path must drive the same progress reporter the - sync path does. Pages and rate-limit updates from each completed - page should land via the active ``ProgressReporter``, exactly as - they would on ``_walk_pages``.""" - import asyncio - - from dataretrieval.waterdata.utils import _paginate - + """The async paginate path must drive the same progress reporter. + Pages and rate-limit updates from each completed page should land + via the active ``ProgressReporter``, exactly as they would on + ``_walk_pages``.""" resp1 = _resp( [{"id": "1", "properties": {"v": "a"}}], next_url="https://example.com/p2", @@ -454,14 +453,12 @@ async def parse_response(resp): ) return mock.MagicMock(empty=False, __len__=lambda self: 1), nxt - # _paginate expects parse_response to be sync, like the sync path. + # parse_response is sync (like the page parsers). def parse_sync(resp): body = resp.json() nxt = next( (link["href"] for link in body["links"] if link["rel"] == "next"), None ) - import pandas as pd - return pd.DataFrame(body["features"]), nxt async def follow_up(cursor, sess): @@ -502,11 +499,6 @@ def test_fan_out_async_sets_chunks_on_active_reporter(monkeypatch): many sub-requests are in flight, and ticks ``current_chunk`` via ``start_chunk(len(completed))`` as each gathered sub-request finishes — reaching ``plan.total`` in the all-success case.""" - import asyncio - - import pandas as pd - - from dataretrieval.waterdata.chunking import ChunkedCall, ChunkPlan # Fake build_request whose URL length scales with the sites list, # mirroring the planner's _request_bytes contract. _FakeReq has the @@ -527,7 +519,7 @@ def build(*, sites): async def fetch_async(args): return pd.DataFrame({"id": [",".join(args["sites"])]}), mock.Mock( - elapsed=__import__("datetime").timedelta(seconds=0.01), + elapsed=datetime.timedelta(seconds=0.01), headers={"x-ratelimit-remaining": "999"}, ) @@ -535,8 +527,7 @@ async def fetch_async(args): async def run(): # Drive the async execution core directly (the same coroutine the - # sync ``resume()`` facade runs through the anyio portal); the - # decorated async fetcher is the only fetcher now. + # sync ``resume()`` facade runs through the anyio portal). with progress_context(service="daily", stream=stream, enabled=True) as rep: await ChunkedCall(plan, fetch_async)._run(4) return rep.total_chunks, rep.current_chunk diff --git a/tests/waterdata_test.py b/tests/waterdata_test.py index ea91ae9e..9e0d4c70 100644 --- a/tests/waterdata_test.py +++ b/tests/waterdata_test.py @@ -3,6 +3,7 @@ import sys from unittest import mock +import numpy as np import pandas as pd import pytest from pandas import DataFrame @@ -564,8 +565,6 @@ def test_get_reference_table_accepts_numpy_int_max_rows(): # numpy integers are valid caps: isinstance(np.int64, int) is False, so the # validation must accept numbers.Integral (not just int) — otherwise a cap # derived from a numpy/pandas computation is wrongly rejected. - import numpy as np - df, _ = get_reference_table("agency-codes", max_rows=np.int64(2)) assert len(df) == 2 @@ -712,8 +711,6 @@ def test_pandas_series_normalizes_to_list(self): assert isinstance(result, list) def test_numpy_array_normalizes_to_list(self): - import numpy as np - result = _normalize_str_iterable(np.array(["00060", "00010"]), "p") assert result == ["00060", "00010"] assert isinstance(result, list) diff --git a/tests/waterdata_utils_test.py b/tests/waterdata_utils_test.py index ac48788e..7063b767 100644 --- a/tests/waterdata_utils_test.py +++ b/tests/waterdata_utils_test.py @@ -1,4 +1,5 @@ import asyncio +import datetime import json import logging from unittest import mock @@ -10,14 +11,21 @@ import dataretrieval.waterdata.utils as _utils_module from dataretrieval.waterdata.chunking import RateLimited, ServiceUnavailable from dataretrieval.waterdata.utils import ( + OGC_API_URL, _arrange_cols, + _check_ogc_requests, _error_body, + _finalize_ogc, _format_api_dates, _get_args, + _get_resp_data, _handle_stats_nesting, + _next_req_url, _parse_retry_after, _raise_for_non_200, + _row_cap, _walk_pages, + get_stats_data, ) _LOGGER_NAME = _utils_module.__name__ @@ -113,8 +121,6 @@ def test_row_cap_truncates_and_stops_within_first_page(): # Regression for BUG 2: ``_row_cap`` bounds the TOTAL rows. A first page # already over the cap is truncated to exactly ``max_rows`` and the # ``next`` link is never followed. - from dataretrieval.waterdata.utils import _row_cap - resp1 = mock.MagicMock() resp1.json.return_value = { "numberReturned": 3, @@ -143,8 +149,6 @@ def test_row_cap_truncates_and_stops_within_first_page(): def test_row_cap_stops_across_pages(): # The cap accumulates across pages: page 1 (1 row) is under the cap so # page 2 is fetched; once the cap (2) is met the third page is NOT. - from dataretrieval.waterdata.utils import _row_cap - def _page(idx, *, has_next): resp = mock.MagicMock() nxt = f"https://example.com/page{idx + 1}" @@ -179,10 +183,6 @@ def test_finalize_ogc_truncates_combined_to_max_rows(): # max_rows is enforced on the *combined* frame in _finalize_ogc (after # dedup/sort), so it bounds the total exactly even when a chunked call's # per-sub-request pages overshoot the per-_paginate early-stop. - import datetime - - from dataretrieval.waterdata.utils import _finalize_ogc - frame = pd.DataFrame({"id": [str(i) for i in range(10)]}) resp = mock.MagicMock() resp.url = "https://example.com/q" @@ -334,8 +334,6 @@ def test_get_resp_data_handles_missing_features_key(): ``_paginate`` as a generic transport error. ``_handle_stats_nesting`` was already hardened against this; ``_get_resp_data`` now mirrors that defensiveness and returns an empty frame instead.""" - from dataretrieval.waterdata.utils import _get_resp_data - resp = mock.Mock() resp.json.return_value = {"numberReturned": 1, "links": []} df = _get_resp_data(resp, geopd=False) @@ -350,12 +348,10 @@ def test_walk_pages_does_not_mutate_initial_response(): ``.elapsed`` before pagination completed (a Session response hook, a logging middleware) must continue to see the original first-page values — NOT the rewritten cumulative values.""" - import datetime as _dt - page1 = mock.MagicMock() page1.status_code = 200 page1.url = "https://example.com/page1" - page1.elapsed = _dt.timedelta(seconds=1) + page1.elapsed = datetime.timedelta(seconds=1) page1.headers = {"x-ratelimit-remaining": "999"} page1.json.return_value = { "numberReturned": 1, @@ -368,7 +364,7 @@ def test_walk_pages_does_not_mutate_initial_response(): page2 = mock.MagicMock() page2.status_code = 200 page2.url = "https://example.com/page2" - page2.elapsed = _dt.timedelta(seconds=2) + page2.elapsed = datetime.timedelta(seconds=2) page2.headers = {"x-ratelimit-remaining": "998"} page2.json.return_value = { "numberReturned": 1, @@ -396,7 +392,7 @@ def test_walk_pages_does_not_mutate_initial_response(): # The returned aggregate carries page-2 headers + cumulative elapsed. assert final.headers["x-ratelimit-remaining"] == "998" - assert final.elapsed == _dt.timedelta(seconds=3) + assert final.elapsed == datetime.timedelta(seconds=3) # And mutating the aggregate's headers doesn't leak into either page. final.headers["X-Trace-Id"] = "abc" assert "X-Trace-Id" not in page1.headers @@ -422,8 +418,6 @@ def _run_get_stats_data_with_failure(failure_resp_or_exc, monkeypatch): `monkeypatch` stubs ``_handle_stats_nesting`` so the synthetic minimal response body doesn't need to parse — these tests only assert on the pagination loop's error surfacing.""" - from dataretrieval.waterdata.utils import get_stats_data - monkeypatch.setattr( _utils_module, "_handle_stats_nesting", @@ -547,8 +541,6 @@ def test_get_resp_data_empty_preserves_geopd_type(): ``GeoDataFrame`` (not a plain ``DataFrame``) when geopd is True, so paginating across a sparse intermediate page doesn't downgrade the final concat result.""" - from dataretrieval.waterdata.utils import _get_resp_data - fake_gpd = mock.MagicMock() class _Sentinel: @@ -578,8 +570,6 @@ def test_get_resp_data_always_materializes_id_column(): ``_arrange_cols`` rename to the service-specific output_id (``daily_id``, ``channel_measurements_id``, etc.) isn't a silent no-op.""" - from dataretrieval.waterdata.utils import _get_resp_data - resp = mock.MagicMock() resp.json.return_value = { "numberReturned": 2, @@ -704,8 +694,6 @@ def test_format_api_dates_rejects_mapping(): """`time={"2024-01-01": "x"}` would silently materialize as the keys list, accepting input the user clearly didn't intend. """ - import pytest - with pytest.raises(TypeError, match="date input must be a string or sequence"): _format_api_dates({"2024-01-01": "ignored"}) @@ -848,8 +836,6 @@ def test_next_req_url_rejects_cross_host(): auth-like artifacts) were minted for the original host; following a server-supplied cross-host URL would leak them — and the URL itself could be sensitive.""" - from dataretrieval.waterdata.utils import _next_req_url - resp = mock.MagicMock() resp.url = httpx.URL("https://api.waterdata.usgs.gov/page1") body = { @@ -867,9 +853,6 @@ def test_check_ogc_requests_raises_typed_on_5xx(httpx_mock): ``_raise_for_non_200`` so callers see ``ServiceUnavailable`` / ``RateLimited`` / ``RuntimeError`` — the same typed contract as the main data path.""" - from dataretrieval.waterdata.chunking import ServiceUnavailable - from dataretrieval.waterdata.utils import OGC_API_URL, _check_ogc_requests - httpx_mock.add_response( method="GET", url=f"{OGC_API_URL}/collections/daily/schema", From c507d620c8f47324f12c6335bc48c85091c0a9eb Mon Sep 17 00:00:00 2001 From: thodson-usgs Date: Wed, 27 May 2026 16:18:24 -0500 Subject: [PATCH 08/16] docs(waterdata): trim over-documentation in chunking.py Cut accreted redundancy (no behavior change): the concurrency-env constants comment now points to the module docstring / _read_concurrency_env instead of triplicating the knob semantics; the _NEVER_CHUNK exclusion taxonomy is compressed 16->7 lines (reasons kept); completed_chunks loses a Returns block that restated its one-line summary; the ChunkInterrupted snapshot comment drops the .copy() archaeology; and multi_value_chunked's two overlapping paragraphs collapse to one, deferring the concurrency model to the module docstring. Net -30 lines. Co-Authored-By: Claude Opus 4.7 (1M context) --- dataretrieval/waterdata/chunking.py | 74 +++++++++-------------------- 1 file changed, 22 insertions(+), 52 deletions(-) diff --git a/dataretrieval/waterdata/chunking.py b/dataretrieval/waterdata/chunking.py index fe6ee6ab..2233b6e5 100644 --- a/dataretrieval/waterdata/chunking.py +++ b/dataretrieval/waterdata/chunking.py @@ -81,22 +81,13 @@ # leaves ~200 bytes for request-line framing and proxy variance. _WATERDATA_URL_BYTE_LIMIT = 8000 -# Default rule: any list-shaped kwarg with >1 element is chunked across -# sub-requests — each chunk becomes a comma-joined sub-list in the URL. -# The OGC getters expose ~90 such list-shaped params (IDs, codes, -# statuses, ...), all chunkable, so it's shorter to enumerate the -# exceptions than to maintain an allowlist that grows with the API. -# Exceptions, by reason: -# - response shape: ``properties`` defines the columns; sharding -# would yield different schemas per chunk. -# - structured: ``bbox`` is a fixed 4-element coord tuple. -# - intervals: date/time ranges are not enumerable sets. -# - handled elsewhere: ``filter`` becomes its own axis in -# ``_extract_axes`` (joiner ``" OR "``); -# comma-joining CQL clauses would emit -# malformed expressions. -# - scalar by contract: ``limit``, ``skip_geometry``, ``filter_lang`` -# — a list value would be a type-erasure smuggle. +# Any list-shaped kwarg with >1 element is chunked (comma-joined per +# sub-list in the URL); ~90 OGC params qualify, so we denylist the few +# exceptions rather than maintain a growing allowlist. Excluded because: +# ``properties`` defines the column schema; ``bbox`` is a fixed coord +# tuple; date/time params are intervals, not enumerable sets; ``filter`` +# is handled as its own OR-axis in ``_extract_axes``; and ``limit`` / +# ``skip_geometry`` / ``filter_lang`` are scalar by contract. _NEVER_CHUNK = frozenset( { "properties", @@ -118,12 +109,9 @@ # Response header USGS uses to advertise remaining hourly quota. _QUOTA_HEADER = "x-ratelimit-remaining" -# Environment variable that controls fan-out concurrency. Read at call -# time (not import) so test patches via ``monkeypatch.setenv`` take -# effect. The default (16) is the server-friendly sweet spot: higher -# values trip the upstream into 5xx burst-protection in practice. Set to -# ``1`` for a single connection, set to ``unbounded`` for no per-call cap -# (use sparingly — you own the upstream-burst risk). +# Fan-out concurrency cap, read at call time (not import) so test +# ``monkeypatch.setenv`` applies. Value grammar in :func:`_read_concurrency_env`; +# the concurrency model is in the module docstring. _CONCURRENCY_ENV = "API_USGS_CONCURRENT" _CONCURRENCY_DEFAULT = 16 _CONCURRENCY_UNBOUNDED = "unbounded" @@ -564,13 +552,10 @@ def __init__( self.total_chunks = total_chunks self.call = call self.retry_after = retry_after - # Snapshot partial state at raise time so the exception's view - # stays stable across later ``call.resume()`` advances; the - # live view lives on ``call.partial_frame``/``.partial_response``. - # ``partial_frame`` gets a defensive ``.copy()`` because - # ``_combine_chunk_frames`` may return a chunk frame verbatim - # in the single-completed-chunk fast path; ``partial_response`` - # already comes via ``copy.copy`` from ``_combine_chunk_responses``. + # Snapshot partial state at raise time so the exception's view stays + # stable across later ``call.resume()`` advances (the live view is on + # ``call.partial_frame`` / ``.partial_response``). ``.copy()`` guards + # the single-chunk fast path, where the frame may be returned verbatim. if call is None: self.partial_frame: pd.DataFrame = pd.DataFrame() self.partial_response: httpx.Response | None = None @@ -1485,14 +1470,7 @@ def wrap_failure(self, exc: BaseException) -> ChunkInterrupted | None: @property def completed_chunks(self) -> int: - """ - Number of sub-requests completed so far. - - Returns - ------- - int - The count of completed sub-requests. - """ + """Number of sub-requests completed so far.""" return len(self._chunks) def _combine_raw(self) -> tuple[pd.DataFrame, httpx.Response]: @@ -1760,21 +1738,13 @@ def multi_value_chunked( """ Decorate an async fetcher to transparently chunk over-budget requests. - Splits multi-value list params and cql-text filters across - sub-requests so each fits the URL byte limit. Builds a - :class:`ChunkPlan` and runs it: passthrough requests are a trivial - single-step plan, so the decorated function has one code path - either way. - - Decorates an ``async def fetch(args) -> (df, response)`` and returns a - callable that builds the :class:`ChunkPlan`, constructs a - :class:`ChunkedCall` over the fetcher, and drives it to completion via - :meth:`ChunkedCall.resume` (an ``anyio`` worker-thread portal, so it - works whether or not the caller is already inside an event loop — - Jupyter / IPython / async apps). Every pending sub-request is gathered - under one :class:`httpx.AsyncClient`; concurrency is bounded purely by - the connection pool, sized from ``API_USGS_CONCURRENT`` (``1`` is a - single-connection gather, ``plan.total <= 1`` a one-element gather). + Returns a callable that builds a :class:`ChunkPlan` from ``args``, + constructs a :class:`ChunkedCall` over the decorated + ``async def fetch(args) -> (df, response)``, and drives it to + completion via :meth:`ChunkedCall.resume`. The plan splits multi-value + list params and the cql-text filter so each sub-request URL fits the + byte limit; an already-fitting request is a one-step plan. See the + module docstring for the concurrency model. Parameters ---------- From 79a90172b68c2ead482019ea4269f4fd50cc2931 Mon Sep 17 00:00:00 2001 From: thodson-usgs Date: Wed, 27 May 2026 16:21:43 -0500 Subject: [PATCH 09/16] docs(waterdata): editorial pass on chunking.py docs MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Readability + accuracy: - Module docstring: 'ChunkedCall iterates the joint cartesian product so every sub-request URL fits' attributed the fit guarantee to ChunkedCall, but that's ChunkPlan's job — reworded so ChunkPlan keeps each URL under budget and ChunkedCall fetches the resulting product. - Dropped two duplicated explanations: the sparse-completion [0,2,5] example (kept on the class docstring, trimmed from __init__) and the 'no semaphore' note (kept in _run's docstring, trimmed from its inline comment). Verified the docs carry no stale references after the async-only refactor + renames: every :meth:/:func:/:class:/:attr: cross-ref resolves, the retry defaults (4 / 0.5s / 30s / 60s) match the constants, and the only 'semaphore' mentions are correct negations (pool throttles, not a semaphore). Co-Authored-By: Claude Opus 4.7 (1M context) --- dataretrieval/waterdata/chunking.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/dataretrieval/waterdata/chunking.py b/dataretrieval/waterdata/chunking.py index 2233b6e5..e45d37b1 100644 --- a/dataretrieval/waterdata/chunking.py +++ b/dataretrieval/waterdata/chunking.py @@ -4,9 +4,9 @@ parameter (sites, parameter codes, …) plus the cql-text ``filter``, which splits along its top-level OR clauses. Any of them can fan the URL past the server's ~8 KB byte limit. ``ChunkPlan`` picks a fan-out -for each axis that minimizes total sub-requests under the URL budget; -``ChunkedCall`` iterates the joint cartesian product so every -sub-request URL fits. Requests that already fit get a trivial +for each axis that minimizes total sub-requests while keeping every +sub-request URL under the budget; ``ChunkedCall`` fetches the resulting +cartesian product of chunks. Requests that already fit get a trivial single-step plan — ``ChunkedCall`` has one code path either way. Concurrency: ``multi_value_chunked`` fans every pending sub-request out @@ -1412,10 +1412,8 @@ def __init__( self.fetch = fetch self.retry_policy = retry_policy self.finalize = finalize - # Completed (frame, response) pairs keyed by sub-args index. - # Sparse so the gather can record scattered completions (e.g. - # indices [0, 2, 5] when 1/3/4 failed) and a subsequent - # ``resume()`` only re-issues the missing indices. + # Completed (frame, response) pairs keyed by sub-args index; sparse + # (gathered sub-requests complete out of order — see class docstring). self._chunks: dict[int, tuple[pd.DataFrame, httpx.Response]] = {} def record(self, index: int, pair: tuple[pd.DataFrame, httpx.Response]) -> None: @@ -1669,8 +1667,7 @@ async def _run(self, max_concurrent: int | None) -> tuple[pd.DataFrame, Any]: # ``httpx.Limits()`` defaults to ``max_connections=100`` — at higher # concurrency the pool would silently bottleneck the fan-out behind # that cap. Set it to the resolved concurrency so the pool *is* the - # throttle (``None`` for truly unbounded). No semaphore: we gather - # every pending sub-request and let the pool serialize. + # throttle (``None`` for truly unbounded). limits = httpx.Limits( max_connections=max_concurrent, max_keepalive_connections=max_concurrent ) From bbabf5058be99984310b1d168b6b65c3239472b4 Mon Sep 17 00:00:00 2001 From: thodson-usgs Date: Wed, 27 May 2026 16:56:03 -0500 Subject: [PATCH 10/16] docs(waterdata): address PR 285 review comments on chunking.py docs - Drop the 'not a semaphore' clarification (module docstring + _run). - Omit the 'All four are ... power users' sentence from the retry-defaults comment. - Remove the '(is this worth retrying at all?)' note-to-self in RetryPolicy. - Copy-edit two dense passages for readability (the _Finalize comment and the _retryable docstring). - Drop the 'The async execution core' lead-in from _run's docstring. Co-Authored-By: Claude Opus 4.7 (1M context) --- dataretrieval/waterdata/chunking.py | 66 +++++++++++++---------------- 1 file changed, 29 insertions(+), 37 deletions(-) diff --git a/dataretrieval/waterdata/chunking.py b/dataretrieval/waterdata/chunking.py index e45d37b1..004ce10b 100644 --- a/dataretrieval/waterdata/chunking.py +++ b/dataretrieval/waterdata/chunking.py @@ -13,7 +13,7 @@ under one ``asyncio.gather`` sharing a single ``httpx.AsyncClient``; concurrency is bounded purely by the client's connection pool (``httpx.Limits(max_connections=N, max_keepalive_connections=N)``), so -the pool — not a semaphore — throttles. ``API_USGS_CONCURRENT`` resolves +the pool throttles. ``API_USGS_CONCURRENT`` resolves ``N``: an integer N > 1 caps connections at N; ``1`` pins a single connection (one request at a time); the literal ``unbounded`` removes the cap (``N=None``). The default (16) is the server-friendly sweet @@ -153,13 +153,10 @@ def _read_concurrency_env() -> int | None: # Retry-with-backoff defaults for transient sub-request failures (429 / -# 5xx / connect-read timeouts). All four are resolved at call time by -# ``RetryPolicy.from_env`` (the env var via ``monkeypatch.setenv``, the -# timing constants via ``monkeypatch.setattr`` on this module), so both -# are overridable in tests and by power users. Defaults: 4 retries, 0.5s -# base doubling under full jitter up to a 30s per-attempt ceiling, and -# honor a server ``Retry-After`` up to 60s before escalating to a -# resumable interruption instead. +# 5xx / connect-read timeouts): 4 retries, 0.5s base doubling under full +# jitter up to a 30s per-attempt ceiling, and honor a server +# ``Retry-After`` up to 60s before escalating to a resumable interruption +# instead. _RETRIES_ENV = "API_USGS_RETRIES" _RETRIES_DEFAULT = 4 _RETRY_BASE_BACKOFF = 0.5 @@ -196,7 +193,7 @@ class RetryPolicy: """Bounded retry-with-backoff config for transient sub-request failures. An immutable value object that owns the *timing* decisions; the - exception taxonomy ("is this worth retrying at all?") lives in + exception taxonomy (which failures are retryable) lives in :func:`_retryable`. Backoff is exponential with **full jitter** (:func:`random.uniform` over ``[0, ceiling]``) so the concurrent fan-out's retries don't re-burst in lockstep. A server ``Retry-After`` @@ -376,14 +373,12 @@ def get_active_client() -> httpx.AsyncClient | None: # ``ChunkedCall`` drives: an ``async def fetch(args) -> (df, response)``. _Fetch = Callable[[dict[str, Any]], Awaitable[tuple[pd.DataFrame, httpx.Response]]] -# Caller-supplied transform applied to the *combined* chunk result. It lets a -# resumed call (:meth:`ChunkedCall.resume` / :attr:`~ChunkedCall.partial_frame` -# / :attr:`~ChunkedCall.partial_response`) return the same shape as the -# un-interrupted call instead of the chunker's raw ``(frame, httpx.Response)``. -# The chunker stays generic — it only knows "post-process the assembled -# result"; the OGC getters inject the actual type-coercion / column-arrangement -# / ``BaseMetadata`` pipeline (see ``utils._finalize_ogc``). The default is -# identity, so direct ``ChunkedCall`` use and the tests are unaffected. +# Caller-supplied transform applied to the combined chunk result, so a +# resumed call returns the same shape as an un-interrupted one rather than +# the chunker's raw ``(frame, httpx.Response)``. This keeps the chunker +# generic: the OGC getters inject their post-processing (type coercion, +# column arrangement, ``BaseMetadata``) through ``utils._finalize_ogc``. +# The default is identity, so direct ``ChunkedCall`` use is unaffected. _Finalize = Callable[[pd.DataFrame, httpx.Response], tuple[pd.DataFrame, Any]] @@ -1114,18 +1109,16 @@ def _retryable(exc: BaseException) -> tuple[bool, float | None]: """ Decide whether ``exc`` is a transient worth an automatic retry. - Inspects only the *top-level* exception, by design — and so is - deliberately narrower than :func:`_classify_chunk_error`, which walks - the ``__cause__`` chain for resumability. ``_paginate`` raises an - initial-request transient (429 / 5xx / :class:`httpx.TransportError` - such as ``ConnectError`` / ``ReadTimeout``) *raw*, but re-wraps any - mid-pagination failure as a ``RuntimeError``. Retrying only the raw, - top-level transient means we re-issue a sub-request that made no - progress (cheap), while a failure after partial pagination escalates - to the resumable :class:`ChunkInterrupted` instead of being re-walked - from page 1 — which would re-spend the very quota that was exhausted. - ``httpx.InvalidURL`` is excluded (a too-long cursor won't fix on - retry), and it only ever arises on a follow-up page anyway. + Only the *top-level* exception is inspected — unlike + :func:`_classify_chunk_error`, which walks the ``__cause__`` chain. + The distinction matters because ``_paginate`` raises an + initial-request transient (429 / 5xx / :class:`httpx.TransportError`) + *raw*, but wraps a mid-pagination failure as a ``RuntimeError``. So a + raw transient means a sub-request that made no progress and is cheap to + re-issue, whereas a mid-pagination failure is left to escalate to a + resumable :class:`ChunkInterrupted` rather than re-walked from page 1 + (which would re-spend the quota just exhausted). ``httpx.InvalidURL`` + is never retried — a too-long cursor won't fix on a retry. Returns ------- @@ -1620,9 +1613,8 @@ def resume(self) -> tuple[pd.DataFrame, Any]: async def _run(self, max_concurrent: int | None) -> tuple[pd.DataFrame, Any]: """ - The async execution core: gather every pending sub-request over - one shared :class:`httpx.AsyncClient` and return the combined, - finalized result. + Gather every pending sub-request over one shared + :class:`httpx.AsyncClient` and return the combined, finalized result. Pending sub-requests (:meth:`_pending`) fan out under ``asyncio.gather`` with ``return_exceptions=True`` so completed @@ -1632,12 +1624,12 @@ async def _run(self, max_concurrent: int | None) -> tuple[pd.DataFrame, Any]: ``.call``; ``exc.call.resume()`` then re-issues only the unfinished indices through this same runner. - Concurrency is bounded purely by the client's connection pool — + Concurrency is bounded by the client's connection pool — ``httpx.Limits(max_connections=N, max_keepalive_connections=N)`` - where ``N = max_concurrent`` (``None`` for unbounded). There is no - semaphore: the gather dispatches *every* pending sub-request and the - pool throttles, so ``N=1`` is just a single-connection gather (one - request at a time) and ``total <= 1`` is just a one-element gather. + where ``N = max_concurrent`` (``None`` for unbounded). The gather + dispatches *every* pending sub-request and the pool throttles, so + ``N=1`` is just a single-connection gather (one request at a time) + and ``total <= 1`` is just a one-element gather. The shared client is published on :data:`_chunked_client` so the paginated-loop helpers reuse its connection pool. From ead1c090965816fa3a7bba9867400d6445556ebd Mon Sep 17 00:00:00 2001 From: thodson-usgs Date: Thu, 28 May 2026 07:56:58 -0500 Subject: [PATCH 11/16] docs(waterdata): final editorial pass on chunking.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Retry-defaults comment: drop the hardcoded numbers (4 / 0.5s / 30s / 60s) — they live in the constants below and were drifting from them. - _publish docstring: drop the 'set/reset token dance' mechanism leak; state the contract. - get_active_client docstring: drop the 'public accessor / private ContextVar' justification paragraph; keep the one-liner + the paginated-loop usage cue. - combined(): drop the 'terminal success result' paragraph that duplicated the Returns section; move the finalize return-shape detail into Returns where it belongs. Co-Authored-By: Claude Opus 4.7 (1M context) --- dataretrieval/waterdata/chunking.py | 41 +++++++++++++---------------- 1 file changed, 18 insertions(+), 23 deletions(-) diff --git a/dataretrieval/waterdata/chunking.py b/dataretrieval/waterdata/chunking.py index 004ce10b..3811719f 100644 --- a/dataretrieval/waterdata/chunking.py +++ b/dataretrieval/waterdata/chunking.py @@ -153,10 +153,9 @@ def _read_concurrency_env() -> int | None: # Retry-with-backoff defaults for transient sub-request failures (429 / -# 5xx / connect-read timeouts): 4 retries, 0.5s base doubling under full -# jitter up to a 30s per-attempt ceiling, and honor a server -# ``Retry-After`` up to 60s before escalating to a resumable interruption -# instead. +# 5xx / connect-read timeouts): exponential backoff with full jitter, and +# honor a server ``Retry-After`` up to the cap below before escalating +# to a resumable interruption instead. _RETRIES_ENV = "API_USGS_RETRIES" _RETRIES_DEFAULT = 4 _RETRY_BASE_BACKOFF = 0.5 @@ -322,15 +321,14 @@ def backoff(self, attempt: int, retry_after: float | None) -> float: @contextmanager def _publish(client: httpx.AsyncClient) -> Iterator[None]: """ - Bind ``client`` to the ``_chunked_client`` ContextVar for the - duration of the ``with`` block (wrapping the set/reset token dance), - so the paginated-loop helpers can borrow the chunker's shared client - via :func:`get_active_client`. + Publish ``client`` on the ``_chunked_client`` ContextVar so the + paginated-loop helpers can borrow it via :func:`get_active_client` + for the duration of the ``with`` block. Parameters ---------- client : httpx.AsyncClient - The client to publish on ``_chunked_client``. + The client to publish. Yields ------ @@ -348,11 +346,9 @@ def get_active_client() -> httpx.AsyncClient | None: """ Return the chunker's currently-published client, or ``None``. - Public accessor for the ``_chunked_client`` ContextVar so - sibling modules (notably - :func:`dataretrieval.waterdata.utils._client_for`) don't have - to reach into the private ContextVar directly. Used by the - paginated-loop helpers to reuse the per-call connection pool. + Used by the paginated-loop helpers (e.g. + :func:`dataretrieval.waterdata.utils._client_for`) to reuse the + per-call connection pool. Returns ------- @@ -1493,19 +1489,18 @@ def combined(self) -> tuple[pd.DataFrame, Any]: """ Combine every recorded sub-request and apply :attr:`finalize`. - The terminal *success* result: :meth:`_run` returns this, so a - completed call (first run or resume) yields the same shape - ``finalize`` produces — a raw ``(frame, httpx.Response)`` by - default, or the OGC getters' type-coerced / column-arranged frame - plus ``BaseMetadata``. The ``partial_*`` accessors deliberately do - NOT go through here — they return the raw :meth:`_combine_raw` - snapshot to stay cheap and side-effect-free. + Returned by :meth:`_run` on a completed call (first run or + resume). The ``partial_*`` accessors deliberately do NOT route + through here — they return the raw :meth:`_combine_raw` snapshot + to stay cheap and side-effect-free. Returns ------- tuple of (pandas.DataFrame, finalized response) - The combined frame and the finalized aggregate response / - metadata that :attr:`finalize` produces. + The combined frame and whatever :attr:`finalize` produces — + a raw :class:`httpx.Response` by default, or the OGC + getters' type-coerced / column-arranged frame plus + ``BaseMetadata``. """ return self.finalize(*self._combine_raw()) From 01c734b95e353fa0eddc8f7ce1b672788f2d8b94 Mon Sep 17 00:00:00 2001 From: thodson-usgs Date: Thu, 28 May 2026 08:17:08 -0500 Subject: [PATCH 12/16] docs(waterdata): correct interruption claims to include transport errors Module docstring, ChunkedCall.resume() Raises, and ChunkedCall._run all listed only 429/5xx as the failures that raise ChunkInterrupted, but _classify_chunk_error also wraps bare httpx.HTTPError (ConnectError, TimeoutException, RemoteProtocolError, ...) and httpx.InvalidURL as ServiceInterrupted (chunking.py:1098). So callers who only caught the 429/5xx case per the docs could miss the transport-error path. Fix: list transport errors alongside 429/5xx in all three docstrings, and name QuotaExhausted vs ServiceInterrupted by which case maps where. Surfaced by a docs-vs-code audit; no functional change. Co-Authored-By: Claude Opus 4.7 (1M context) --- dataretrieval/waterdata/chunking.py | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/dataretrieval/waterdata/chunking.py b/dataretrieval/waterdata/chunking.py index 3811719f..0a674634 100644 --- a/dataretrieval/waterdata/chunking.py +++ b/dataretrieval/waterdata/chunking.py @@ -30,9 +30,10 @@ the resumable interruption below so a multi-minute quota-window reset doesn't block the call. -Interruption: any mid-stream transient failure (429, 5xx) surfaces -as a ``ChunkInterrupted`` subclass — ``QuotaExhausted`` for 429, -``ServiceInterrupted`` for 5xx. The exception carries ``.call``, a +Interruption: any mid-stream transient failure — 429, 5xx, or a bare +transport error (connect/read timeout, oversize follow-up URL) — surfaces +as a ``ChunkInterrupted`` subclass: ``QuotaExhausted`` for 429, +``ServiceInterrupted`` for the rest. The exception carries ``.call``, a ``ChunkedCall`` handle that owns the already-completed sub-request state (sparse-indexed, since gathered sub-requests complete out of order). Call ``.call.resume()`` once the underlying condition clears; @@ -1596,11 +1597,11 @@ def resume(self) -> tuple[pd.DataFrame, Any]: Raises ------ ChunkInterrupted - On a mid-stream transient failure - (:class:`QuotaExhausted` for 429, - :class:`ServiceInterrupted` for 5xx). The resumable handle - is on ``exc.call`` — wait for the underlying condition to - clear and call ``exc.call.resume()`` again. + On a mid-stream transient failure — 429, 5xx, or a bare + transport error: :class:`QuotaExhausted` for 429, + :class:`ServiceInterrupted` for the rest. The resumable + handle is on ``exc.call`` — wait for the underlying + condition to clear and call ``exc.call.resume()`` again. """ concurrency = _read_concurrency_env() with start_blocking_portal() as portal: @@ -1613,8 +1614,9 @@ async def _run(self, max_concurrent: int | None) -> tuple[pd.DataFrame, Any]: Pending sub-requests (:meth:`_pending`) fan out under ``asyncio.gather`` with ``return_exceptions=True`` so completed - sub-requests survive a sibling's transient failure. On a recognized - transient (:class:`RateLimited`, :class:`ServiceUnavailable`) a + sub-requests survive a sibling's transient failure. On a + recognized transient (:class:`RateLimited`, :class:`ServiceUnavailable`, + or a bare ``httpx.HTTPError`` / ``httpx.InvalidURL``) a :class:`ChunkInterrupted` subclass is raised carrying ``self`` on ``.call``; ``exc.call.resume()`` then re-issues only the unfinished indices through this same runner. From ad1208ea1473a5e2f23a08124e1f99891644a4c2 Mon Sep 17 00:00:00 2001 From: thodson-usgs Date: Thu, 28 May 2026 08:32:30 -0500 Subject: [PATCH 13/16] docs(waterdata): editorial pass for readability on chunking.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Tighten seven prose-heavy passages where the wording made the reader back up to re-parse. No semantic changes. - RetryPolicy.__post_init__ comment: lead with intent ('catch invalid knobs at construction'); keep the why-not (asyncio.sleep silent vs time.sleep loud) as a clarifying parenthetical. - RetryPolicy.from_env docstring: split the comma-and-semicolon chain into one cleaner sentence; lead with the verb ('Reads...'). - _chunked_client comment: drop the 'across every gathered sub-request of the call' tail and the 'in that case' coda. - _set_response_url docstring: lead with the control flow (try direct first; on real responses, swap the bound request) rather than the read-only-vs-writable mechanism. - _retry_delay docstring: drop the stale 'sync and async drivers share it' line left over from before the async-only collapse; format the three None cases as an em-dash list. - ChunkedCall class docstring: split the long opener at the natural sentence boundary instead of trailing it with 'and ... and ... — used both for...'. - _pending docstring: replace the awkward 'The single source of the "walk ..., skip ..." rule' construction with a direct two-sentence description. Co-Authored-By: Claude Opus 4.7 (1M context) --- dataretrieval/waterdata/chunking.py | 67 ++++++++++++++--------------- 1 file changed, 32 insertions(+), 35 deletions(-) diff --git a/dataretrieval/waterdata/chunking.py b/dataretrieval/waterdata/chunking.py index 0a674634..f91f1a0b 100644 --- a/dataretrieval/waterdata/chunking.py +++ b/dataretrieval/waterdata/chunking.py @@ -222,10 +222,10 @@ class RetryPolicy: retry_after_cap: float = _RETRY_AFTER_CAP def __post_init__(self) -> None: - # Guard the value object's own invariants so a misconfiguration - # fails loudly at construction rather than as a downstream - # ``time.sleep`` ValueError (negative delay) or a silent - # asyncio.sleep-treats-negative-as-zero divergence. + # Catch invalid timing knobs here so a misconfiguration fails at + # construction, not deep in a later ``time.sleep`` (ValueError on + # a negative delay) or silently in ``asyncio.sleep`` (which + # treats negative as zero). if self.max_retries < 0: raise ValueError(f"max_retries must be >= 0 (got {self.max_retries}).") if self.base_backoff < 0 or self.max_backoff < 0 or self.retry_after_cap < 0: @@ -236,10 +236,10 @@ def from_env(cls) -> RetryPolicy: """ Build a policy from the module-level defaults, resolved now. - ``max_retries`` comes from ``API_USGS_RETRIES``; the timing knobs - are read from the ``_RETRY_*`` module constants at call time (not - the dataclass field defaults, which freeze at class definition) so - ``monkeypatch.setattr`` on those constants takes effect. + Reads ``max_retries`` from ``API_USGS_RETRIES`` and the timing + knobs from the ``_RETRY_*`` module constants at call time — not + the dataclass field defaults (which freeze at class definition) + — so test ``monkeypatch.setattr`` on the constants takes effect. Returns ------- @@ -308,12 +308,11 @@ def backoff(self, attempt: int, retry_after: float | None) -> float: _NO_RETRY = RetryPolicy(max_retries=0) -# The single shared ``httpx.AsyncClient`` of an in-flight chunked call, -# published (via :func:`_publish`) during ``ChunkedCall._run`` so the -# paginated-loop helpers downstream (``_walk_pages``) reuse one -# connection pool across every gathered sub-request of the call. ``None`` -# when not inside a chunked call — paginated helpers fall back to their -# own short-lived client in that case. +# Shared per-call ``httpx.AsyncClient``, published via :func:`_publish` +# during ``ChunkedCall._run`` so paginated-loop helpers (``_walk_pages``) +# reuse the same connection pool across every sub-request. ``None`` +# outside a chunked call — paginated helpers then open their own +# short-lived client. _chunked_client: ContextVar[httpx.AsyncClient | None] = ContextVar( "_chunked_client", default=None ) @@ -676,13 +675,12 @@ def _set_response_url(response: httpx.Response, url: str | httpx.URL) -> None: Overwrite the URL surfaced by a response without back-propagating the change into any aliased original. - On real ``httpx.Response`` instances ``.url`` is a read-only - property that resolves through the bound request; rather than - mutate the existing request's URL (which would be visible through - any shallow copy that shares the same ``.request``), we replace - the response's request with a fresh :class:`httpx.Request` carrying - the new URL. On lightweight test mocks ``.url`` is a plain - writable attribute — that path is tried first. + Try the direct assignment first: on lightweight test mocks ``.url`` + is a plain writable attribute. On real ``httpx.Response`` it's + read-only (it resolves through the bound request), so swap in a + fresh :class:`httpx.Request` carrying the new URL — mutating the + existing one would leak through any shallow copy that shares the + same ``.request``. """ try: response.url = url # type: ignore[misc] @@ -1141,12 +1139,11 @@ def _retry_delay(exc: BaseException, attempt: int, policy: RetryPolicy) -> float Decide the backoff for a just-failed ``attempt`` (1-based), or ``None`` to give up and re-raise. - Returns ``None`` when the error isn't a retryable transient, the policy - is exhausted, or the server's ``Retry-After`` is too long to absorb - inline (so it escalates to a resumable :class:`ChunkInterrupted`). - Otherwise returns the seconds to wait and emits the progress-bar retry - note. This is the whole retry *decision* — the sync and async drivers - share it and differ only in how they call the fetch and how they sleep. + Returns ``None`` in three cases — the error isn't a retryable + transient, the policy is exhausted, or the server's ``Retry-After`` + exceeds the cap (escalates to a resumable :class:`ChunkInterrupted` + instead). Otherwise returns the seconds to wait and emits the + progress-bar retry note. Parameters ---------- @@ -1339,10 +1336,10 @@ class ChunkedCall: Stateful handle for a chunked call. Holds the in-flight state (per-sub-request frames and responses) - and the async fetcher, and exposes a single :meth:`resume` entry - point that drives the call from wherever it is to completion — used - both for the first invocation (from :meth:`ChunkPlan.execute`) and - for subsequent retries after a :class:`ChunkInterrupted`. + and the async fetcher. A single :meth:`resume` entry point drives + the call from wherever it is to completion — used both for the + first invocation (from :meth:`ChunkPlan.execute`) and for subsequent + retries after a :class:`ChunkInterrupted`. :meth:`_run` gathers every pending sub-request over one shared :class:`httpx.AsyncClient`, applies the failure-precedence rules, and @@ -1553,10 +1550,10 @@ def _pending(self) -> Iterator[tuple[int, dict[str, Any]]]: """ Yield ``(index, sub_args)`` for sub-requests not yet completed. - The single source of the "walk :meth:`ChunkPlan.iter_sub_args` in - deterministic order, skip any index already in ``self._chunks``" - rule that :meth:`_run` uses to decide *which* sub-requests it - still owes (first run and every resume alike). + Walks :meth:`ChunkPlan.iter_sub_args` in deterministic order + and skips any index already in ``self._chunks``. :meth:`_run` + uses this to pick up exactly the sub-requests it still owes — + first run and every resume alike. Yields ------ From c63da1b835577328d32b8aefb4d53294deff31c1 Mon Sep 17 00:00:00 2001 From: thodson-usgs Date: Thu, 28 May 2026 10:11:47 -0400 Subject: [PATCH 14/16] refactor(waterdata): address PR 285 review (compact retry doc, drop _ASLEEP, inline trivial methods) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Module docstring (L29-31): apply suggested wording — drop the "isn't slept off inline" / "doesn't block the call" rationale; the one-line escalation statement is enough. - Drop the ``_ASLEEP = asyncio.sleep`` module-level test hook in favor of a direct ``await asyncio.sleep(delay)``; tests now patch ``chunking.asyncio.sleep`` (still scoped to the chunking module's asyncio binding, no extra indirection in production). - Inline ``ChunkedCall.record(index, pair)`` into the one call site in ``_run.track``; the "single writer of ``_chunks``" invariant moves to a comment on ``self._chunks`` initialization. - Inline ``ChunkedCall.combined()`` into ``_run``'s return; the ``partial_*`` bypass note moves to a comment at the return site, where it's more useful than buried in a removed helper's docstring. No behavior change; 296 offline tests pass. Co-Authored-By: Claude Opus 4.7 (1M context) --- dataretrieval/waterdata/chunking.py | 58 +++++------------------------ tests/waterdata_chunking_test.py | 24 ++++++------ 2 files changed, 21 insertions(+), 61 deletions(-) diff --git a/dataretrieval/waterdata/chunking.py b/dataretrieval/waterdata/chunking.py index f91f1a0b..34a8565d 100644 --- a/dataretrieval/waterdata/chunking.py +++ b/dataretrieval/waterdata/chunking.py @@ -26,9 +26,7 @@ 5xx, connect/read timeout) with exponential backoff + full jitter, honoring a server ``Retry-After`` when present. ``API_USGS_RETRIES`` sets the cap (default 4; ``0`` disables). A ``Retry-After`` longer -than the per-call ceiling isn't slept off inline — it escalates to -the resumable interruption below so a multi-minute quota-window -reset doesn't block the call. +than the per-call ceiling escalates to a resumable interruption. Interruption: any mid-stream transient failure — 429, 5xx, or a bare transport error (connect/read timeout, oversize follow-up URL) — surfaces @@ -1128,12 +1126,6 @@ def _retryable(exc: BaseException) -> tuple[bool, float | None]: return False, None -# Sleep hook, indirected through a module global so tests can -# ``monkeypatch.setattr`` it to a no-op instead of waiting for real -# backoff. Production uses the stdlib call. -_ASLEEP = asyncio.sleep - - def _retry_delay(exc: BaseException, attempt: int, policy: RetryPolicy) -> float | None: """ Decide the backoff for a just-failed ``attempt`` (1-based), or ``None`` @@ -1206,7 +1198,7 @@ async def _retry( delay = _retry_delay(exc, attempt, policy) if delay is None: raise - await _ASLEEP(delay) + await asyncio.sleep(delay) def _combine_chunk_frames(frames: list[pd.DataFrame]) -> pd.DataFrame: @@ -1401,26 +1393,10 @@ def __init__( self.finalize = finalize # Completed (frame, response) pairs keyed by sub-args index; sparse # (gathered sub-requests complete out of order — see class docstring). + # ``_run``'s ``track`` closure is the only writer, so ``dict`` insertion + # order is completion order (relied on by :meth:`_combine_raw`). self._chunks: dict[int, tuple[pd.DataFrame, httpx.Response]] = {} - def record(self, index: int, pair: tuple[pd.DataFrame, httpx.Response]) -> None: - """ - Record a completed sub-request's ``(frame, response)`` pair under - its sub-args index. - - The single writer of ``self._chunks`` — used by the gather in - :meth:`_run` — so ``dict`` insertion order is completion order - (see :meth:`_combine_raw`). - - Parameters - ---------- - index : int - The sub-args index this completed pair belongs to. - pair : tuple of (pandas.DataFrame, httpx.Response) - The completed sub-request's ``(frame, response)`` pair. - """ - self._chunks[index] = pair - def wrap_failure(self, exc: BaseException) -> ChunkInterrupted | None: """ Build the matching :class:`ChunkInterrupted` carrying this @@ -1483,25 +1459,6 @@ def _combine_raw(self) -> tuple[pd.DataFrame, httpx.Response]: _combine_chunk_responses(responses, self.plan.canonical_url), ) - def combined(self) -> tuple[pd.DataFrame, Any]: - """ - Combine every recorded sub-request and apply :attr:`finalize`. - - Returned by :meth:`_run` on a completed call (first run or - resume). The ``partial_*`` accessors deliberately do NOT route - through here — they return the raw :meth:`_combine_raw` snapshot - to stay cheap and side-effect-free. - - Returns - ------- - tuple of (pandas.DataFrame, finalized response) - The combined frame and whatever :attr:`finalize` produces — - a raw :class:`httpx.Response` by default, or the OGC - getters' type-coerced / column-arranged frame plus - ``BaseMetadata``. - """ - return self.finalize(*self._combine_raw()) - @property def partial_frame(self) -> pd.DataFrame: """ @@ -1669,7 +1626,7 @@ async def track( ) -> tuple[pd.DataFrame, httpx.Response]: """One sub-request (with retry) + record + progress tick.""" result = await _retry(lambda: self.fetch(args), self.retry_policy) - self.record(index, result) + self._chunks[index] = result if reporter is not None: # Chunks finish out of order under gather, so tick the # completed *count* rather than a positional index. @@ -1710,7 +1667,10 @@ async def track( interrupted, exc = first_transient raise interrupted from exc - return self.combined() + # Apply the injected ``finalize`` to the raw combined result. + # ``partial_frame`` / ``partial_response`` deliberately bypass + # ``finalize`` to stay cheap and side-effect-free. + return self.finalize(*self._combine_raw()) def multi_value_chunked( diff --git a/tests/waterdata_chunking_test.py b/tests/waterdata_chunking_test.py index ab23c075..0d1f615c 100644 --- a/tests/waterdata_chunking_test.py +++ b/tests/waterdata_chunking_test.py @@ -60,8 +60,8 @@ def _aiozero(_d): - """An async no-op sleep — monkeypatched over ``chunking._ASLEEP`` so - retry backoff doesn't actually wait in tests.""" + """An async no-op sleep — monkeypatched over the ``chunking`` module's + ``asyncio.sleep`` so retry backoff doesn't actually wait in tests.""" async def _noop(): return None @@ -1489,7 +1489,7 @@ def test_combine_chunk_responses_does_not_mutate_input_urls(): # --------------------------------------------------------------------------- # Retry-with-backoff: RetryPolicy + _retryable + drivers + decorator wiring. # Conftest pins API_USGS_RETRIES=0, so these tests opt in explicitly and -# patch chunking._SLEEP / chunking._ASLEEP to no-ops (no real backoff). +# patch the chunking module's ``asyncio.sleep`` to a no-op (no real backoff). # --------------------------------------------------------------------------- @@ -1601,7 +1601,7 @@ def test_retryable_skips_wrapped_midpagination_transient(): def test_retry_transient_then_recovers(monkeypatch): - monkeypatch.setattr(_chunking, "_ASLEEP", _aiozero) + monkeypatch.setattr(_chunking.asyncio, "sleep", _aiozero) calls = {"n": 0} async def afn(): @@ -1616,7 +1616,7 @@ async def afn(): def test_retry_exhausted_reraises(monkeypatch): - monkeypatch.setattr(_chunking, "_ASLEEP", _aiozero) + monkeypatch.setattr(_chunking.asyncio, "sleep", _aiozero) calls = {"n": 0} async def afn(): @@ -1635,7 +1635,7 @@ def _record(delay): slept.append(delay) return _aiozero(delay) - monkeypatch.setattr(_chunking, "_ASLEEP", _record) + monkeypatch.setattr(_chunking.asyncio, "sleep", _record) calls = {"n": 0} async def afn(): @@ -1654,7 +1654,7 @@ def _record(delay): slept.append(delay) return _aiozero(delay) - monkeypatch.setattr(_chunking, "_ASLEEP", _record) + monkeypatch.setattr(_chunking.asyncio, "sleep", _record) calls = {"n": 0} async def afn(): @@ -1673,7 +1673,7 @@ def test_retry_transient_then_success(monkeypatch): async def _noslept(_d): return None - monkeypatch.setattr(_chunking, "_ASLEEP", _noslept) + monkeypatch.setattr(_chunking.asyncio, "sleep", _noslept) calls = {"n": 0} async def afn(): @@ -1693,7 +1693,7 @@ def test_chunker_retries_transient_then_completes(monkeypatch): """A transient on one sub-request is retried transparently; the decorated call completes with no ChunkInterrupted.""" monkeypatch.setenv("API_USGS_RETRIES", "3") - monkeypatch.setattr(_chunking, "_ASLEEP", _aiozero) + monkeypatch.setattr(_chunking.asyncio, "sleep", _aiozero) state = {"failed": False} async def fetch(args): @@ -1713,7 +1713,7 @@ def test_chunker_exhausted_retries_still_resumable(monkeypatch): """When retries are exhausted the failure still surfaces as a resumable ChunkInterrupted — retries don't swallow the escape hatch.""" monkeypatch.setenv("API_USGS_RETRIES", "2") - monkeypatch.setattr(_chunking, "_ASLEEP", _aiozero) + monkeypatch.setattr(_chunking.asyncio, "sleep", _aiozero) attempts = {"n": 0} async def fetch(args): @@ -1737,7 +1737,7 @@ def test_async_fan_out_retries_transient_then_completes(monkeypatch): async def _noslept(_d): return None - monkeypatch.setattr(_chunking, "_ASLEEP", _noslept) + monkeypatch.setattr(_chunking.asyncio, "sleep", _noslept) state = {"failed": False} async def fetch_async(args): @@ -1759,7 +1759,7 @@ def test_async_fan_out_surfaces_fatal_over_transient(monkeypatch): async def _noslept(_d): return None - monkeypatch.setattr(_chunking, "_ASLEEP", _noslept) + monkeypatch.setattr(_chunking.asyncio, "sleep", _noslept) async def fetch_async(args): # One chunk carries a deterministic programmer error; the rest are From d9580051284d9088ad45d45e2c5ee5e9f8c85314 Mon Sep 17 00:00:00 2001 From: thodson-usgs Date: Thu, 28 May 2026 11:16:13 -0400 Subject: [PATCH 15/16] =?UTF-8?q?docs+tests(waterdata):=20apply=20/simplif?= =?UTF-8?q?y=20sweep=20=E2=80=94=20close=20stale=20refs=20and=20consolidat?= =?UTF-8?q?e=20sleep=20helpers?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Doc-vs-code drift left by the recent inlining + the 01c734b transport-error fix: - multi_value_chunked Raises (chunking.py:1716): now lists transport errors alongside 429/5xx — the surface the 01c734b sweep missed. - _combine_raw docstring (chunking.py:1444): "record" → "the track closure in _run" since record() was inlined in c63da1b. - track closure docstring (chunking.py:1627): "+ record +" → "+ result-store +" for the same reason. - _aiozero test helper docstring: tightened to "asyncio.sleep (via the chunking module's binding)" — chunking.asyncio IS the asyncio module. - Test section banner: "drivers" (plural) → "driver" (only one remaining after the async-only collapse). Simplifications: - Drop the redundant 3-line comment above the inlined `return self.finalize(*self._combine_raw())` — partial_frame's docstring and the class Attributes already say the same thing twice. - Test sleep-faker variants consolidated: replaced 3 inline `async def _noslept(_d): return None` blocks with the existing module-level `_aiozero`; replaced 2 inline `_record` closures with a new module-level `_recording_sleep(slept)` factory. Net −8 lines. 296 offline tests pass; ruff clean. Co-Authored-By: Claude Opus 4.7 (1M context) --- dataretrieval/waterdata/chunking.py | 20 ++++++------- tests/waterdata_chunking_test.py | 44 +++++++++++++---------------- 2 files changed, 28 insertions(+), 36 deletions(-) diff --git a/dataretrieval/waterdata/chunking.py b/dataretrieval/waterdata/chunking.py index 34a8565d..ab079070 100644 --- a/dataretrieval/waterdata/chunking.py +++ b/dataretrieval/waterdata/chunking.py @@ -1441,10 +1441,11 @@ def _combine_raw(self) -> tuple[pd.DataFrame, httpx.Response]: Frames concatenate in sub-args *index* order (``sorted`` keys — deterministic, independent of parallel completion order). The aggregated response takes its headers from the most-recently- - *completed* sub-request: ``record`` is the only writer of - ``self._chunks`` and ``dict`` preserves insertion order, so the - chunks' natural order is completion order and the last one carries - the freshest ``x-ratelimit-remaining``. + *completed* sub-request: the ``track`` closure in :meth:`_run` + is the only writer of ``self._chunks`` and ``dict`` preserves + insertion order, so the chunks' natural order is completion + order and the last one carries the freshest + ``x-ratelimit-remaining``. Returns ------- @@ -1624,7 +1625,7 @@ async def _run(self, max_concurrent: int | None) -> tuple[pd.DataFrame, Any]: async def track( index: int, args: dict[str, Any] ) -> tuple[pd.DataFrame, httpx.Response]: - """One sub-request (with retry) + record + progress tick.""" + """One sub-request (with retry) + result-store + progress tick.""" result = await _retry(lambda: self.fetch(args), self.retry_policy) self._chunks[index] = result if reporter is not None: @@ -1667,9 +1668,6 @@ async def track( interrupted, exc = first_transient raise interrupted from exc - # Apply the injected ``finalize`` to the raw combined result. - # ``partial_frame`` / ``partial_response`` deliberately bypass - # ``finalize`` to stay cheap and side-effect-free. return self.finalize(*self._combine_raw()) @@ -1713,9 +1711,9 @@ def multi_value_chunked( RequestTooLarge If no plan can fit ``url_limit``. ChunkInterrupted - On a mid-execution 429 (:class:`QuotaExhausted`) or 5xx - (:class:`ServiceInterrupted`). See :class:`ChunkedCall` for - the resume semantics. + On a mid-execution transient — 429, 5xx, or a bare transport + error: :class:`QuotaExhausted` for 429, :class:`ServiceInterrupted` + for the rest. See :class:`ChunkedCall` for the resume semantics. See Also -------- diff --git a/tests/waterdata_chunking_test.py b/tests/waterdata_chunking_test.py index 0d1f615c..5dbacfbf 100644 --- a/tests/waterdata_chunking_test.py +++ b/tests/waterdata_chunking_test.py @@ -60,8 +60,8 @@ def _aiozero(_d): - """An async no-op sleep — monkeypatched over the ``chunking`` module's - ``asyncio.sleep`` so retry backoff doesn't actually wait in tests.""" + """An async no-op sleep — monkeypatched over ``asyncio.sleep`` (via + the chunking module's binding) so retry backoff doesn't wait in tests.""" async def _noop(): return None @@ -69,6 +69,17 @@ async def _noop(): return _noop() +def _recording_sleep(slept): + """An ``_aiozero`` variant that appends each requested delay to ``slept`` + before resolving — for tests that need to assert what would have been waited.""" + + def _record(delay): + slept.append(delay) + return _aiozero(delay) + + return _record + + class _FakeReq: """Stand-in for ``httpx.Request`` whose ``_request_bytes`` shape is ``len(str(url)) + len(content)``.""" @@ -1487,7 +1498,7 @@ def test_combine_chunk_responses_does_not_mutate_input_urls(): # --------------------------------------------------------------------------- -# Retry-with-backoff: RetryPolicy + _retryable + drivers + decorator wiring. +# Retry-with-backoff: RetryPolicy + _retryable + driver + decorator wiring. # Conftest pins API_USGS_RETRIES=0, so these tests opt in explicitly and # patch the chunking module's ``asyncio.sleep`` to a no-op (no real backoff). # --------------------------------------------------------------------------- @@ -1631,11 +1642,7 @@ async def afn(): def test_retry_non_retryable_not_retried(monkeypatch): slept: list[float] = [] - def _record(delay): - slept.append(delay) - return _aiozero(delay) - - monkeypatch.setattr(_chunking.asyncio, "sleep", _record) + monkeypatch.setattr(_chunking.asyncio, "sleep", _recording_sleep(slept)) calls = {"n": 0} async def afn(): @@ -1650,11 +1657,7 @@ async def afn(): def test_retry_long_retry_after_escalates(monkeypatch): slept: list[float] = [] - def _record(delay): - slept.append(delay) - return _aiozero(delay) - - monkeypatch.setattr(_chunking.asyncio, "sleep", _record) + monkeypatch.setattr(_chunking.asyncio, "sleep", _recording_sleep(slept)) calls = {"n": 0} async def afn(): @@ -1670,10 +1673,7 @@ async def afn(): def test_retry_transient_then_success(monkeypatch): - async def _noslept(_d): - return None - - monkeypatch.setattr(_chunking.asyncio, "sleep", _noslept) + monkeypatch.setattr(_chunking.asyncio, "sleep", _aiozero) calls = {"n": 0} async def afn(): @@ -1734,10 +1734,7 @@ def test_async_fan_out_retries_transient_then_completes(monkeypatch): """The parallel path retries a transient sub-request and completes.""" monkeypatch.setenv("API_USGS_RETRIES", "3") - async def _noslept(_d): - return None - - monkeypatch.setattr(_chunking.asyncio, "sleep", _noslept) + monkeypatch.setattr(_chunking.asyncio, "sleep", _aiozero) state = {"failed": False} async def fetch_async(args): @@ -1756,10 +1753,7 @@ def test_async_fan_out_surfaces_fatal_over_transient(monkeypatch): being masked behind a resumable interruption from a transient sibling.""" monkeypatch.setenv("API_USGS_RETRIES", "2") - async def _noslept(_d): - return None - - monkeypatch.setattr(_chunking.asyncio, "sleep", _noslept) + monkeypatch.setattr(_chunking.asyncio, "sleep", _aiozero) async def fetch_async(args): # One chunk carries a deterministic programmer error; the rest are From 6da534db4dd632362350d2823b32e9446a089ac7 Mon Sep 17 00:00:00 2001 From: thodson-usgs Date: Thu, 28 May 2026 11:40:26 -0400 Subject: [PATCH 16/16] feat(waterdata): layered config with WaterDataConfig + RetryPolicy + ConcurrencyPolicy MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Centralizes the four runtime knobs the Water Data getters consult (retry budget, concurrency cap, API token, progress mode) behind one config object with the conventional precedence rule used by git / npm / pip / cargo — closer to the action wins: 1. Defaults (dataclass field defaults) 2. User config file (~/.config/dataretrieval/config.cfg) 3. Local config file (./dataretrieval.cfg) 4. Environment variables (API_USGS_*) 5. Python override (ContextVar-scoped `override()` context manager) New `dataretrieval/waterdata/_config.py`: - `RetryPolicy` (moved from chunking.py). - `ConcurrencyPolicy` (new — mirrors RetryPolicy for the concurrency cap; replaces `_read_concurrency_env`). - `WaterDataConfig` composes them + api_token + progress. - `WaterDataConfig.load()` runs the precedence chain. - `current()` returns the active override (if any) or freshly loads. - `override(WaterDataConfig)` / `set_config(...)` for Python-side runtime overrides. - Stdlib `configparser` for the INI files — no new deps. (If TOML + pyproject.toml integration is later wanted, `tomli` is a small conditional dep follow-up.) Wired through: - chunking.py: imports `RetryPolicy` / `ConcurrencyPolicy` from `_config` (re-export keeps the `from chunking import RetryPolicy` path working). `_read_concurrency_env` and `_read_retries_env` removed. - utils._default_headers reads `api_token` via `_config.current()`. - _progress reads progress mode + api-key hint via `_config.current()`. Backward compatibility: every existing env var keeps working unchanged (env layer reads them at call time). Existing tests pass with no changes except (a) the brittle `monkeypatch.setattr(_chunking, "_RETRY_BASE_BACKOFF", 0.0)` test replaced with one that uses the new `override()` mechanism, (b) one `_chunking._RETRIES_DEFAULT` reference re-pointed at `_config._RETRIES_DEFAULT`. 15 new tests covering each precedence layer, the override context manager (including nesting), parsing edge cases (case-insensitive "unbounded", invalid values rejected, progress parser truthy/falsy/auto), and unknown-key tolerance. Total offline suite: 311 passed (+15). Co-Authored-By: Claude Opus 4.7 (1M context) --- dataretrieval/waterdata/_config.py | 447 +++++++++++++++++++++++++++ dataretrieval/waterdata/_progress.py | 21 +- dataretrieval/waterdata/chunking.py | 201 +----------- dataretrieval/waterdata/utils.py | 11 +- tests/waterdata_chunking_test.py | 25 +- tests/waterdata_config_test.py | 281 +++++++++++++++++ 6 files changed, 771 insertions(+), 215 deletions(-) create mode 100644 dataretrieval/waterdata/_config.py create mode 100644 tests/waterdata_config_test.py diff --git a/dataretrieval/waterdata/_config.py b/dataretrieval/waterdata/_config.py new file mode 100644 index 00000000..a75e0b5b --- /dev/null +++ b/dataretrieval/waterdata/_config.py @@ -0,0 +1,447 @@ +"""Layered configuration for the Water Data getters. + +The Water Data module has a few runtime knobs — retry budget, concurrency +cap, API token, progress bar mode. Historically each was read directly from +its env var at the call site. This module gathers them behind one config +object with the conventional precedence rule used by ``git`` / ``npm`` / +``pip`` / ``cargo`` (closer to the action wins): + +1. **Defaults** — the dataclass field defaults below. +2. **User config file** — ``$XDG_CONFIG_HOME/dataretrieval/config.cfg`` + (default ``~/.config/dataretrieval/config.cfg``) on Linux/macOS, + ``%APPDATA%\\dataretrieval\\config.cfg`` on Windows. +3. **Local config file** — ``./dataretrieval.cfg`` in the current working + directory. +4. **Environment variables** — ``API_USGS_RETRIES``, + ``API_USGS_CONCURRENT``, ``API_USGS_PAT``, ``API_USGS_PROGRESS``. +5. **Python override** — :func:`override` (a ContextVar-scoped context + manager) or :func:`set_config` (process-wide). + +Config-file schema (stdlib INI via :mod:`configparser`):: + + [default] + api_token = ... + progress = auto # on | off | auto + + [retry] + max_retries = 4 + base_backoff = 0.5 + max_backoff = 30.0 + retry_after_cap = 60.0 + + [concurrency] + max_connections = 16 # int >= 1, or the string "unbounded" + +Stdlib-only — no TOML / YAML deps. If/when pyproject.toml integration +becomes valuable (e.g. ``[tool.dataretrieval]``), adding ``tomli`` as a +conditional dep is a small mechanical follow-up. + +Backward compatibility: every existing env var keeps working unchanged, +and the legacy ``RetryPolicy.from_env()`` / ``ConcurrencyPolicy.from_env()`` +factories still build per-call from the layered loader. +""" + +from __future__ import annotations + +import configparser +import os +import random +import sys +from collections.abc import Iterator +from contextlib import contextmanager +from contextvars import ContextVar +from dataclasses import dataclass, field, fields +from pathlib import Path +from typing import Any + +# ---- env-var names (single source of truth) ------------------------------- + +ENV_RETRIES = "API_USGS_RETRIES" +ENV_CONCURRENT = "API_USGS_CONCURRENT" +ENV_PAT = "API_USGS_PAT" +ENV_PROGRESS = "API_USGS_PROGRESS" + +# Sentinel string in ``API_USGS_CONCURRENT`` or ``concurrency.max_connections`` +# meaning "no cap on simultaneous connections". +CONCURRENCY_UNBOUNDED = "unbounded" + + +# ---- defaults the dataclass fields ship with ----------------------------- +# Values mirror the legacy ``_RETRY_*`` module constants so callers porting +# from the old shape see the same numbers. + +_RETRIES_DEFAULT = 4 +_RETRY_BASE_BACKOFF = 0.5 +_RETRY_MAX_BACKOFF = 30.0 +_RETRY_AFTER_CAP = 60.0 +_CONCURRENCY_DEFAULT = 16 + + +# ---- file locations ------------------------------------------------------- + +_CONFIG_BASENAME = "config.cfg" +_USER_CONFIG_SUBDIR = "dataretrieval" +_LOCAL_CONFIG_NAME = "dataretrieval.cfg" + +# Section names in the INI files. ``DEFAULTS_SECTION`` carries the +# top-level scalars (api_token, progress); the policy sub-tables get their +# own sections. +DEFAULTS_SECTION = "default" +RETRY_SECTION = "retry" +CONCURRENCY_SECTION = "concurrency" + + +def _user_config_path() -> Path: + """Cross-platform user config path. Honors ``XDG_CONFIG_HOME``.""" + if sys.platform == "win32": # pragma: no cover - exercised on Windows CI + base = os.environ.get("APPDATA") or str(Path.home() / "AppData/Roaming") + return Path(base) / _USER_CONFIG_SUBDIR / _CONFIG_BASENAME + xdg = os.environ.get("XDG_CONFIG_HOME") + base = Path(xdg) if xdg else Path.home() / ".config" + return base / _USER_CONFIG_SUBDIR / _CONFIG_BASENAME + + +def _local_config_path() -> Path: + """The local config file in the current working directory.""" + return Path.cwd() / _LOCAL_CONFIG_NAME + + +# ---- dataclasses ---------------------------------------------------------- + + +@dataclass(frozen=True) +class RetryPolicy: + """Retry-with-backoff timing knobs for one sub-request. + + Frozen value object. Construct directly to override per-call, or use + :meth:`from_env` to build from the layered config (defaults → user file + → local file → env vars → Python override). + + Attributes + ---------- + max_retries : int + Maximum retries per sub-request (``0`` disables retry entirely). + base_backoff : float + First-retry delay in seconds; subsequent retries double under full + jitter up to :attr:`max_backoff`. + max_backoff : float + Per-attempt ceiling on the slept delay (seconds). + retry_after_cap : float + If the server sends ``Retry-After`` greater than this (seconds), + the failure escalates to a resumable :class:`ChunkInterrupted` + instead of blocking the gather inline. + """ + + max_retries: int = _RETRIES_DEFAULT + base_backoff: float = _RETRY_BASE_BACKOFF + max_backoff: float = _RETRY_MAX_BACKOFF + retry_after_cap: float = _RETRY_AFTER_CAP + + def __post_init__(self) -> None: + # Catch invalid timing knobs here so a misconfiguration fails at + # construction, not deep in a later ``time.sleep`` (ValueError on + # a negative delay) or silently in ``asyncio.sleep`` (which + # treats negative as zero). + if self.max_retries < 0: + raise ValueError(f"max_retries must be >= 0 (got {self.max_retries}).") + if self.base_backoff < 0 or self.max_backoff < 0 or self.retry_after_cap < 0: + raise ValueError("retry backoff settings must be non-negative.") + + @classmethod + def from_env(cls) -> RetryPolicy: + """Build from the layered config (defaults → user file → local + file → env vars → Python override).""" + return current().retry + + def should_retry(self, attempt: int, retry_after: float | None) -> bool: + """Whether ``attempt`` should be retried under this policy. + + Returns ``False`` if the policy is exhausted or if the server's + ``Retry-After`` (seconds) exceeds :attr:`retry_after_cap`. + """ + if attempt > self.max_retries: + return False + return retry_after is None or retry_after <= self.retry_after_cap + + def backoff(self, attempt: int, retry_after: float | None) -> float: + """Seconds to wait before the next retry of ``attempt``. + + Honor server ``Retry-After`` when present (already filtered by + :meth:`should_retry` against :attr:`retry_after_cap`). Otherwise: + exponential ``base_backoff * 2 ** (attempt - 1)`` capped at + :attr:`max_backoff`, then full-jitter randomized in ``[0, capped]``. + """ + if retry_after is not None: + return retry_after + ceiling = min(self.max_backoff, self.base_backoff * 2 ** (attempt - 1)) + return random.uniform(0.0, ceiling) + + +@dataclass(frozen=True) +class ConcurrencyPolicy: + """Simultaneous-connection cap for chunked sub-request gather. + + ``max_connections=None`` means uncapped — the gather dispatches every + pending sub-request and lets ``httpx`` open as many connections as + needed. + """ + + max_connections: int | None = _CONCURRENCY_DEFAULT + + def __post_init__(self) -> None: + if self.max_connections is not None and self.max_connections < 1: + raise ValueError( + f"max_connections must be >= 1 or None (got {self.max_connections})." + ) + + @classmethod + def from_env(cls) -> ConcurrencyPolicy: + """Build from the layered config — convenience for callers that + only need the concurrency piece.""" + return current().concurrency + + +@dataclass(frozen=True) +class WaterDataConfig: + """Top-level config composing every runtime knob the Water Data + getters consult. + + Construct directly, or use :meth:`load` to build via the precedence + layering. :func:`current` returns the in-effect config (the active + override if one is set, else a freshly :meth:`load`-ed one). + """ + + retry: RetryPolicy = field(default_factory=RetryPolicy) + concurrency: ConcurrencyPolicy = field(default_factory=ConcurrencyPolicy) + api_token: str | None = None + # ``None`` = auto-detect (TTY-driven); ``True`` / ``False`` are explicit + # overrides. Mirrors the legacy ``API_USGS_PROGRESS`` semantics: + # ``"on"`` / ``"true"`` / ``"1"`` → True, ``"off"`` / ``"false"`` / + # ``"0"`` → False, anything else → ``None`` (auto). + progress: bool | None = None + + @classmethod + def load(cls) -> WaterDataConfig: + """Build from the precedence chain — defaults → user file → local + file → env vars (does NOT consult :func:`override`).""" + merged: dict[str, Any] = {} + for layer in (_load_user_file(), _load_local_file(), _load_env()): + _deep_update(merged, layer) + return _from_mapping(merged) + + +# ---- ContextVar-based runtime override ----------------------------------- + +_active: ContextVar[WaterDataConfig | None] = ContextVar( + "waterdata_active_config", default=None +) + + +def current() -> WaterDataConfig: + """Return the active :class:`WaterDataConfig` — the override set via + :func:`override` / :func:`set_config` if any, else freshly loaded + from the file + env precedence chain.""" + active = _active.get() + if active is not None: + return active + return WaterDataConfig.load() + + +def set_config(config: WaterDataConfig | None) -> None: + """Pin a config process-wide (or ``None`` to clear the pin and fall + back to the layered loader). Test/notebook convenience — prefer + :func:`override` for scoped use.""" + _active.set(config) + + +@contextmanager +def override(config: WaterDataConfig) -> Iterator[None]: + """Pin ``config`` as the active config for the duration of the + ``with`` block (thread-safe and async-safe via ``ContextVar``):: + + with override(WaterDataConfig(retry=RetryPolicy(max_retries=10))): + ... # every call inside sees the override + """ + token = _active.set(config) + try: + yield + finally: + _active.reset(token) + + +# ---- file readers --------------------------------------------------------- + + +def _read_ini(path: Path) -> dict[str, dict[str, str]]: + """Read an INI file; return ``{section: {key: value}}``. Missing file + returns ``{}``. A malformed file raises :class:`configparser.Error`.""" + if not path.exists(): + return {} + parser = configparser.ConfigParser( + # Don't promote the ``[DEFAULT]`` section into every other section + # (the configparser default behavior). Each section stands alone. + default_section="__never_used__", + ) + parser.read(path) + return {section: dict(parser[section]) for section in parser.sections()} + + +def _ini_to_mapping(sections: dict[str, dict[str, str]]) -> dict[str, Any]: + """Convert raw INI string values into the shape ``_from_mapping`` expects. + + ``[default]`` keys become top-level entries (``api_token``, ``progress``); + ``[retry]`` and ``[concurrency]`` become sub-tables. String values are + parsed into the appropriate types (int / float / bool / None). + """ + out: dict[str, Any] = {} + # Top-level scalars from [default]. + defaults = sections.get(DEFAULTS_SECTION, {}) + if "api_token" in defaults: + out["api_token"] = defaults["api_token"] + if "progress" in defaults: + parsed = _parse_progress(defaults["progress"]) + if parsed is not None or defaults["progress"].strip().lower() in ("auto", ""): + out["progress"] = parsed + # Retry sub-table. + retry_raw = sections.get(RETRY_SECTION, {}) + if retry_raw: + retry: dict[str, Any] = {} + if "max_retries" in retry_raw: + retry["max_retries"] = _coerce_int(retry_raw["max_retries"], "max_retries") + for k in ("base_backoff", "max_backoff", "retry_after_cap"): + if k in retry_raw: + retry[k] = _coerce_float(retry_raw[k], k) + if retry: + out["retry"] = retry + # Concurrency sub-table. + conc_raw = sections.get(CONCURRENCY_SECTION, {}) + if "max_connections" in conc_raw: + out["concurrency"] = { + "max_connections": _parse_concurrency(conc_raw["max_connections"]) + } + return out + + +def _load_user_file() -> dict[str, Any]: + """Layer 2 — ``~/.config/dataretrieval/config.cfg`` (or platform + equivalent).""" + return _ini_to_mapping(_read_ini(_user_config_path())) + + +def _load_local_file() -> dict[str, Any]: + """Layer 3 — ``./dataretrieval.cfg`` in the current working directory.""" + return _ini_to_mapping(_read_ini(_local_config_path())) + + +def _load_env() -> dict[str, Any]: + """Layer 4 — the four ``API_USGS_*`` env vars, mapped onto the + dataclass shape.""" + out: dict[str, Any] = {} + if (raw := os.environ.get(ENV_RETRIES)) is not None: + out["retry"] = {"max_retries": _coerce_int(raw, ENV_RETRIES)} + if (raw := os.environ.get(ENV_CONCURRENT)) is not None: + out["concurrency"] = {"max_connections": _parse_concurrency(raw)} + if (raw := os.environ.get(ENV_PAT)) is not None: + out["api_token"] = raw + if (raw := os.environ.get(ENV_PROGRESS)) is not None: + parsed = _parse_progress(raw) + if parsed is not None or raw.strip().lower() in ("auto", ""): + out["progress"] = parsed + return out + + +# ---- mapping → dataclass -------------------------------------------------- + + +def _from_mapping(mapping: dict[str, Any]) -> WaterDataConfig: + """Construct a :class:`WaterDataConfig` from a merged mapping, + silently dropping unknown keys so a stray entry doesn't crash + construction. Sub-tables (``retry``, ``concurrency``) feed their + own dataclass constructors via :func:`_filter_kwargs`.""" + retry_kw = _filter_kwargs(mapping.get("retry", {}) or {}, RetryPolicy) + conc_kw = _filter_kwargs(mapping.get("concurrency", {}) or {}, ConcurrencyPolicy) + return WaterDataConfig( + retry=RetryPolicy(**retry_kw), + concurrency=ConcurrencyPolicy(**conc_kw), + api_token=mapping.get("api_token"), + progress=mapping.get("progress"), + ) + + +def _filter_kwargs(mapping: dict[str, Any], cls: type) -> dict[str, Any]: + """Keep only keys that match dataclass fields of ``cls``.""" + known = {f.name for f in fields(cls)} + return {k: v for k, v in mapping.items() if k in known} + + +# ---- helpers -------------------------------------------------------------- + + +def _deep_update(dst: dict[str, Any], src: dict[str, Any]) -> None: + """Recursive dict merge: nested dicts merge; scalars overwrite. Used + to layer config sources without dropping unset sub-table keys.""" + for k, v in src.items(): + if isinstance(v, dict) and isinstance(dst.get(k), dict): + _deep_update(dst[k], v) + else: + dst[k] = v + + +def _coerce_int(raw: str | int, name: str) -> int: + if isinstance(raw, int): + return raw + try: + return int(raw) + except ValueError as e: + raise ValueError(f"{name} must be an integer (got {raw!r}).") from e + + +def _coerce_float(raw: str | float, name: str) -> float: + if isinstance(raw, (int, float)): + return float(raw) + try: + return float(raw) + except ValueError as e: + raise ValueError(f"{name} must be a number (got {raw!r}).") from e + + +def _parse_concurrency(raw: str | int) -> int | None: + """Parse a concurrency value. The literal ``"unbounded"`` + (case-insensitive) → ``None``; anything else must parse as an int >= 1.""" + if isinstance(raw, str) and raw.strip().lower() == CONCURRENCY_UNBOUNDED: + return None + n = _coerce_int(raw, ENV_CONCURRENT) + if n < 1: + raise ValueError( + f"{ENV_CONCURRENT} must be >= 1 or '{CONCURRENCY_UNBOUNDED}' (got {raw!r})." + ) + return n + + +def _parse_progress(raw: str) -> bool | None: + """Parse a progress preference. ``"on"`` / ``"true"`` / ``"1"`` → + True; ``"off"`` / ``"false"`` / ``"0"`` → False; ``"auto"`` / + ``""`` / anything else → ``None`` (auto).""" + s = raw.strip().lower() + if s in ("on", "true", "1", "yes"): + return True + if s in ("off", "false", "0", "no"): + return False + return None + + +# ---- public re-exports ---------------------------------------------------- + +__all__ = [ + "RetryPolicy", + "ConcurrencyPolicy", + "WaterDataConfig", + "current", + "override", + "set_config", + "ENV_RETRIES", + "ENV_CONCURRENT", + "ENV_PAT", + "ENV_PROGRESS", + "CONCURRENCY_UNBOUNDED", +] diff --git a/dataretrieval/waterdata/_progress.py b/dataretrieval/waterdata/_progress.py index ce94effb..5122f52e 100644 --- a/dataretrieval/waterdata/_progress.py +++ b/dataretrieval/waterdata/_progress.py @@ -25,7 +25,6 @@ from __future__ import annotations import contextvars -import os import sys from collections.abc import Iterator from contextlib import contextmanager @@ -80,13 +79,17 @@ def _in_jupyter_kernel() -> bool: def _enabled_default(stream: TextIO) -> bool: """Whether to draw the line by default. - ``API_USGS_PROGRESS`` wins when set. Otherwise show it for interactive use — - a TTY or a Jupyter/IPython kernel — and stay quiet for redirected output, - logs, and CI. + An explicit on/off from the layered config wins (set via + ``$API_USGS_PROGRESS``, a config file's ``[default] progress``, or a + Python override — see :mod:`._config`). Otherwise show it for + interactive use — a TTY or a Jupyter/IPython kernel — and stay quiet + for redirected output, logs, and CI. """ - override = os.getenv("API_USGS_PROGRESS") - if override is not None: - return override.strip().lower() not in {"", "0", "false", "no", "off"} + from . import _config # local import to avoid an import cycle at module load + + explicit = _config.current().progress + if explicit is not None: + return explicit if _in_jupyter_kernel(): return True return hasattr(stream, "isatty") and stream.isatty() @@ -252,8 +255,10 @@ def close(self) -> None: self.enabled = False def _maybe_hint_api_key(self) -> None: + from . import _config # local import to avoid an import cycle at module load + global _api_key_hint_shown - if _api_key_hint_shown or os.getenv("API_USGS_PAT"): + if _api_key_hint_shown or _config.current().api_token: return # Set the once-per-process latch only after a successful write, so a # failed write (broken pipe) doesn't silently burn the hint for every diff --git a/dataretrieval/waterdata/chunking.py b/dataretrieval/waterdata/chunking.py index ab079070..f0aec9e8 100644 --- a/dataretrieval/waterdata/chunking.py +++ b/dataretrieval/waterdata/chunking.py @@ -52,8 +52,6 @@ import functools import itertools import math -import os -import random from collections.abc import Awaitable, Callable, Iterator from contextlib import contextmanager, suppress from contextvars import ContextVar @@ -108,197 +106,12 @@ # Response header USGS uses to advertise remaining hourly quota. _QUOTA_HEADER = "x-ratelimit-remaining" -# Fan-out concurrency cap, read at call time (not import) so test -# ``monkeypatch.setenv`` applies. Value grammar in :func:`_read_concurrency_env`; -# the concurrency model is in the module docstring. -_CONCURRENCY_ENV = "API_USGS_CONCURRENT" -_CONCURRENCY_DEFAULT = 16 -_CONCURRENCY_UNBOUNDED = "unbounded" - - -def _read_concurrency_env() -> int | None: - """ - Resolve the ``API_USGS_CONCURRENT`` env var to a parallelism cap. - - Returns - ------- - int or None - ``1`` for a single connection; an integer >1 for bounded - concurrency; ``None`` to disable the per-call cap entirely - (``unbounded`` keyword). Unset → default of - ``_CONCURRENCY_DEFAULT``. - """ - raw = os.environ.get(_CONCURRENCY_ENV) - if raw is None: - return _CONCURRENCY_DEFAULT - raw = raw.strip() - if raw == "": - return _CONCURRENCY_DEFAULT - if raw.lower() == _CONCURRENCY_UNBOUNDED: - return None - try: - value = int(raw) - except ValueError as exc: - raise ValueError( - f"{_CONCURRENCY_ENV} must be a positive integer or " - f"'{_CONCURRENCY_UNBOUNDED}'; got {raw!r}." - ) from exc - if value < 1: - raise ValueError( - f"{_CONCURRENCY_ENV} must be >= 1 (got {value}); use " - f"'{_CONCURRENCY_UNBOUNDED}' to disable the cap." - ) - return value - - -# Retry-with-backoff defaults for transient sub-request failures (429 / -# 5xx / connect-read timeouts): exponential backoff with full jitter, and -# honor a server ``Retry-After`` up to the cap below before escalating -# to a resumable interruption instead. -_RETRIES_ENV = "API_USGS_RETRIES" -_RETRIES_DEFAULT = 4 -_RETRY_BASE_BACKOFF = 0.5 -_RETRY_MAX_BACKOFF = 30.0 -_RETRY_AFTER_CAP = 60.0 - - -def _read_retries_env() -> int: - """ - Resolve the ``API_USGS_RETRIES`` env var to a max-retry count. - - Returns - ------- - int - Number of retries after the first attempt; ``0`` disables - retrying. Unset/blank → ``_RETRIES_DEFAULT``. - """ - raw = os.environ.get(_RETRIES_ENV) - if raw is None or raw.strip() == "": - return _RETRIES_DEFAULT - try: - value = int(raw.strip()) - except ValueError as exc: - raise ValueError( - f"{_RETRIES_ENV} must be a non-negative integer (got {raw!r})." - ) from exc - if value < 0: - raise ValueError(f"{_RETRIES_ENV} must be >= 0 (got {value}).") - return value - - -@dataclass(frozen=True) -class RetryPolicy: - """Bounded retry-with-backoff config for transient sub-request failures. - - An immutable value object that owns the *timing* decisions; the - exception taxonomy (which failures are retryable) lives in - :func:`_retryable`. Backoff is exponential with **full jitter** - (:func:`random.uniform` over ``[0, ceiling]``) so the concurrent - fan-out's retries don't re-burst in lockstep. A server ``Retry-After`` - hint, when present, overrides the computed backoff — unless it exceeds - :attr:`retry_after_cap`, in which case retrying stops and the failure - surfaces as a resumable :class:`ChunkInterrupted` (a multi-minute - quota-window reset shouldn't block the call inline). - - Attributes - ---------- - max_retries : int - Retries attempted after the first try; ``0`` disables retrying. - base_backoff : float - Seconds; the jitter ceiling for the first retry, doubled each - subsequent attempt. - max_backoff : float - Upper bound on any single attempt's backoff ceiling. - retry_after_cap : float - Largest ``Retry-After`` (seconds) honored inline; longer hints - escalate to a resumable interruption. - """ - - max_retries: int = _RETRIES_DEFAULT - base_backoff: float = _RETRY_BASE_BACKOFF - max_backoff: float = _RETRY_MAX_BACKOFF - retry_after_cap: float = _RETRY_AFTER_CAP - - def __post_init__(self) -> None: - # Catch invalid timing knobs here so a misconfiguration fails at - # construction, not deep in a later ``time.sleep`` (ValueError on - # a negative delay) or silently in ``asyncio.sleep`` (which - # treats negative as zero). - if self.max_retries < 0: - raise ValueError(f"max_retries must be >= 0 (got {self.max_retries}).") - if self.base_backoff < 0 or self.max_backoff < 0 or self.retry_after_cap < 0: - raise ValueError("retry backoff settings must be non-negative.") - - @classmethod - def from_env(cls) -> RetryPolicy: - """ - Build a policy from the module-level defaults, resolved now. - - Reads ``max_retries`` from ``API_USGS_RETRIES`` and the timing - knobs from the ``_RETRY_*`` module constants at call time — not - the dataclass field defaults (which freeze at class definition) - — so test ``monkeypatch.setattr`` on the constants takes effect. - - Returns - ------- - RetryPolicy - A policy built from the module-level defaults resolved at - call time. - """ - return cls( - max_retries=_read_retries_env(), - base_backoff=_RETRY_BASE_BACKOFF, - max_backoff=_RETRY_MAX_BACKOFF, - retry_after_cap=_RETRY_AFTER_CAP, - ) - - def should_retry(self, attempt: int, retry_after: float | None) -> bool: - """ - Whether a just-failed ``attempt`` (1-based) warrants another try. - - A ``Retry-After`` longer than ``retry_after_cap`` is *not* slept - off inline — it returns ``False`` so the failure escalates to a - resumable interruption instead of blocking the call for minutes. - - Parameters - ---------- - attempt : int - The just-failed attempt number (1-based). - retry_after : float or None - Seconds the server suggested waiting (``Retry-After`` hint), - or ``None`` when no hint was given. - - Returns - ------- - bool - ``True`` if another try is warranted, ``False`` otherwise. - """ - if attempt > self.max_retries: - return False - return retry_after is None or retry_after <= self.retry_after_cap - - def backoff(self, attempt: int, retry_after: float | None) -> float: - """ - Seconds to wait before retry ``attempt`` (1-based). - - Parameters - ---------- - attempt : int - The retry attempt number (1-based). - retry_after : float or None - Seconds the server suggested waiting (``Retry-After`` hint), - or ``None`` to use the computed exponential backoff instead. - - Returns - ------- - float - Seconds to wait before the retry. - """ - if retry_after is not None: - return retry_after - ceiling = min(self.max_backoff, self.base_backoff * 2 ** (attempt - 1)) - return random.uniform(0.0, ceiling) - +# ``RetryPolicy`` and ``ConcurrencyPolicy`` are the value-object knobs the +# chunker reads at call time. They live in :mod:`._config`, which layers +# defaults → user file → local file → env vars → Python override (see that +# module's docstring). Re-exported here so legacy callers / tests keep +# ``from dataretrieval.waterdata.chunking import RetryPolicy`` working. +from ._config import ConcurrencyPolicy, RetryPolicy # noqa: E402 # Default for direct ``ChunkedCall`` / ``ChunkPlan.execute`` construction # (and tests): no retrying. The production decorator path explicitly passes @@ -1558,7 +1371,7 @@ def resume(self) -> tuple[pd.DataFrame, Any]: handle is on ``exc.call`` — wait for the underlying condition to clear and call ``exc.call.resume()`` again. """ - concurrency = _read_concurrency_env() + concurrency = ConcurrencyPolicy.from_env().max_connections with start_blocking_portal() as portal: return portal.call(functools.partial(self._run, concurrency)) diff --git a/dataretrieval/waterdata/utils.py b/dataretrieval/waterdata/utils.py index 9a2be2c4..80adbba5 100644 --- a/dataretrieval/waterdata/utils.py +++ b/dataretrieval/waterdata/utils.py @@ -5,7 +5,6 @@ import json import logging import numbers -import os import re from collections.abc import ( AsyncIterator, @@ -27,7 +26,7 @@ from dataretrieval import __version__ from dataretrieval.utils import HTTPX_DEFAULTS, BaseMetadata -from dataretrieval.waterdata import _progress, chunking +from dataretrieval.waterdata import _config, _progress, chunking from dataretrieval.waterdata.chunking import ( _QUOTA_HEADER, RateLimited, @@ -340,8 +339,10 @@ def _default_headers(): ------- dict A dictionary containing default headers including 'Accept-Encoding', - 'Accept', 'User-Agent', and 'lang'. If the environment variable - 'API_USGS_PAT' is set, its value is included as the 'X-Api-Key' header. + 'Accept', 'User-Agent', and 'lang'. If an API token is configured + (via ``$API_USGS_PAT`` or any other config layer — see + :mod:`dataretrieval.waterdata._config`), it's included as the + ``X-Api-Key`` header. """ headers = { "Accept-Encoding": "compress, gzip", @@ -349,7 +350,7 @@ def _default_headers(): "User-Agent": f"python-dataretrieval/{__version__}", "lang": "en-US", } - token = os.getenv("API_USGS_PAT") + token = _config.current().api_token if token: headers["X-Api-Key"] = token return headers diff --git a/tests/waterdata_chunking_test.py b/tests/waterdata_chunking_test.py index 5dbacfbf..1e2d242f 100644 --- a/tests/waterdata_chunking_test.py +++ b/tests/waterdata_chunking_test.py @@ -31,6 +31,7 @@ if sys.version_info < (3, 10): pytest.skip("Skip entire module on Python < 3.10", allow_module_level=True) +from dataretrieval.waterdata import _config from dataretrieval.waterdata import chunking as _chunking from dataretrieval.waterdata import utils as _utils from dataretrieval.waterdata.chunking import ( @@ -1552,7 +1553,7 @@ def test_retry_policy_from_env(monkeypatch): monkeypatch.setenv("API_USGS_RETRIES", "0") assert RetryPolicy.from_env().max_retries == 0 monkeypatch.delenv("API_USGS_RETRIES", raising=False) - assert RetryPolicy.from_env().max_retries == _chunking._RETRIES_DEFAULT + assert RetryPolicy.from_env().max_retries == _config._RETRIES_DEFAULT monkeypatch.setenv("API_USGS_RETRIES", "-1") with pytest.raises(ValueError): RetryPolicy.from_env() @@ -1570,13 +1571,21 @@ def test_retry_policy_rejects_invalid_settings(): RetryPolicy(max_backoff=-1.0) -def test_retry_policy_from_env_honors_monkeypatched_constants(monkeypatch): - # The timing knobs are read from the module constants at call time, so - # monkeypatching them (as the module comment promises) takes effect. - monkeypatch.setattr(_chunking, "_RETRY_MAX_BACKOFF", 0.0) - monkeypatch.setattr(_chunking, "_RETRY_BASE_BACKOFF", 0.0) - policy = RetryPolicy.from_env() - assert policy.max_backoff == 0.0 and policy.base_backoff == 0.0 +def test_retry_policy_override_takes_effect(): + """A Python-side override is the highest-precedence config layer and + flows through ``RetryPolicy.from_env()`` for the duration of the block.""" + from dataretrieval.waterdata._config import WaterDataConfig, override + + custom = WaterDataConfig( + retry=RetryPolicy(base_backoff=0.0, max_backoff=0.0, max_retries=2) + ) + with override(custom): + policy = RetryPolicy.from_env() + assert policy.base_backoff == 0.0 + assert policy.max_backoff == 0.0 + assert policy.max_retries == 2 + # And the override unwinds at block exit. + assert RetryPolicy.from_env().base_backoff == _config._RETRY_BASE_BACKOFF # -- _retryable taxonomy ---------------------------------------------------- diff --git a/tests/waterdata_config_test.py b/tests/waterdata_config_test.py new file mode 100644 index 00000000..bd98f413 --- /dev/null +++ b/tests/waterdata_config_test.py @@ -0,0 +1,281 @@ +"""Tests for the layered config loader. + +Verifies the precedence chain (defaults → user file → local file → env vars +→ Python override) and the round-trip through the INI / configparser layer. +""" + +from __future__ import annotations + +import textwrap +from pathlib import Path + +import pytest + +from dataretrieval.waterdata import _config +from dataretrieval.waterdata._config import ( + CONCURRENCY_UNBOUNDED, + ConcurrencyPolicy, + RetryPolicy, + WaterDataConfig, + override, +) + +# ---- isolate from real config files --------------------------------------- + + +@pytest.fixture(autouse=True) +def _isolate(monkeypatch, tmp_path): + """Point user-config + local-config + cwd at empty tmp dirs so no real + file or shell env leaks in. Tests that want a layer install it explicitly + via :func:`_write` / :func:`monkeypatch.setenv`. + + Also clears the four ``API_USGS_*`` env vars (the autouse conftest + fixture sets two, but per-layer tests need a blank slate).""" + # Redirect user-config path. + user_dir = tmp_path / "user_config" + user_dir.mkdir() + monkeypatch.setattr(_config, "_user_config_path", lambda: user_dir / "config.cfg") + # Make Path.cwd() return the tmp dir so the local-config file lookup + # lands here too. + cwd = tmp_path / "cwd" + cwd.mkdir() + monkeypatch.chdir(cwd) + # Clear the env vars set by other autouse fixtures. + for name in ( + _config.ENV_RETRIES, + _config.ENV_CONCURRENT, + _config.ENV_PAT, + _config.ENV_PROGRESS, + ): + monkeypatch.delenv(name, raising=False) + + +def _write(path: Path, body: str) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text(textwrap.dedent(body).lstrip()) + + +# ---- layer 1: defaults ---------------------------------------------------- + + +def test_defaults_when_nothing_configured(): + cfg = WaterDataConfig.load() + assert cfg.retry == RetryPolicy() + assert cfg.concurrency == ConcurrencyPolicy() + assert cfg.api_token is None + assert cfg.progress is None + + +# ---- layer 2: user file --------------------------------------------------- + + +def test_user_file_overrides_defaults(): + _write( + _config._user_config_path(), + """ + [default] + api_token = user-token + progress = on + + [retry] + max_retries = 7 + base_backoff = 0.25 + + [concurrency] + max_connections = 32 + """, + ) + cfg = WaterDataConfig.load() + assert cfg.api_token == "user-token" + assert cfg.progress is True + assert cfg.retry.max_retries == 7 + assert cfg.retry.base_backoff == 0.25 + # Unset fields still take the dataclass defaults. + assert cfg.retry.max_backoff == _config._RETRY_MAX_BACKOFF + assert cfg.concurrency.max_connections == 32 + + +# ---- layer 3: local file overrides user file ------------------------------ + + +def test_local_file_overrides_user_file(): + _write( + _config._user_config_path(), + """ + [default] + api_token = user-token + + [retry] + max_retries = 7 + """, + ) + _write( + Path.cwd() / _config._LOCAL_CONFIG_NAME, + """ + [default] + api_token = local-token + + [retry] + max_retries = 99 + base_backoff = 1.5 + """, + ) + cfg = WaterDataConfig.load() + assert cfg.api_token == "local-token" + assert cfg.retry.max_retries == 99 + assert cfg.retry.base_backoff == 1.5 + + +# ---- layer 4: env overrides files ----------------------------------------- + + +def test_env_overrides_files(monkeypatch): + _write( + _config._user_config_path(), + """ + [default] + api_token = file-token + + [retry] + max_retries = 5 + + [concurrency] + max_connections = 8 + """, + ) + monkeypatch.setenv(_config.ENV_RETRIES, "2") + monkeypatch.setenv(_config.ENV_CONCURRENT, CONCURRENCY_UNBOUNDED) + monkeypatch.setenv(_config.ENV_PAT, "env-token") + monkeypatch.setenv(_config.ENV_PROGRESS, "off") + + cfg = WaterDataConfig.load() + assert cfg.api_token == "env-token" + assert cfg.progress is False + assert cfg.retry.max_retries == 2 # env wins over file's 5 + assert cfg.concurrency.max_connections is None # "unbounded" → None + + +def test_env_only_sets_what_it_provides(monkeypatch): + """An env var sets only its own field; other file-set fields are + preserved (the deep-update keeps sibling keys).""" + _write( + _config._user_config_path(), + """ + [retry] + max_retries = 5 + base_backoff = 1.0 + max_backoff = 10.0 + """, + ) + monkeypatch.setenv(_config.ENV_RETRIES, "2") + cfg = WaterDataConfig.load() + assert cfg.retry.max_retries == 2 # env overrides + assert cfg.retry.base_backoff == 1.0 # file-set, preserved + assert cfg.retry.max_backoff == 10.0 # file-set, preserved + + +# ---- layer 5: Python override wins above all ------------------------------ + + +def test_python_override_wins(monkeypatch): + monkeypatch.setenv(_config.ENV_RETRIES, "2") + custom = WaterDataConfig( + retry=RetryPolicy(max_retries=99, base_backoff=0.0, max_backoff=0.0) + ) + with override(custom): + # current() short-circuits to the active override (no file/env load). + assert _config.current() is custom + # And the legacy from_env() factories pick it up too. + assert RetryPolicy.from_env().max_retries == 99 + assert RetryPolicy.from_env().base_backoff == 0.0 + # On exit, the layered loader resumes. + assert RetryPolicy.from_env().max_retries == 2 # back to env + + +def test_override_is_contextvar_scoped(): + """Nested ``override`` blocks pop correctly; the outer override is + restored at inner exit.""" + outer = WaterDataConfig(retry=RetryPolicy(max_retries=1)) + inner = WaterDataConfig(retry=RetryPolicy(max_retries=2)) + with override(outer): + assert _config.current() is outer + with override(inner): + assert _config.current() is inner + assert _config.current() is outer + + +# ---- parsing / validation ------------------------------------------------- + + +def test_env_concurrency_unbounded_keyword(monkeypatch): + monkeypatch.setenv(_config.ENV_CONCURRENT, "UNBOUNDED") # case-insensitive + assert WaterDataConfig.load().concurrency.max_connections is None + + +def test_env_concurrency_invalid_value(monkeypatch): + monkeypatch.setenv(_config.ENV_CONCURRENT, "abc") + with pytest.raises(ValueError, match=_config.ENV_CONCURRENT): + WaterDataConfig.load() + + +def test_env_retries_negative_is_rejected(monkeypatch): + monkeypatch.setenv(_config.ENV_RETRIES, "-1") + with pytest.raises(ValueError): + WaterDataConfig.load() + + +def test_progress_parser_recognizes_truthy_falsy_and_auto(): + assert _config._parse_progress("on") is True + assert _config._parse_progress("true") is True + assert _config._parse_progress("1") is True + assert _config._parse_progress("off") is False + assert _config._parse_progress("FALSE") is False # case-insensitive + assert _config._parse_progress("auto") is None + assert _config._parse_progress("") is None + assert _config._parse_progress("nonsense") is None + + +def test_missing_files_are_silent(): + """No user or local file → no error, just falls through to defaults.""" + assert not _config._user_config_path().exists() + assert not (Path.cwd() / _config._LOCAL_CONFIG_NAME).exists() + cfg = WaterDataConfig.load() + assert cfg == WaterDataConfig() # all defaults + + +def test_unknown_keys_in_file_are_ignored(): + """A stray key in a sub-table shouldn't crash construction.""" + _write( + _config._user_config_path(), + """ + [retry] + max_retries = 3 + not_a_field = 42 + + [concurrency] + max_connections = 4 + also_not_a_field = "x" + """, + ) + cfg = WaterDataConfig.load() + assert cfg.retry.max_retries == 3 + assert cfg.concurrency.max_connections == 4 + + +# ---- direct dataclass validation ----------------------------------------- + + +def test_retry_policy_rejects_negative_settings(): + with pytest.raises(ValueError): + RetryPolicy(max_retries=-1) + with pytest.raises(ValueError): + RetryPolicy(base_backoff=-0.5) + + +def test_concurrency_policy_rejects_zero_or_negative(): + with pytest.raises(ValueError): + ConcurrencyPolicy(max_connections=0) + with pytest.raises(ValueError): + ConcurrencyPolicy(max_connections=-1) + # ``None`` is fine — explicit "unbounded". + ConcurrencyPolicy(max_connections=None)