diff --git a/dataretrieval/waterdata/_config.py b/dataretrieval/waterdata/_config.py new file mode 100644 index 00000000..a75e0b5b --- /dev/null +++ b/dataretrieval/waterdata/_config.py @@ -0,0 +1,447 @@ +"""Layered configuration for the Water Data getters. + +The Water Data module has a few runtime knobs — retry budget, concurrency +cap, API token, progress bar mode. Historically each was read directly from +its env var at the call site. This module gathers them behind one config +object with the conventional precedence rule used by ``git`` / ``npm`` / +``pip`` / ``cargo`` (closer to the action wins): + +1. **Defaults** — the dataclass field defaults below. +2. **User config file** — ``$XDG_CONFIG_HOME/dataretrieval/config.cfg`` + (default ``~/.config/dataretrieval/config.cfg``) on Linux/macOS, + ``%APPDATA%\\dataretrieval\\config.cfg`` on Windows. +3. **Local config file** — ``./dataretrieval.cfg`` in the current working + directory. +4. **Environment variables** — ``API_USGS_RETRIES``, + ``API_USGS_CONCURRENT``, ``API_USGS_PAT``, ``API_USGS_PROGRESS``. +5. **Python override** — :func:`override` (a ContextVar-scoped context + manager) or :func:`set_config` (process-wide). + +Config-file schema (stdlib INI via :mod:`configparser`):: + + [default] + api_token = ... + progress = auto # on | off | auto + + [retry] + max_retries = 4 + base_backoff = 0.5 + max_backoff = 30.0 + retry_after_cap = 60.0 + + [concurrency] + max_connections = 16 # int >= 1, or the string "unbounded" + +Stdlib-only — no TOML / YAML deps. If/when pyproject.toml integration +becomes valuable (e.g. ``[tool.dataretrieval]``), adding ``tomli`` as a +conditional dep is a small mechanical follow-up. + +Backward compatibility: every existing env var keeps working unchanged, +and the legacy ``RetryPolicy.from_env()`` / ``ConcurrencyPolicy.from_env()`` +factories still build per-call from the layered loader. +""" + +from __future__ import annotations + +import configparser +import os +import random +import sys +from collections.abc import Iterator +from contextlib import contextmanager +from contextvars import ContextVar +from dataclasses import dataclass, field, fields +from pathlib import Path +from typing import Any + +# ---- env-var names (single source of truth) ------------------------------- + +ENV_RETRIES = "API_USGS_RETRIES" +ENV_CONCURRENT = "API_USGS_CONCURRENT" +ENV_PAT = "API_USGS_PAT" +ENV_PROGRESS = "API_USGS_PROGRESS" + +# Sentinel string in ``API_USGS_CONCURRENT`` or ``concurrency.max_connections`` +# meaning "no cap on simultaneous connections". +CONCURRENCY_UNBOUNDED = "unbounded" + + +# ---- defaults the dataclass fields ship with ----------------------------- +# Values mirror the legacy ``_RETRY_*`` module constants so callers porting +# from the old shape see the same numbers. + +_RETRIES_DEFAULT = 4 +_RETRY_BASE_BACKOFF = 0.5 +_RETRY_MAX_BACKOFF = 30.0 +_RETRY_AFTER_CAP = 60.0 +_CONCURRENCY_DEFAULT = 16 + + +# ---- file locations ------------------------------------------------------- + +_CONFIG_BASENAME = "config.cfg" +_USER_CONFIG_SUBDIR = "dataretrieval" +_LOCAL_CONFIG_NAME = "dataretrieval.cfg" + +# Section names in the INI files. ``DEFAULTS_SECTION`` carries the +# top-level scalars (api_token, progress); the policy sub-tables get their +# own sections. +DEFAULTS_SECTION = "default" +RETRY_SECTION = "retry" +CONCURRENCY_SECTION = "concurrency" + + +def _user_config_path() -> Path: + """Cross-platform user config path. Honors ``XDG_CONFIG_HOME``.""" + if sys.platform == "win32": # pragma: no cover - exercised on Windows CI + base = os.environ.get("APPDATA") or str(Path.home() / "AppData/Roaming") + return Path(base) / _USER_CONFIG_SUBDIR / _CONFIG_BASENAME + xdg = os.environ.get("XDG_CONFIG_HOME") + base = Path(xdg) if xdg else Path.home() / ".config" + return base / _USER_CONFIG_SUBDIR / _CONFIG_BASENAME + + +def _local_config_path() -> Path: + """The local config file in the current working directory.""" + return Path.cwd() / _LOCAL_CONFIG_NAME + + +# ---- dataclasses ---------------------------------------------------------- + + +@dataclass(frozen=True) +class RetryPolicy: + """Retry-with-backoff timing knobs for one sub-request. + + Frozen value object. Construct directly to override per-call, or use + :meth:`from_env` to build from the layered config (defaults → user file + → local file → env vars → Python override). + + Attributes + ---------- + max_retries : int + Maximum retries per sub-request (``0`` disables retry entirely). + base_backoff : float + First-retry delay in seconds; subsequent retries double under full + jitter up to :attr:`max_backoff`. + max_backoff : float + Per-attempt ceiling on the slept delay (seconds). + retry_after_cap : float + If the server sends ``Retry-After`` greater than this (seconds), + the failure escalates to a resumable :class:`ChunkInterrupted` + instead of blocking the gather inline. + """ + + max_retries: int = _RETRIES_DEFAULT + base_backoff: float = _RETRY_BASE_BACKOFF + max_backoff: float = _RETRY_MAX_BACKOFF + retry_after_cap: float = _RETRY_AFTER_CAP + + def __post_init__(self) -> None: + # Catch invalid timing knobs here so a misconfiguration fails at + # construction, not deep in a later ``time.sleep`` (ValueError on + # a negative delay) or silently in ``asyncio.sleep`` (which + # treats negative as zero). + if self.max_retries < 0: + raise ValueError(f"max_retries must be >= 0 (got {self.max_retries}).") + if self.base_backoff < 0 or self.max_backoff < 0 or self.retry_after_cap < 0: + raise ValueError("retry backoff settings must be non-negative.") + + @classmethod + def from_env(cls) -> RetryPolicy: + """Build from the layered config (defaults → user file → local + file → env vars → Python override).""" + return current().retry + + def should_retry(self, attempt: int, retry_after: float | None) -> bool: + """Whether ``attempt`` should be retried under this policy. + + Returns ``False`` if the policy is exhausted or if the server's + ``Retry-After`` (seconds) exceeds :attr:`retry_after_cap`. + """ + if attempt > self.max_retries: + return False + return retry_after is None or retry_after <= self.retry_after_cap + + def backoff(self, attempt: int, retry_after: float | None) -> float: + """Seconds to wait before the next retry of ``attempt``. + + Honor server ``Retry-After`` when present (already filtered by + :meth:`should_retry` against :attr:`retry_after_cap`). Otherwise: + exponential ``base_backoff * 2 ** (attempt - 1)`` capped at + :attr:`max_backoff`, then full-jitter randomized in ``[0, capped]``. + """ + if retry_after is not None: + return retry_after + ceiling = min(self.max_backoff, self.base_backoff * 2 ** (attempt - 1)) + return random.uniform(0.0, ceiling) + + +@dataclass(frozen=True) +class ConcurrencyPolicy: + """Simultaneous-connection cap for chunked sub-request gather. + + ``max_connections=None`` means uncapped — the gather dispatches every + pending sub-request and lets ``httpx`` open as many connections as + needed. + """ + + max_connections: int | None = _CONCURRENCY_DEFAULT + + def __post_init__(self) -> None: + if self.max_connections is not None and self.max_connections < 1: + raise ValueError( + f"max_connections must be >= 1 or None (got {self.max_connections})." + ) + + @classmethod + def from_env(cls) -> ConcurrencyPolicy: + """Build from the layered config — convenience for callers that + only need the concurrency piece.""" + return current().concurrency + + +@dataclass(frozen=True) +class WaterDataConfig: + """Top-level config composing every runtime knob the Water Data + getters consult. + + Construct directly, or use :meth:`load` to build via the precedence + layering. :func:`current` returns the in-effect config (the active + override if one is set, else a freshly :meth:`load`-ed one). + """ + + retry: RetryPolicy = field(default_factory=RetryPolicy) + concurrency: ConcurrencyPolicy = field(default_factory=ConcurrencyPolicy) + api_token: str | None = None + # ``None`` = auto-detect (TTY-driven); ``True`` / ``False`` are explicit + # overrides. Mirrors the legacy ``API_USGS_PROGRESS`` semantics: + # ``"on"`` / ``"true"`` / ``"1"`` → True, ``"off"`` / ``"false"`` / + # ``"0"`` → False, anything else → ``None`` (auto). + progress: bool | None = None + + @classmethod + def load(cls) -> WaterDataConfig: + """Build from the precedence chain — defaults → user file → local + file → env vars (does NOT consult :func:`override`).""" + merged: dict[str, Any] = {} + for layer in (_load_user_file(), _load_local_file(), _load_env()): + _deep_update(merged, layer) + return _from_mapping(merged) + + +# ---- ContextVar-based runtime override ----------------------------------- + +_active: ContextVar[WaterDataConfig | None] = ContextVar( + "waterdata_active_config", default=None +) + + +def current() -> WaterDataConfig: + """Return the active :class:`WaterDataConfig` — the override set via + :func:`override` / :func:`set_config` if any, else freshly loaded + from the file + env precedence chain.""" + active = _active.get() + if active is not None: + return active + return WaterDataConfig.load() + + +def set_config(config: WaterDataConfig | None) -> None: + """Pin a config process-wide (or ``None`` to clear the pin and fall + back to the layered loader). Test/notebook convenience — prefer + :func:`override` for scoped use.""" + _active.set(config) + + +@contextmanager +def override(config: WaterDataConfig) -> Iterator[None]: + """Pin ``config`` as the active config for the duration of the + ``with`` block (thread-safe and async-safe via ``ContextVar``):: + + with override(WaterDataConfig(retry=RetryPolicy(max_retries=10))): + ... # every call inside sees the override + """ + token = _active.set(config) + try: + yield + finally: + _active.reset(token) + + +# ---- file readers --------------------------------------------------------- + + +def _read_ini(path: Path) -> dict[str, dict[str, str]]: + """Read an INI file; return ``{section: {key: value}}``. Missing file + returns ``{}``. A malformed file raises :class:`configparser.Error`.""" + if not path.exists(): + return {} + parser = configparser.ConfigParser( + # Don't promote the ``[DEFAULT]`` section into every other section + # (the configparser default behavior). Each section stands alone. + default_section="__never_used__", + ) + parser.read(path) + return {section: dict(parser[section]) for section in parser.sections()} + + +def _ini_to_mapping(sections: dict[str, dict[str, str]]) -> dict[str, Any]: + """Convert raw INI string values into the shape ``_from_mapping`` expects. + + ``[default]`` keys become top-level entries (``api_token``, ``progress``); + ``[retry]`` and ``[concurrency]`` become sub-tables. String values are + parsed into the appropriate types (int / float / bool / None). + """ + out: dict[str, Any] = {} + # Top-level scalars from [default]. + defaults = sections.get(DEFAULTS_SECTION, {}) + if "api_token" in defaults: + out["api_token"] = defaults["api_token"] + if "progress" in defaults: + parsed = _parse_progress(defaults["progress"]) + if parsed is not None or defaults["progress"].strip().lower() in ("auto", ""): + out["progress"] = parsed + # Retry sub-table. + retry_raw = sections.get(RETRY_SECTION, {}) + if retry_raw: + retry: dict[str, Any] = {} + if "max_retries" in retry_raw: + retry["max_retries"] = _coerce_int(retry_raw["max_retries"], "max_retries") + for k in ("base_backoff", "max_backoff", "retry_after_cap"): + if k in retry_raw: + retry[k] = _coerce_float(retry_raw[k], k) + if retry: + out["retry"] = retry + # Concurrency sub-table. + conc_raw = sections.get(CONCURRENCY_SECTION, {}) + if "max_connections" in conc_raw: + out["concurrency"] = { + "max_connections": _parse_concurrency(conc_raw["max_connections"]) + } + return out + + +def _load_user_file() -> dict[str, Any]: + """Layer 2 — ``~/.config/dataretrieval/config.cfg`` (or platform + equivalent).""" + return _ini_to_mapping(_read_ini(_user_config_path())) + + +def _load_local_file() -> dict[str, Any]: + """Layer 3 — ``./dataretrieval.cfg`` in the current working directory.""" + return _ini_to_mapping(_read_ini(_local_config_path())) + + +def _load_env() -> dict[str, Any]: + """Layer 4 — the four ``API_USGS_*`` env vars, mapped onto the + dataclass shape.""" + out: dict[str, Any] = {} + if (raw := os.environ.get(ENV_RETRIES)) is not None: + out["retry"] = {"max_retries": _coerce_int(raw, ENV_RETRIES)} + if (raw := os.environ.get(ENV_CONCURRENT)) is not None: + out["concurrency"] = {"max_connections": _parse_concurrency(raw)} + if (raw := os.environ.get(ENV_PAT)) is not None: + out["api_token"] = raw + if (raw := os.environ.get(ENV_PROGRESS)) is not None: + parsed = _parse_progress(raw) + if parsed is not None or raw.strip().lower() in ("auto", ""): + out["progress"] = parsed + return out + + +# ---- mapping → dataclass -------------------------------------------------- + + +def _from_mapping(mapping: dict[str, Any]) -> WaterDataConfig: + """Construct a :class:`WaterDataConfig` from a merged mapping, + silently dropping unknown keys so a stray entry doesn't crash + construction. Sub-tables (``retry``, ``concurrency``) feed their + own dataclass constructors via :func:`_filter_kwargs`.""" + retry_kw = _filter_kwargs(mapping.get("retry", {}) or {}, RetryPolicy) + conc_kw = _filter_kwargs(mapping.get("concurrency", {}) or {}, ConcurrencyPolicy) + return WaterDataConfig( + retry=RetryPolicy(**retry_kw), + concurrency=ConcurrencyPolicy(**conc_kw), + api_token=mapping.get("api_token"), + progress=mapping.get("progress"), + ) + + +def _filter_kwargs(mapping: dict[str, Any], cls: type) -> dict[str, Any]: + """Keep only keys that match dataclass fields of ``cls``.""" + known = {f.name for f in fields(cls)} + return {k: v for k, v in mapping.items() if k in known} + + +# ---- helpers -------------------------------------------------------------- + + +def _deep_update(dst: dict[str, Any], src: dict[str, Any]) -> None: + """Recursive dict merge: nested dicts merge; scalars overwrite. Used + to layer config sources without dropping unset sub-table keys.""" + for k, v in src.items(): + if isinstance(v, dict) and isinstance(dst.get(k), dict): + _deep_update(dst[k], v) + else: + dst[k] = v + + +def _coerce_int(raw: str | int, name: str) -> int: + if isinstance(raw, int): + return raw + try: + return int(raw) + except ValueError as e: + raise ValueError(f"{name} must be an integer (got {raw!r}).") from e + + +def _coerce_float(raw: str | float, name: str) -> float: + if isinstance(raw, (int, float)): + return float(raw) + try: + return float(raw) + except ValueError as e: + raise ValueError(f"{name} must be a number (got {raw!r}).") from e + + +def _parse_concurrency(raw: str | int) -> int | None: + """Parse a concurrency value. The literal ``"unbounded"`` + (case-insensitive) → ``None``; anything else must parse as an int >= 1.""" + if isinstance(raw, str) and raw.strip().lower() == CONCURRENCY_UNBOUNDED: + return None + n = _coerce_int(raw, ENV_CONCURRENT) + if n < 1: + raise ValueError( + f"{ENV_CONCURRENT} must be >= 1 or '{CONCURRENCY_UNBOUNDED}' (got {raw!r})." + ) + return n + + +def _parse_progress(raw: str) -> bool | None: + """Parse a progress preference. ``"on"`` / ``"true"`` / ``"1"`` → + True; ``"off"`` / ``"false"`` / ``"0"`` → False; ``"auto"`` / + ``""`` / anything else → ``None`` (auto).""" + s = raw.strip().lower() + if s in ("on", "true", "1", "yes"): + return True + if s in ("off", "false", "0", "no"): + return False + return None + + +# ---- public re-exports ---------------------------------------------------- + +__all__ = [ + "RetryPolicy", + "ConcurrencyPolicy", + "WaterDataConfig", + "current", + "override", + "set_config", + "ENV_RETRIES", + "ENV_CONCURRENT", + "ENV_PAT", + "ENV_PROGRESS", + "CONCURRENCY_UNBOUNDED", +] diff --git a/dataretrieval/waterdata/_progress.py b/dataretrieval/waterdata/_progress.py index 7263d555..5122f52e 100644 --- a/dataretrieval/waterdata/_progress.py +++ b/dataretrieval/waterdata/_progress.py @@ -25,7 +25,6 @@ from __future__ import annotations import contextvars -import os import sys from collections.abc import Iterator from contextlib import contextmanager @@ -80,13 +79,17 @@ def _in_jupyter_kernel() -> bool: def _enabled_default(stream: TextIO) -> bool: """Whether to draw the line by default. - ``API_USGS_PROGRESS`` wins when set. Otherwise show it for interactive use — - a TTY or a Jupyter/IPython kernel — and stay quiet for redirected output, - logs, and CI. + An explicit on/off from the layered config wins (set via + ``$API_USGS_PROGRESS``, a config file's ``[default] progress``, or a + Python override — see :mod:`._config`). Otherwise show it for + interactive use — a TTY or a Jupyter/IPython kernel — and stay quiet + for redirected output, logs, and CI. """ - override = os.getenv("API_USGS_PROGRESS") - if override is not None: - return override.strip().lower() not in {"", "0", "false", "no", "off"} + from . import _config # local import to avoid an import cycle at module load + + explicit = _config.current().progress + if explicit is not None: + return explicit if _in_jupyter_kernel(): return True return hasattr(stream, "isatty") and stream.isatty() @@ -121,6 +124,9 @@ def __init__( # The hourly request quota (``x-ratelimit-limit``), shown as the # denominator when the server reports it. self.rate_limit: str | None = None + # Transient note shown while a sub-request backs off before a + # retry; cleared by the next page/chunk so it doesn't linger. + self.retry_note: str | None = None self._last_len = 0 # Whether anything was actually written to the stream — drives whether # close() needs a terminating newline. (``current_chunk`` is a poor @@ -140,6 +146,7 @@ def start_chunk(self, index: int) -> None: avoids a premature "0 pages" frame before the first page arrives. """ self.current_chunk = index + self.retry_note = None if self.total_chunks > 1: self._render() @@ -147,6 +154,25 @@ def add_page(self, rows: int = 0) -> None: """Record one fetched page carrying ``rows`` rows and redraw.""" self.pages += 1 self.rows += int(rows) + self.retry_note = None + self._render() + + def note_retry(self, *, attempt: int, wait: float) -> None: + """Show that a sub-request is backing off before retry ``attempt``. + + Cleared by the next :meth:`add_page` / :meth:`start_chunk` (or by + :meth:`close`) so the line returns to normal once the retry resolves. + """ + # Keep sub-second waits explicit (avoid misleading ``0s``) while + # rendering whole-second waits without unnecessary ``.0`` noise. + # ``float()`` to support Python 3.9-3.11: ``round(int, 1)`` returns an + # int and ``int.is_integer()`` (used below) only exists on 3.12+. + wait_1dp = round(float(wait), 1) + if wait_1dp < 1 or not wait_1dp.is_integer(): + secs = f"{wait_1dp:.1f}s" + else: + secs = f"{wait_1dp:.0f}s" + self.retry_note = f"retrying (attempt {attempt}, waiting {secs})" self._render() def set_rate_remaining( @@ -179,6 +205,8 @@ def _format(self) -> str: else: segment = f"{remaining} requests remaining" parts.append(segment) + if self.retry_note is not None: + parts.append(self.retry_note) if self.service: return f"Retrieving: {self.service} · " + " · ".join(parts) return "Progress: " + " · ".join(parts) @@ -209,6 +237,13 @@ def close(self) -> None: """ if self._closed: return + # A retry note set during the final backoff would otherwise freeze as + # the persisted last line of a call that has since completed or given + # up; clear it and redraw (while still un-closed, so ``_render`` runs) + # so the final state isn't a stale "retrying". + if self.enabled and self._rendered and self.retry_note is not None: + self.retry_note = None + self._render() self._closed = True if not (self.enabled and self._rendered): return @@ -220,8 +255,10 @@ def close(self) -> None: self.enabled = False def _maybe_hint_api_key(self) -> None: + from . import _config # local import to avoid an import cycle at module load + global _api_key_hint_shown - if _api_key_hint_shown or os.getenv("API_USGS_PAT"): + if _api_key_hint_shown or _config.current().api_token: return # Set the once-per-process latch only after a successful write, so a # failed write (broken pipe) doesn't silently burn the hint for every diff --git a/dataretrieval/waterdata/api.py b/dataretrieval/waterdata/api.py index 57fffc88..44550375 100644 --- a/dataretrieval/waterdata/api.py +++ b/dataretrieval/waterdata/api.py @@ -2022,6 +2022,7 @@ def get_reference_table( collection: str, limit: int | None = None, query: dict | None = None, + max_rows: int | None = None, ) -> tuple[pd.DataFrame, BaseMetadata]: """Get metadata reference tables for the USGS Water Data API. @@ -2046,6 +2047,12 @@ def get_reference_table( query: dictionary, optional The optional args parameter can be used to pass a dictionary of query parameters to the collection API call. + max_rows : int, optional + Cap the total number of rows returned, stopping pagination early + instead of downloading the whole table. Useful for cheaply + previewing large tables (e.g. ``hydrologic-unit-codes`` has ~125k + rows). Unlike ``limit`` (the per-page size), this bounds the total + result. The default (None) downloads every page. Returns ------- @@ -2092,7 +2099,9 @@ def get_reference_table( query_args = dict(query) if query else {} if limit is not None: query_args["limit"] = limit - return get_ogc_data(args=query_args, output_id=output_id, service=collection) + return get_ogc_data( + args=query_args, output_id=output_id, service=collection, max_rows=max_rows + ) def get_codes(code_service: CODE_SERVICES) -> pd.DataFrame: diff --git a/dataretrieval/waterdata/chunking.py b/dataretrieval/waterdata/chunking.py index 36ee24fd..f0aec9e8 100644 --- a/dataretrieval/waterdata/chunking.py +++ b/dataretrieval/waterdata/chunking.py @@ -4,19 +4,39 @@ parameter (sites, parameter codes, …) plus the cql-text ``filter``, which splits along its top-level OR clauses. Any of them can fan the URL past the server's ~8 KB byte limit. ``ChunkPlan`` picks a fan-out -for each axis that minimizes total sub-requests under the URL budget; -``ChunkedCall`` iterates the joint cartesian product so every -sub-request URL fits. Requests that already fit get a trivial +for each axis that minimizes total sub-requests while keeping every +sub-request URL under the budget; ``ChunkedCall`` fetches the resulting +cartesian product of chunks. Requests that already fit get a trivial single-step plan — ``ChunkedCall`` has one code path either way. -Interruption: any mid-stream transient failure (429, 5xx) surfaces -as a ``ChunkInterrupted`` subclass — ``QuotaExhausted`` for 429, -``ServiceInterrupted`` for 5xx. The exception carries ``.call``, a +Concurrency: ``multi_value_chunked`` fans every pending sub-request out +under one ``asyncio.gather`` sharing a single ``httpx.AsyncClient``; +concurrency is bounded purely by the client's connection pool +(``httpx.Limits(max_connections=N, max_keepalive_connections=N)``), so +the pool throttles. ``API_USGS_CONCURRENT`` resolves +``N``: an integer N > 1 caps connections at N; ``1`` pins a single +connection (one request at a time); the literal ``unbounded`` removes +the cap (``N=None``). The default (16) is the server-friendly sweet +spot; higher values can trip USGS burst-protection 5xx in practice. The +fan-out runs in a short-lived worker thread (an ``anyio`` blocking +portal), so it works whether or not the caller is already inside an +event loop (Jupyter / IPython / async apps). + +Retries: each sub-request is retried on a transient failure (429, +5xx, connect/read timeout) with exponential backoff + full jitter, +honoring a server ``Retry-After`` when present. ``API_USGS_RETRIES`` +sets the cap (default 4; ``0`` disables). A ``Retry-After`` longer +than the per-call ceiling escalates to a resumable interruption. + +Interruption: any mid-stream transient failure — 429, 5xx, or a bare +transport error (connect/read timeout, oversize follow-up URL) — surfaces +as a ``ChunkInterrupted`` subclass: ``QuotaExhausted`` for 429, +``ServiceInterrupted`` for the rest. The exception carries ``.call``, a ``ChunkedCall`` handle that owns the already-completed sub-request -state. Call ``.call.resume()`` once the underlying condition -clears; only the still-pending sub-requests are re-issued. -``Retry-After`` (when the server sets it) is surfaced on the -exception as ``.retry_after``. +state (sparse-indexed, since gathered sub-requests complete out of +order). Call ``.call.resume()`` once the underlying condition clears; +only the still-pending sub-requests are re-issued. ``Retry-After`` (when +the server sets it) is surfaced on the exception as ``.retry_after``. Dedup: list-axis chunks don't overlap; filter-axis chunks can, so ``_combine_chunk_frames`` dedupes by feature ``id``. ``properties``, @@ -27,11 +47,12 @@ from __future__ import annotations +import asyncio import copy import functools import itertools import math -from collections.abc import Callable, Iterator +from collections.abc import Awaitable, Callable, Iterator from contextlib import contextmanager, suppress from contextvars import ContextVar from dataclasses import dataclass @@ -41,6 +62,7 @@ import httpx import pandas as pd +from anyio.from_thread import start_blocking_portal from dataretrieval.utils import HTTPX_DEFAULTS @@ -56,22 +78,13 @@ # leaves ~200 bytes for request-line framing and proxy variance. _WATERDATA_URL_BYTE_LIMIT = 8000 -# Default rule: any list-shaped kwarg with >1 element is chunked across -# sub-requests — each chunk becomes a comma-joined sub-list in the URL. -# The OGC getters expose ~90 such list-shaped params (IDs, codes, -# statuses, ...), all chunkable, so it's shorter to enumerate the -# exceptions than to maintain an allowlist that grows with the API. -# Exceptions, by reason: -# - response shape: ``properties`` defines the columns; sharding -# would yield different schemas per chunk. -# - structured: ``bbox`` is a fixed 4-element coord tuple. -# - intervals: date/time ranges are not enumerable sets. -# - handled elsewhere: ``filter`` becomes its own axis in -# ``_extract_axes`` (joiner ``" OR "``); -# comma-joining CQL clauses would emit -# malformed expressions. -# - scalar by contract: ``limit``, ``skip_geometry``, ``filter_lang`` -# — a list value would be a type-erasure smuggle. +# Any list-shaped kwarg with >1 element is chunked (comma-joined per +# sub-list in the URL); ~90 OGC params qualify, so we denylist the few +# exceptions rather than maintain a growing allowlist. Excluded because: +# ``properties`` defines the column schema; ``bbox`` is a fixed coord +# tuple; date/time params are intervals, not enumerable sets; ``filter`` +# is handled as its own OR-axis in ``_extract_axes``; and ``limit`` / +# ``skip_geometry`` / ``filter_lang`` are scalar by contract. _NEVER_CHUNK = frozenset( { "properties", @@ -93,22 +106,45 @@ # Response header USGS uses to advertise remaining hourly quota. _QUOTA_HEADER = "x-ratelimit-remaining" -# Client shared across all sub-requests of a single chunked call so -# paginated-loop helpers downstream (``_walk_pages``) reuse one -# connection pool across the whole call. ``None`` when not inside a -# chunked call — paginated helpers fall back to their own short-lived -# client in that case. -_chunked_client: ContextVar[httpx.Client | None] = ContextVar( +# ``RetryPolicy`` and ``ConcurrencyPolicy`` are the value-object knobs the +# chunker reads at call time. They live in :mod:`._config`, which layers +# defaults → user file → local file → env vars → Python override (see that +# module's docstring). Re-exported here so legacy callers / tests keep +# ``from dataretrieval.waterdata.chunking import RetryPolicy`` working. +from ._config import ConcurrencyPolicy, RetryPolicy # noqa: E402 + +# Default for direct ``ChunkedCall`` / ``ChunkPlan.execute`` construction +# (and tests): no retrying. The production decorator path explicitly passes +# ``RetryPolicy.from_env()`` so retries are on by default there. +_NO_RETRY = RetryPolicy(max_retries=0) + + +# Shared per-call ``httpx.AsyncClient``, published via :func:`_publish` +# during ``ChunkedCall._run`` so paginated-loop helpers (``_walk_pages``) +# reuse the same connection pool across every sub-request. ``None`` +# outside a chunked call — paginated helpers then open their own +# short-lived client. +_chunked_client: ContextVar[httpx.AsyncClient | None] = ContextVar( "_chunked_client", default=None ) @contextmanager -def _publish_client(client: httpx.Client) -> Iterator[None]: +def _publish(client: httpx.AsyncClient) -> Iterator[None]: """ - Make ``client`` visible to :func:`get_active_client` for the - duration of the ``with`` block via the ``_chunked_client`` - ContextVar. Wraps the set/reset token dance so callers don't have to. + Publish ``client`` on the ``_chunked_client`` ContextVar so the + paginated-loop helpers can borrow it via :func:`get_active_client` + for the duration of the ``with`` block. + + Parameters + ---------- + client : httpx.AsyncClient + The client to publish. + + Yields + ------ + None + Yields once, for the duration of the bind. """ token = _chunked_client.set(client) try: @@ -117,19 +153,19 @@ def _publish_client(client: httpx.Client) -> Iterator[None]: _chunked_client.reset(token) -def get_active_client() -> httpx.Client | None: +def get_active_client() -> httpx.AsyncClient | None: """ - Return the chunker's currently-published sync client, or ``None``. + Return the chunker's currently-published client, or ``None``. - Public accessor for the ``_chunked_client`` ContextVar so - sibling modules (notably :func:`dataretrieval.waterdata.utils._client_for`) - don't have to reach into the private ContextVar directly. + Used by the paginated-loop helpers (e.g. + :func:`dataretrieval.waterdata.utils._client_for`) to reuse the + per-call connection pool. Returns ------- - httpx.Client or None - The client published by :func:`_publish_client` if currently - inside a :class:`ChunkedCall` ``resume`` block; ``None`` otherwise. + httpx.AsyncClient or None + The client published via :func:`_publish` if currently inside a + :class:`ChunkedCall` run; ``None`` otherwise. """ return _chunked_client.get() @@ -140,7 +176,24 @@ def get_active_client() -> httpx.Client | None: _LIST_SEP = "," _OR_SEP = " OR " -_FetchOnce = Callable[[dict[str, Any]], tuple[pd.DataFrame, httpx.Response]] +# ``_Fetch`` is the per-sub-request fetcher the decorator wraps and +# ``ChunkedCall`` drives: an ``async def fetch(args) -> (df, response)``. +_Fetch = Callable[[dict[str, Any]], Awaitable[tuple[pd.DataFrame, httpx.Response]]] + +# Caller-supplied transform applied to the combined chunk result, so a +# resumed call returns the same shape as an un-interrupted one rather than +# the chunker's raw ``(frame, httpx.Response)``. This keeps the chunker +# generic: the OGC getters inject their post-processing (type coercion, +# column arrangement, ``BaseMetadata``) through ``utils._finalize_ogc``. +# The default is identity, so direct ``ChunkedCall`` use is unaffected. +_Finalize = Callable[[pd.DataFrame, httpx.Response], tuple[pd.DataFrame, Any]] + + +def _passthrough_result( + frame: pd.DataFrame, response: httpx.Response +) -> tuple[pd.DataFrame, Any]: + """Default :data:`_Finalize`: return the raw combined pair unchanged.""" + return frame, response class _RetryableTransportError(RuntimeError): @@ -244,9 +297,10 @@ class ChunkInterrupted(RuntimeError): later ``call.resume()`` (use ``exc.call.partial_frame`` for the live view). partial_response : httpx.Response or None - Aggregated response covering the completed sub-requests at - raise time; ``None`` if nothing had completed yet. Same - snapshot semantics as ``partial_frame``. + Raw aggregate response covering the completed sub-requests at + raise time; ``None`` if nothing had completed yet. Same snapshot + semantics as ``partial_frame``. (Raw, not finalized — use + ``exc.call.resume()`` for the finalized ``(df, metadata)`` result.) Examples -------- @@ -300,13 +354,10 @@ def __init__( self.total_chunks = total_chunks self.call = call self.retry_after = retry_after - # Snapshot partial state at raise time so the exception's view - # stays stable across later ``call.resume()`` advances; the - # live view lives on ``call.partial_frame``/``.partial_response``. - # ``partial_frame`` gets a defensive ``.copy()`` because - # ``_combine_chunk_frames`` may return a chunk frame verbatim - # in the single-completed-chunk fast path; ``partial_response`` - # already comes via ``copy.copy`` from ``_combine_chunk_responses``. + # Snapshot partial state at raise time so the exception's view stays + # stable across later ``call.resume()`` advances (the live view is on + # ``call.partial_frame`` / ``.partial_response``). ``.copy()`` guards + # the single-chunk fast path, where the frame may be returned verbatim. if call is None: self.partial_frame: pd.DataFrame = pd.DataFrame() self.partial_response: httpx.Response | None = None @@ -435,13 +486,12 @@ def _set_response_url(response: httpx.Response, url: str | httpx.URL) -> None: Overwrite the URL surfaced by a response without back-propagating the change into any aliased original. - On real ``httpx.Response`` instances ``.url`` is a read-only - property that resolves through the bound request; rather than - mutate the existing request's URL (which would be visible through - any shallow copy that shares the same ``.request``), we replace - the response's request with a fresh :class:`httpx.Request` carrying - the new URL. On lightweight test mocks ``.url`` is a plain - writable attribute — that path is tried first. + Try the direct assignment first: on lightweight test mocks ``.url`` + is a plain writable attribute. On real ``httpx.Response`` it's + read-only (it resolves through the bound request), so swap in a + fresh :class:`httpx.Request` carrying the new URL — mutating the + existing one would leak through any shallow copy that shares the + same ``.request``. """ try: response.url = url # type: ignore[misc] @@ -636,7 +686,7 @@ def __init__( axes = _extract_axes(args) # No chunkable axes → skip ``build_request`` entirely; the # common Water Data call shape shouldn't pay for an unused - # request prep on the passthrough hot path. ``fetch_once`` + # request prep on the passthrough hot path. The fetcher # will run with the user's args verbatim; if that produces # an over-budget URL, the server (or httpx itself) rejects. if not axes: @@ -767,26 +817,37 @@ def iter_sub_args(self) -> Iterator[dict[str, Any]]: sub_args[axis.arg_key] = axis.render(chunk) yield sub_args - def execute(self, fetch_once: _FetchOnce) -> tuple[pd.DataFrame, httpx.Response]: + def execute( + self, + fetch: _Fetch, + retry_policy: RetryPolicy = _NO_RETRY, + finalize: _Finalize = _passthrough_result, + ) -> tuple[pd.DataFrame, Any]: """ - Run the plan and return the combined ``(frame, response)``. + Run the plan and return the combined, finalized result. - Thin wrapper around ``ChunkedCall(self, fetch_once).resume()``; + Thin wrapper around ``ChunkedCall(self, fetch).resume()``; see :class:`ChunkedCall` for the per-sub-request semantics. Parameters ---------- - fetch_once : Callable - Function that issues a single sub-request, given the + fetch : Callable + ``async def`` that issues a single sub-request, given the substituted args dict, and returns ``(frame, response)``. + retry_policy : RetryPolicy, optional + Per-sub-request retry-with-backoff policy. Defaults to + :data:`_NO_RETRY`; the decorator passes ``RetryPolicy.from_env()``. + finalize : Callable, optional + Transform applied to the combined ``(frame, response)`` (see + :data:`_Finalize`). Defaults to :func:`_passthrough_result`. Returns ------- df : pandas.DataFrame Combined data from every successful sub-request. - response : httpx.Response - Aggregated response (canonical URL, last page's headers, - cumulative elapsed time). + response + The finalized aggregate — a raw :class:`httpx.Response` by + default, or whatever ``finalize`` produces. Raises ------ @@ -796,7 +857,7 @@ def execute(self, fetch_once: _FetchOnce) -> tuple[pd.DataFrame, httpx.Response] :class:`ServiceInterrupted` for 5xx). The resumable handle is on ``exc.call``. """ - return ChunkedCall(self, fetch_once).resume() + return ChunkedCall(self, fetch, retry_policy, finalize).resume() def _classify_chunk_error( @@ -850,6 +911,109 @@ def _classify_chunk_error( return None +def _retryable(exc: BaseException) -> tuple[bool, float | None]: + """ + Decide whether ``exc`` is a transient worth an automatic retry. + + Only the *top-level* exception is inspected — unlike + :func:`_classify_chunk_error`, which walks the ``__cause__`` chain. + The distinction matters because ``_paginate`` raises an + initial-request transient (429 / 5xx / :class:`httpx.TransportError`) + *raw*, but wraps a mid-pagination failure as a ``RuntimeError``. So a + raw transient means a sub-request that made no progress and is cheap to + re-issue, whereas a mid-pagination failure is left to escalate to a + resumable :class:`ChunkInterrupted` rather than re-walked from page 1 + (which would re-spend the quota just exhausted). ``httpx.InvalidURL`` + is never retried — a too-long cursor won't fix on a retry. + + Returns + ------- + tuple[bool, float or None] + ``(retryable, retry_after)`` — the server ``Retry-After`` hint + (seconds) when the transient carried one, else ``None``. + """ + if isinstance(exc, (RateLimited, ServiceUnavailable)): + return True, exc.retry_after + if isinstance(exc, httpx.TransportError): + return True, None + return False, None + + +def _retry_delay(exc: BaseException, attempt: int, policy: RetryPolicy) -> float | None: + """ + Decide the backoff for a just-failed ``attempt`` (1-based), or ``None`` + to give up and re-raise. + + Returns ``None`` in three cases — the error isn't a retryable + transient, the policy is exhausted, or the server's ``Retry-After`` + exceeds the cap (escalates to a resumable :class:`ChunkInterrupted` + instead). Otherwise returns the seconds to wait and emits the + progress-bar retry note. + + Parameters + ---------- + exc : BaseException + The exception raised by the just-failed attempt. + attempt : int + The just-failed attempt number (1-based). + policy : RetryPolicy + The retry-with-backoff policy governing the decision. + + Returns + ------- + float or None + Seconds to wait before retrying, or ``None`` to give up and + re-raise. + """ + retryable, retry_after = _retryable(exc) + if not retryable or not policy.should_retry(attempt, retry_after): + return None + delay = policy.backoff(attempt, retry_after) + # Surface the imminent retry on the active progress reporter, if any. + reporter = _progress.current() + if reporter is not None: + reporter.note_retry(attempt=attempt, wait=delay) + return delay + + +async def _retry( + afn: Callable[[], Awaitable[tuple[pd.DataFrame, httpx.Response]]], + policy: RetryPolicy, +) -> tuple[pd.DataFrame, httpx.Response]: + """ + Call ``afn`` with bounded retry-with-backoff on transient failures. + + A non-retryable or policy-exhausted failure (see :func:`_retry_delay`) + propagates unchanged so the caller's existing handling wraps it as a + resumable :class:`ChunkInterrupted`. The whole retry *decision* lives + in :func:`_retry_delay`; this driver only awaits the sleep between + attempts. + + Parameters + ---------- + afn : Callable + Zero-arg awaitable callable that issues a single sub-request and + returns ``(frame, response)``. + policy : RetryPolicy + The retry-with-backoff policy governing the retries. + + Returns + ------- + tuple of (pandas.DataFrame, httpx.Response) + The ``(frame, response)`` pair from the first successful call. + """ + attempt = 0 + while True: + try: + return await afn() + except Exception as exc: # noqa: BLE001 — re-raised unless retryable + attempt += 1 + delay = _retry_delay(exc, attempt, policy) + if delay is None: + raise + await asyncio.sleep(delay) + + def _combine_chunk_frames(frames: list[pd.DataFrame]) -> pd.DataFrame: """ Concatenate per-chunk frames, dropping empties and deduping by ``id``. @@ -934,7 +1098,7 @@ def _combine_chunk_responses( One response per completed sub-request, in execution order. canonical_url : str or None URL of the unchunked original request. ``None`` skips the URL - override — used by the passthrough path (``fetch_once``'s + override — used by the passthrough path (the fetcher's response already carries the original-query URL) and by the worst-case overflow path (no buildable canonical URL exists). @@ -977,62 +1141,151 @@ class ChunkedCall: Stateful handle for a chunked call. Holds the in-flight state (per-sub-request frames and responses) - and exposes a single :meth:`resume` entry point that drives the - call from wherever it is to completion — used both for the first - invocation (from :meth:`ChunkPlan.execute`) and for subsequent + and the async fetcher. A single :meth:`resume` entry point drives + the call from wherever it is to completion — used both for the + first invocation (from :meth:`ChunkPlan.execute`) and for subsequent retries after a :class:`ChunkInterrupted`. + :meth:`_run` gathers every pending sub-request over one shared + :class:`httpx.AsyncClient`, applies the failure-precedence rules, and + combines; :meth:`resume` drives it through an ``anyio`` blocking + portal so it works whether or not the caller is already inside an + event loop. Concurrency is bounded purely by the client's connection + pool, so a single connection (``API_USGS_CONCURRENT=1``) is just a + degenerate gather. + A ``ChunkedCall`` is created internally when a :class:`ChunkPlan` executes; callers reach it via :attr:`ChunkInterrupted.call` on the exception raised by a mid-stream failure. - :meth:`resume` is idempotent: it iterates + :meth:`resume` is idempotent: :meth:`_run` iterates :meth:`ChunkPlan.iter_sub_args` (deterministic order) and skips any index whose result is already in ``self._chunks``. The - completion set is a ``dict[int, (df, response)]`` keyed by - sub-args index; a subsequent ``resume`` only re-issues - sub-requests whose index isn't already present. + completion set is a sparse ``dict[int, (df, response)]`` so the + gather can record scattered completions (e.g. indices [0, 2, 5] + after siblings [1, 3, 4] failed) and a subsequent ``resume`` only + re-issues the missing indices. Parameters ---------- plan : ChunkPlan The chunking plan to execute. - fetch_once : Callable - Function that issues a single sub-request, given the + fetch : Callable + ``async def`` that issues a single sub-request, given the substituted args dict, and returns ``(frame, response)``. Attributes ---------- plan : ChunkPlan The plan being driven (read-only after construction). - fetch_once : Callable - The per-sub-request fetch function. + fetch : Callable + The async per-sub-request fetch function. + finalize : Callable + Transform applied to the combined result (see :data:`_Finalize`) at + the terminal :meth:`_run` return, so a completed call yields the + caller's finished shape. The ``partial_*`` accessors deliberately + skip it and stay raw. partial_frame : pandas.DataFrame - Combined frame of completed sub-requests (live; recomputed per - access). + Raw combined frame of completed sub-requests (live; recomputed per + access). Not finalized — call :meth:`resume` for the finished shape. partial_response : httpx.Response or None - Aggregated response with canonical URL restored, or ``None`` - when nothing has completed yet (live; recomputed per access). + Raw aggregate response (canonical URL restored), or ``None`` when + nothing has completed yet (live; recomputed per access). """ - def __init__(self, plan: ChunkPlan, fetch_once: _FetchOnce) -> None: + def __init__( + self, + plan: ChunkPlan, + fetch: _Fetch, + retry_policy: RetryPolicy = _NO_RETRY, + finalize: _Finalize = _passthrough_result, + ) -> None: self.plan = plan - self.fetch_once = fetch_once - # Completed (frame, response) pairs keyed by sub-args index; - # ``resume()`` skips indices already present. + self.fetch = fetch + self.retry_policy = retry_policy + self.finalize = finalize + # Completed (frame, response) pairs keyed by sub-args index; sparse + # (gathered sub-requests complete out of order — see class docstring). + # ``_run``'s ``track`` closure is the only writer, so ``dict`` insertion + # order is completion order (relied on by :meth:`_combine_raw`). self._chunks: dict[int, tuple[pd.DataFrame, httpx.Response]] = {} - def _ordered_chunks(self) -> list[tuple[pd.DataFrame, httpx.Response]]: - return [self._chunks[i] for i in sorted(self._chunks)] + def wrap_failure(self, exc: BaseException) -> ChunkInterrupted | None: + """ + Build the matching :class:`ChunkInterrupted` carrying this + call when ``exc`` is a recognized transient transport failure; + return ``None`` for unrecognized failures so the caller can + re-raise. Encapsulates the + ``classify → instantiate-with-call-state`` recipe so + :class:`ChunkedCall`'s private fields stay private. + + Parameters + ---------- + exc : BaseException + The exception raised by a sub-request. + + Returns + ------- + ChunkInterrupted or None + The matching :class:`ChunkInterrupted` subclass carrying this + call for a recognized transient failure; ``None`` otherwise. + """ + classification = _classify_chunk_error(exc) + if classification is None: + return None + interrupted_class, retry_after = classification + return interrupted_class( + completed_chunks=self.completed_chunks, + total_chunks=self.plan.total, + call=self, + retry_after=retry_after, + cause=exc, + ) + + @property + def completed_chunks(self) -> int: + """Number of sub-requests completed so far.""" + return len(self._chunks) + + def _combine_raw(self) -> tuple[pd.DataFrame, httpx.Response]: + """Assemble the raw ``(frame, response)`` from completed sub-requests, + before :attr:`finalize` runs. + + Frames concatenate in sub-args *index* order (``sorted`` keys — + deterministic, independent of parallel completion order). The + aggregated response takes its headers from the most-recently- + *completed* sub-request: the ``track`` closure in :meth:`_run` + is the only writer of ``self._chunks`` and ``dict`` preserves + insertion order, so the chunks' natural order is completion + order and the last one carries the freshest + ``x-ratelimit-remaining``. + + Returns + ------- + tuple of (pandas.DataFrame, httpx.Response) + The concatenated frame and the aggregated response, before + :attr:`finalize` is applied. + """ + frames = [self._chunks[i][0] for i in sorted(self._chunks)] + responses = [response for _, response in self._chunks.values()] + return ( + _combine_chunk_frames(frames), + _combine_chunk_responses(responses, self.plan.canonical_url), + ) @property def partial_frame(self) -> pd.DataFrame: """ - Concatenated, deduplicated frame of sub-requests that have - completed so far. + Raw combined frame of sub-requests that have completed so far. Live — recomputed on each access so it reflects current state - across resume attempts. + across resume attempts. Deliberately the *raw* combined frame + (``_combine_raw``), NOT the finalized result: this is a cheap, + side-effect-free snapshot for inspecting partial progress, so + reading it (or building a :class:`ChunkInterrupted` around it) + never triggers ``finalize`` work — which for OGC getters includes + a schema network fetch on an empty frame. Use ``call.resume()`` + for the finalized result. Returns ------- @@ -1042,15 +1295,17 @@ def partial_frame(self) -> pd.DataFrame: """ if not self._chunks: return pd.DataFrame() - return _combine_chunk_frames([frame for frame, _ in self._ordered_chunks()]) + return self._combine_raw()[0] @property def partial_response(self) -> httpx.Response | None: """ - Aggregated response with the canonical URL restored to the + Raw aggregate response with the canonical URL restored to the user's full original query. - Live — recomputed on each access. + Live — recomputed on each access. Like :attr:`partial_frame`, this + is the *raw* aggregate (an :class:`httpx.Response`), not the + finalized result, so inspecting it is side-effect-free. Returns ------- @@ -1060,106 +1315,190 @@ def partial_response(self) -> httpx.Response | None: """ if not self._chunks: return None - return _combine_chunk_responses( - [resp for _, resp in self._ordered_chunks()], self.plan.canonical_url - ) + return self._combine_raw()[1] + + def _pending(self) -> Iterator[tuple[int, dict[str, Any]]]: + """ + Yield ``(index, sub_args)`` for sub-requests not yet completed. + + Walks :meth:`ChunkPlan.iter_sub_args` in deterministic order + and skips any index already in ``self._chunks``. :meth:`_run` + uses this to pick up exactly the sub-requests it still owes — + first run and every resume alike. + + Yields + ------ + tuple of (int, dict) + The sub-args ``index`` and its ``sub_args`` dict for each + sub-request not yet completed. + """ + for index, sub_args in enumerate(self.plan.iter_sub_args()): + if index not in self._chunks: + yield index, sub_args - def resume(self) -> tuple[pd.DataFrame, httpx.Response]: + def resume(self) -> tuple[pd.DataFrame, Any]: """ - Drive the chunked call to completion via the sync ``fetch_once``. + Drive the chunked call to completion and return the combined result. - Opens one ``httpx.Client`` for the run and publishes it on - the ``_chunked_client`` ``ContextVar`` so paginated-loop - helpers downstream (``_walk_pages``) reuse the same connection - pool across every sub-request instead of handshaking fresh on - each. The client is closed when ``resume`` returns or raises; - a follow-up ``resume`` call (after a ``ChunkInterrupted``) - opens a new one. + Runs :meth:`_run` through an ``anyio`` blocking portal (a + short-lived worker thread), so it works whether or not the caller + is already inside an event loop (Jupyter / IPython / async apps). + The portal copies the calling context, so the active progress + reporter still reaches the sub-requests. Idempotent: only sub-requests whose index isn't already in ``self._chunks`` are re-issued. Sub-args order matches - :meth:`ChunkPlan.iter_sub_args` and is deterministic. + :meth:`ChunkPlan.iter_sub_args` and is deterministic, so a + partial completion (sparse indices) resumes correctly. Returns ------- df : pandas.DataFrame Combined data from every successful sub-request. - response : httpx.Response - Aggregated response (canonical URL, last page's headers, - cumulative elapsed time). + response + The finalized aggregate — a raw :class:`httpx.Response` + (canonical URL, most-recently-completed sub-request's headers, + cumulative elapsed time) by default, or whatever + :attr:`finalize` produces (e.g. ``BaseMetadata`` for the OGC + getters). Raises ------ ChunkInterrupted - On a mid-stream transient failure - (:class:`QuotaExhausted` for 429, - :class:`ServiceInterrupted` for 5xx). The resumable handle - is on ``exc.call`` — wait for the underlying condition to - clear and call ``exc.call.resume()`` again. + On a mid-stream transient failure — 429, 5xx, or a bare + transport error: :class:`QuotaExhausted` for 429, + :class:`ServiceInterrupted` for the rest. The resumable + handle is on ``exc.call`` — wait for the underlying + condition to clear and call ``exc.call.resume()`` again. """ - with httpx.Client(**HTTPX_DEFAULTS) as client, _publish_client(client): - reporter = _progress.current() - if reporter is not None: - reporter.set_chunks(self.plan.total) - for i, sub_args in enumerate(self.plan.iter_sub_args()): - if i in self._chunks: - continue - if reporter is not None: - reporter.start_chunk(i + 1) - self._issue(i, sub_args) - ordered = self._ordered_chunks() - frames = [frame for frame, _ in ordered] - responses = [resp for _, resp in ordered] - return ( - _combine_chunk_frames(frames), - _combine_chunk_responses(responses, self.plan.canonical_url), - ) - - def _issue(self, index: int, sub_args: dict[str, Any]) -> None: + concurrency = ConcurrencyPolicy.from_env().max_connections + with start_blocking_portal() as portal: + return portal.call(functools.partial(self._run, concurrency)) + + async def _run(self, max_concurrent: int | None) -> tuple[pd.DataFrame, Any]: """ - Issue one sub-request and record its ``(frame, response)`` pair - under ``index``. - - On failure, classify the exception and either wrap it as a - resumable :class:`ChunkInterrupted` carrying this call, or - re-raise it unchanged to preserve its type. Catches - ``RuntimeError`` (the layer's typed contract: - :class:`RateLimited`, :class:`ServiceUnavailable`, or the - mid-pagination wrapper), :class:`httpx.HTTPError` - (transport-level failures like ``ConnectError`` / - ``TimeoutException``), and :class:`httpx.InvalidURL` (which - inherits directly from ``Exception``, not ``HTTPError``); all - three feed :func:`_classify_chunk_error`. + Gather every pending sub-request over one shared + :class:`httpx.AsyncClient` and return the combined, finalized result. + + Pending sub-requests (:meth:`_pending`) fan out under + ``asyncio.gather`` with ``return_exceptions=True`` so completed + sub-requests survive a sibling's transient failure. On a + recognized transient (:class:`RateLimited`, :class:`ServiceUnavailable`, + or a bare ``httpx.HTTPError`` / ``httpx.InvalidURL``) a + :class:`ChunkInterrupted` subclass is raised carrying ``self`` on + ``.call``; ``exc.call.resume()`` then re-issues only the unfinished + indices through this same runner. + + Concurrency is bounded by the client's connection pool — + ``httpx.Limits(max_connections=N, max_keepalive_connections=N)`` + where ``N = max_concurrent`` (``None`` for unbounded). The gather + dispatches *every* pending sub-request and the pool throttles, so + ``N=1`` is just a single-connection gather (one request at a time) + and ``total <= 1`` is just a one-element gather. + The shared client is published on :data:`_chunked_client` so + the paginated-loop helpers reuse its connection pool. + + Parameters + ---------- + max_concurrent : int or None + Maximum simultaneous connections (the pool cap). ``None`` + disables the cap. + + Returns + ------- + df : pandas.DataFrame + Combined data from every sub-request. + response + The finalized aggregate — a raw :class:`httpx.Response` + (canonical URL, most-recently-completed sub-request's headers, + cumulative elapsed time) by default, or whatever + :attr:`finalize` produces (e.g. ``BaseMetadata`` for OGC getters). + + Raises + ------ + ChunkInterrupted + On a transient sub-request failure. ``.call`` is ``self``, + holding the sparse completed sub-requests; ``.call.resume()`` + re-issues the unfinished ones. """ - try: - self._chunks[index] = self.fetch_once(sub_args) - except (RuntimeError, httpx.HTTPError, httpx.InvalidURL) as exc: - classification = _classify_chunk_error(exc) - if classification is None: - raise - interrupted_class, retry_after = classification - raise interrupted_class( - completed_chunks=len(self._chunks), - total_chunks=self.plan.total, - call=self, - retry_after=retry_after, - cause=exc, - ) from exc + # ``httpx.Limits()`` defaults to ``max_connections=100`` — at higher + # concurrency the pool would silently bottleneck the fan-out behind + # that cap. Set it to the resolved concurrency so the pool *is* the + # throttle (``None`` for truly unbounded). + limits = httpx.Limits( + max_connections=max_concurrent, max_keepalive_connections=max_concurrent + ) + + async with httpx.AsyncClient(limits=limits, **HTTPX_DEFAULTS) as client: + with _publish(client): + reporter = _progress.current() + if reporter is not None: + reporter.set_chunks(self.plan.total) + + async def track( + index: int, args: dict[str, Any] + ) -> tuple[pd.DataFrame, httpx.Response]: + """One sub-request (with retry) + result-store + progress tick.""" + result = await _retry(lambda: self.fetch(args), self.retry_policy) + self._chunks[index] = result + if reporter is not None: + # Chunks finish out of order under gather, so tick the + # completed *count* rather than a positional index. + reporter.start_chunk(self.completed_chunks) + return result + + # Dispatch every pending sub-request concurrently; the + # connection pool (``limits``) is the only throttle. + # ``return_exceptions`` keeps completed pairs after a sibling + # fails, so partial state stays recoverable via :meth:`resume`. + # Failure precedence, in order: + # 1. Cancellation / interrupt signals (CancelledError, + # KeyboardInterrupt, SystemExit — non-Exception) propagate + # unmodified; wrapping them as a transient would swallow + # the user's stop signal. + # 2. A non-transient failure (a real bug — unrecognized by + # ``wrap_failure``) surfaces raw, so it isn't masked behind + # a resumable handle for a transient sibling that landed + # later. + # 3. Only when every failure is a recognized transient do we + # raise the first as a resumable ``ChunkInterrupted``. + results = await asyncio.gather( + *(track(index, args) for index, args in self._pending()), + return_exceptions=True, + ) + failures = [r for r in results if isinstance(r, BaseException)] + for exc in failures: + if not isinstance(exc, Exception): + raise exc + first_transient: tuple[ChunkInterrupted, BaseException] | None = None + for exc in failures: + interrupted = self.wrap_failure(exc) + if interrupted is None: + raise exc + if first_transient is None: + first_transient = (interrupted, exc) + if first_transient is not None: + interrupted, exc = first_transient + raise interrupted from exc + + return self.finalize(*self._combine_raw()) def multi_value_chunked( *, build_request: Callable[..., httpx.Request], url_limit: int | None = None, -) -> Callable[[_FetchOnce], _FetchOnce]: +) -> Callable[[_Fetch], Callable[..., tuple[pd.DataFrame, Any]]]: """ - Decorate a fetch function to transparently chunk over-budget requests. + Decorate an async fetcher to transparently chunk over-budget requests. - Splits multi-value list params and cql-text filters across - sub-requests so each fits the URL byte limit. Builds a - :class:`ChunkPlan` and runs it: passthrough requests are a trivial - single-step plan, so the decorated function has one code path - either way. + Returns a callable that builds a :class:`ChunkPlan` from ``args``, + constructs a :class:`ChunkedCall` over the decorated + ``async def fetch(args) -> (df, response)``, and drives it to + completion via :meth:`ChunkedCall.resume`. The plan splits multi-value + list params and the cql-text filter so each sub-request URL fits the + byte limit; an already-fitting request is a one-step plan. See the + module docstring for the concurrency model. Parameters ---------- @@ -1176,18 +1515,18 @@ def multi_value_chunked( Returns ------- Callable - A decorator that wraps a ``fetch_once(args) -> (df, response)`` - callable into one that accepts the same shape but executes the - underlying plan transparently. + A *synchronous* wrapper ``wrapper(args, *, finalize=...) -> + (df, response)`` that executes the underlying plan transparently + over the decorated async fetcher. Raises ------ RequestTooLarge If no plan can fit ``url_limit``. ChunkInterrupted - On a mid-execution 429 (:class:`QuotaExhausted`) or 5xx - (:class:`ServiceInterrupted`). See :class:`ChunkedCall` for - the resume semantics. + On a mid-execution transient — 429, 5xx, or a bare transport + error: :class:`QuotaExhausted` for 429, :class:`ServiceInterrupted` + for the rest. See :class:`ChunkedCall` for the resume semantics. See Also -------- @@ -1195,14 +1534,20 @@ def multi_value_chunked( ChunkedCall : Per-sub-request execution and resume semantics. """ - def decorator(fetch_once: _FetchOnce) -> _FetchOnce: - @functools.wraps(fetch_once) + def decorator(fetch: _Fetch) -> Callable[..., tuple[pd.DataFrame, Any]]: + @functools.wraps(fetch) def wrapper( args: dict[str, Any], - ) -> tuple[pd.DataFrame, httpx.Response]: + *, + finalize: _Finalize = _passthrough_result, + ) -> tuple[pd.DataFrame, Any]: limit = _WATERDATA_URL_BYTE_LIMIT if url_limit is None else url_limit plan = ChunkPlan(args, build_request, limit) - return plan.execute(fetch_once) + retry_policy = RetryPolicy.from_env() + # The connection-pool cap is resolved inside ``resume()`` from + # ``API_USGS_CONCURRENT``; ``1`` is a single-connection gather, + # ``total <= 1`` a one-element gather — no special branch. + return plan.execute(fetch, retry_policy, finalize) return wrapper diff --git a/dataretrieval/waterdata/utils.py b/dataretrieval/waterdata/utils.py index 66ed1723..80adbba5 100644 --- a/dataretrieval/waterdata/utils.py +++ b/dataretrieval/waterdata/utils.py @@ -4,25 +4,29 @@ import functools import json import logging -import os +import numbers import re from collections.abc import ( + AsyncIterator, + Awaitable, Callable, Iterable, Iterator, Mapping, ) -from contextlib import contextmanager +from contextlib import asynccontextmanager, contextmanager +from contextvars import ContextVar from datetime import datetime, timedelta from typing import Any, TypeVar, get_args from zoneinfo import ZoneInfo import httpx import pandas as pd +from anyio.from_thread import start_blocking_portal from dataretrieval import __version__ from dataretrieval.utils import HTTPX_DEFAULTS, BaseMetadata -from dataretrieval.waterdata import _progress, chunking +from dataretrieval.waterdata import _config, _progress, chunking from dataretrieval.waterdata.chunking import ( _QUOTA_HEADER, RateLimited, @@ -335,8 +339,10 @@ def _default_headers(): ------- dict A dictionary containing default headers including 'Accept-Encoding', - 'Accept', 'User-Agent', and 'lang'. If the environment variable - 'API_USGS_PAT' is set, its value is included as the 'X-Api-Key' header. + 'Accept', 'User-Agent', and 'lang'. If an API token is configured + (via ``$API_USGS_PAT`` or any other config layer — see + :mod:`dataretrieval.waterdata._config`), it's included as the + ``X-Api-Key`` header. """ headers = { "Accept-Encoding": "compress, gzip", @@ -344,7 +350,7 @@ def _default_headers(): "User-Agent": f"python-dataretrieval/{__version__}", "lang": "en-US", } - token = os.getenv("API_USGS_PAT") + token = _config.current().api_token if token: headers["X-Api-Key"] = token return headers @@ -524,8 +530,9 @@ def _paginated_failure_message(pages_collected: int, cause: BaseException) -> st Returns ------- str - A message suitable for the ``RuntimeError`` that ``_walk_pages`` - and ``get_stats_data`` raise from the original exception. + A message suitable for the ``RuntimeError`` that + ``_walk_pages`` and ``get_stats_data`` raise from the + original exception. """ cause_str = str(cause).removesuffix(".") # Some ``httpx`` exceptions (e.g. ``TimeoutException()`` with no args) @@ -799,31 +806,33 @@ def _get_resp_data( return df -@contextmanager -def _client_for(client: httpx.Client | None) -> Iterator[httpx.Client]: +@asynccontextmanager +async def _client_for( + client: httpx.AsyncClient | None, +) -> AsyncIterator[httpx.AsyncClient]: """ - Yield a usable client, picking the best available source. + Yield a usable async client, picking the best available source. Resolution order: 1. ``client`` if the caller supplied one (borrowed; not closed here — the caller owns its lifecycle). - 2. The chunker's shared client if we're inside a - ``ChunkedCall.resume()`` block (per - :func:`chunking.get_active_client`). Borrowed; - ``ChunkedCall.resume`` closes it on exit. - 3. A fresh short-lived ``httpx.Client`` opened here and closed + 2. The chunker's shared async client if we're inside a + :class:`~dataretrieval.waterdata.chunking.ChunkedCall` run (per + :func:`chunking.get_active_client`). Borrowed; the chunker + closes it on exit. + 3. A fresh short-lived ``httpx.AsyncClient`` opened here and closed on context exit. Parameters ---------- - client : httpx.Client or None + client : httpx.AsyncClient or None A caller-owned client to borrow, or ``None`` to defer to the chunker's shared client or a temporary one. Yields ------ - httpx.Client + httpx.AsyncClient The chosen client. """ if client is not None: @@ -833,7 +842,7 @@ def _client_for(client: httpx.Client | None) -> Iterator[httpx.Client]: if shared is not None: yield shared return - with httpx.Client(**HTTPX_DEFAULTS) as new: + async with httpx.AsyncClient(**HTTPX_DEFAULTS) as new: yield new @@ -879,23 +888,46 @@ def _aggregate_paginated_response( _Cursor = TypeVar("_Cursor") +# Optional cap on the total rows a single paginated call accumulates before it +# stops following ``next`` links. ``None`` (the default the data getters use) +# means "no cap — fetch the whole series". Set via :func:`_row_cap` so the deep +# ``_paginate`` loop can honor it without threading the value through the +# generic chunker; this mirrors the ``_progress`` ambient-reporter pattern. +_row_cap_var: ContextVar[int | None] = ContextVar("waterdata_row_cap", default=None) + -def _paginate( +@contextmanager +def _row_cap(max_rows: int | None) -> Iterator[None]: + """Cap the rows any :func:`_paginate` under this context will + accumulate (``None`` = uncapped). Used by :func:`get_reference_table` + to preview large tables without downloading every page.""" + token = _row_cap_var.set(max_rows) + try: + yield + finally: + _row_cap_var.reset(token) + + +async def _paginate( initial_req: httpx.Request, *, parse_response: Callable[[httpx.Response], tuple[pd.DataFrame, _Cursor | None]], - follow_up: Callable[[_Cursor, httpx.Client], httpx.Response], - client: httpx.Client | None = None, + follow_up: Callable[[_Cursor, httpx.AsyncClient], Awaitable[httpx.Response]], + client: httpx.AsyncClient | None = None, ) -> tuple[pd.DataFrame, httpx.Response]: """ - Drive a paginated request to completion. - - Common shape behind :func:`_walk_pages` and :func:`get_stats_data`: - send the initial request, then loop calling ``follow_up`` until - ``parse_response`` reports a ``None`` cursor, accumulating frames - and elapsed time. Any mid-pagination failure raises - ``RuntimeError`` wrapping the cause — the API exposes no resume - cursor, so the caller's only recovery is to retry the whole call. + Drive a paginated request to completion over an + :class:`httpx.AsyncClient`. + + The common shape behind :func:`_walk_pages` and + :func:`get_stats_data`: send the initial request, then loop calling + ``follow_up`` until ``parse_response`` reports a ``None`` cursor, + accumulating frames and elapsed time. Any mid-pagination failure + raises ``RuntimeError`` wrapping the cause — the API exposes no + resume cursor, so the caller's only recovery is to retry the whole + call. Issuing HTTP asynchronously lets the multiple sub-requests of a + chunked call run concurrently under + :meth:`~dataretrieval.waterdata.chunking.ChunkedCall._run`. Parameters ---------- @@ -906,9 +938,9 @@ def _paginate( DataFrame and the cursor (URL, token, …) used to drive ``follow_up`` for the next page; ``None`` terminates the loop. follow_up : callable - ``(cursor, client) -> httpx.Response``. Builds and sends - the next-page request. - client : httpx.Client, optional + ``(cursor, client) -> Awaitable[httpx.Response]``. Builds and + sends the next-page request. + client : httpx.AsyncClient, optional Caller-borrowed client. ``None`` (default) means use the chunker's shared client (if inside a chunked call) or open a temporary one. @@ -919,10 +951,10 @@ def _paginate( Concatenation of every page's parsed frame. response : httpx.Response A shallow copy of the first-page response, with ``.headers`` - rebuilt as a fresh ``httpx.Headers`` reflecting the last - page and ``.elapsed`` set to cumulative wall-clock. The - canonical URL is preserved from the first page. The original - first-page response is not mutated. + rebuilt as a fresh ``httpx.Headers`` reflecting the last page and + ``.elapsed`` set to cumulative wall-clock. The canonical URL is + preserved from the first page. The original first-page response + is not mutated. Raises ------ @@ -943,12 +975,9 @@ def _paginate( """ logger.debug("Requesting: %s", initial_req.url) reporter = _progress.current() - with _client_for(client) as client: - resp = client.send(initial_req) + async with _client_for(client) as sess: + resp = await sess.send(initial_req) _raise_for_non_200(resp) - # Keep the original-request response as the "canonical" one for - # ``md.url`` reproducibility; ``.headers`` and ``.elapsed`` get - # overwritten with latest/cumulative values below. initial_response = resp total_elapsed = _safe_elapsed(resp) @@ -957,23 +986,29 @@ def _paginate( except Exception as e: # noqa: BLE001 # Initial-page parse failures (malformed JSON, missing # ``features``, schema drift) get the same wrapped-message - # treatment as follow-up failures so callers see a - # consistent diagnostic regardless of which page broke. + # treatment as follow-up failures so callers see a consistent + # diagnostic regardless of which page broke. logger.warning("Initial response parse failed.") raise RuntimeError(_paginated_failure_message(0, e)) from e dfs = [df] + # Stop following ``next`` links once the optional row cap is reached + # (see :func:`_row_cap`); ``None`` means uncapped. The concatenation + # is sliced to the cap below so a final over-budget page can't exceed it. + cap = _row_cap_var.get() + nrows = len(df) if reporter is not None: reporter.set_rate_remaining( resp.headers.get(_QUOTA_HEADER), limit=resp.headers.get("x-ratelimit-limit"), ) reporter.add_page(rows=len(df)) - while cursor is not None: + while cursor is not None and (cap is None or nrows < cap): try: - resp = follow_up(cursor, client) + resp = await follow_up(cursor, sess) _raise_for_non_200(resp) df, cursor = parse_response(resp) dfs.append(df) + nrows += len(df) total_elapsed += _safe_elapsed(resp) if reporter is not None: reporter.set_rate_remaining( @@ -995,7 +1030,10 @@ def _paginate( final_response = _aggregate_paginated_response( initial_response, resp, total_elapsed ) - return pd.concat(dfs, ignore_index=True), final_response + result = pd.concat(dfs, ignore_index=True) + if cap is not None: + result = result.head(cap) + return result, final_response def _ogc_parse_response( @@ -1003,9 +1041,10 @@ def _ogc_parse_response( ) -> tuple[pd.DataFrame, str | None]: """Parse one OGC API page: extract the DataFrame and the next-page URL. - Coerces falsy cursors (empty href, etc.) to ``None`` so the - paginate loop's ``while cursor is not None`` terminates instead - of spinning on a meaningless value. + The parse strategy :func:`_walk_pages` hands to + :func:`_paginate`. Coerces falsy cursors (empty href, etc.) to + ``None`` so the paginate loop's ``while cursor is not None`` + terminates instead of spinning on a meaningless value. """ body = resp.json() return ( @@ -1014,19 +1053,19 @@ def _ogc_parse_response( ) -def _walk_pages( +async def _walk_pages( geopd: bool, req: httpx.Request, - client: httpx.Client | None = None, + client: httpx.AsyncClient | None = None, ) -> tuple[pd.DataFrame, httpx.Response]: """ - Iterate through paginated OGC API responses and aggregate into one - DataFrame. + Iterate paginated OGC API responses asynchronously and aggregate + them into one DataFrame. - Thin wrapper that hands off to :func:`_paginate` with OGC-specific - strategies: pages are parsed via :func:`_get_resp_data` and the - next-page cursor is the URL from the response's ``links`` array - (per :func:`_next_req_url`). + Thin wrapper that hands off to :func:`_paginate` with + OGC-specific strategies: pages are parsed via :func:`_get_resp_data` + (through :func:`_ogc_parse_response`) and the next-page cursor is the + URL from the response's ``links`` array (per :func:`_next_req_url`). Parameters ---------- @@ -1034,7 +1073,7 @@ def _walk_pages( Whether geopandas is installed (drives geometry handling). req : httpx.Request The initial HTTP request to send. - client : httpx.Client, optional + client : httpx.AsyncClient, optional Caller-borrowed client; ``None`` defers client management to :func:`_paginate`. @@ -1058,10 +1097,10 @@ def _walk_pages( headers = req.headers content = req.content if method == "POST" else None - def follow_up(cursor: str, client: httpx.Client) -> httpx.Response: - return client.request(method, cursor, headers=headers, content=content) + async def follow_up(cursor: str, sess: httpx.AsyncClient) -> httpx.Response: + return await sess.request(method, cursor, headers=headers, content=content) - return _paginate( + return await _paginate( req, parse_response=functools.partial(_ogc_parse_response, geopd=geopd), follow_up=follow_up, @@ -1233,8 +1272,50 @@ def _sort_rows(df: pd.DataFrame) -> pd.DataFrame: return df +def _finalize_ogc( + frame: pd.DataFrame, + response: httpx.Response, + *, + properties: list[str] | None, + output_id: str, + convert_type: bool, + service: str, + max_rows: int | None = None, +) -> tuple[pd.DataFrame, BaseMetadata]: + """Shape a combined OGC result into the user-facing ``(df, md)``. + + The single home for the OGC getters' result shaping: empties + normalized, types coerced (when ``convert_type``), the wire ``id`` + renamed and columns ordered, rows sorted, optionally truncated to + ``max_rows``, and the response wrapped as + :class:`~dataretrieval.utils.BaseMetadata`. + + Injected into the chunker as its ``finalize`` hook (see + :data:`~dataretrieval.waterdata.chunking._Finalize`) so the + un-interrupted return *and* a resumed ``ChunkInterrupted.call.resume()`` + produce the same shape — closing the gap where resume used to hand back + the chunker's raw frame and bare ``httpx.Response``. + + ``max_rows`` is applied here (after dedup/sort, on the *combined* frame) + rather than only per-sub-request, so a chunked call's total is bounded + to exactly ``max_rows`` and a resumed call honors the cap too — the + per-``_paginate`` ``_row_cap`` is only an early-stop download bound. + """ + frame = _deal_with_empty(frame, properties, service) + if convert_type: + frame = _type_cols(frame) + frame = _arrange_cols(frame, properties, output_id) + frame = _sort_rows(frame) + if max_rows is not None: + frame = frame.head(max_rows) + return frame, BaseMetadata(response) + + def get_ogc_data( - args: dict[str, Any], output_id: str, service: str + args: dict[str, Any], + output_id: str, + service: str, + max_rows: int | None = None, ) -> tuple[pd.DataFrame, BaseMetadata]: """ Retrieves OGC (Open Geospatial Consortium) data from a specified @@ -1253,6 +1334,11 @@ def get_ogc_data( service : str The OGC API collection name (e.g., ``"daily"``, ``"monitoring-locations"``, ``"continuous"``). + max_rows : int, optional + Stop paginating once this many rows have been collected and + truncate the result to exactly ``max_rows``. ``None`` (default) + fetches the full result. Intended for cheap previews of large, + un-chunked tables (e.g. :func:`get_reference_table`). Returns ------- @@ -1267,6 +1353,19 @@ def get_ogc_data( - Handles optional arguments such as `convert_type`. - Applies column cleanup and reordering based on service and properties. """ + # Enforce a genuine positive integer: a float (even ``10.0``) or ``bool`` + # would pass a bare ``< 1`` check and then crash deep in + # ``pd.DataFrame.head`` with an opaque ``TypeError`` after HTTP I/O has + # already fired. ``numbers.Integral`` (not ``int``) so numpy integers — + # e.g. ``max_rows`` derived from a numpy/pandas computation — are accepted; + # ``bool`` is an ``Integral`` subtype, so exclude it explicitly. + if max_rows is not None and ( + not isinstance(max_rows, numbers.Integral) + or isinstance(max_rows, bool) + or max_rows < 1 + ): + raise ValueError(f"max_rows must be a positive integer (got {max_rows!r}).") + args = args.copy() args["service"] = service args = _switch_arg_id(args, id_name=output_id, service=service) @@ -1279,34 +1378,45 @@ def get_ogc_data( convert_type = args.pop("convert_type", False) args = {k: v for k, v in args.items() if v is not None} - with _progress.progress_context(service=service): - return_list, response = _fetch_once(args) - return_list = _deal_with_empty(return_list, properties, service) - if convert_type: - return_list = _type_cols(return_list) - return_list = _arrange_cols(return_list, properties, output_id) - return_list = _sort_rows(return_list) - - return return_list, BaseMetadata(response) + # Post-processing is injected into the chunker rather than applied here, + # so it runs on *every* exit: the normal return AND a later + # ``exc.call.resume()`` after a ChunkInterrupted (which never re-enters + # this function). ``_finalize_ogc`` is the single source of result shape; + # it also applies ``max_rows`` to the *combined* frame so the cap is the + # exact total even when the plan chunks or the call is resumed, while + # ``_row_cap`` below only early-stops each sub-request's pagination. + finalize = functools.partial( + _finalize_ogc, + properties=properties, + output_id=output_id, + convert_type=convert_type, + service=service, + max_rows=max_rows, + ) + with _progress.progress_context(service=service), _row_cap(max_rows): + return _fetch_once(args, finalize=finalize) -@chunking.multi_value_chunked( - build_request=_construct_api_requests, -) -def _fetch_once( +@chunking.multi_value_chunked(build_request=_construct_api_requests) +async def _fetch_once( args: dict[str, Any], ) -> tuple[pd.DataFrame, httpx.Response]: - """Send one prepared-args OGC request; return the frame + response. + """Send one prepared-args OGC request asynchronously; return the + frame + response. ``@chunking.multi_value_chunked`` models every multi-value list parameter and the cql-text filter as a chunkable axis, greedy-halves the biggest chunk across all axes until each sub-request URL fits, and iterates the cartesian product. With no chunkable inputs the - decorator passes args through unchanged. The return shape - is ``(frame, response)``. + decorator passes args through unchanged. The decorator gathers every + sub-request over one shared :class:`httpx.AsyncClient` (concurrency + bounded by the connection pool, sized from ``API_USGS_CONCURRENT``) + and returns a *synchronous* wrapper, so ``get_ogc_data`` keeps calling + ``_fetch_once(args, finalize=...)`` synchronously. The return shape is + ``(frame, response)``. """ req = _construct_api_requests(**args) - return _walk_pages(geopd=GEOPANDAS, req=req) + return await _walk_pages(geopd=GEOPANDAS, req=req) def _handle_stats_nesting( @@ -1469,7 +1579,7 @@ def get_stats_data( args: dict[str, Any], service: str, expand_percentiles: bool, - client: httpx.Client | None = None, + client: httpx.AsyncClient | None = None, ) -> tuple[pd.DataFrame, BaseMetadata]: """ Retrieves statistical data from a specified endpoint and returns it @@ -1479,6 +1589,12 @@ def get_stats_data( handles pagination, processes results, and formats output according to the specified parameters. + The stats path doesn't go through ``multi_value_chunked`` (its query + shape has no chunkable list axes), so it drives :func:`_paginate` + directly through an ``anyio`` blocking portal. The portal runs the + pagination loop in a short-lived worker thread, so this works whether + or not the caller is already inside an event loop. + Parameters ---------- args : Dict[str, Any] @@ -1491,6 +1607,9 @@ def get_stats_data( each percentile gets its own row in the returned dataframe. If True and user requests a computation_type other than percentiles, a percentile column is still returned. + client : httpx.AsyncClient, optional + Caller-borrowed async client. ``None`` (default) opens a + temporary one inside the portal. Primarily a test seam. Returns ------- @@ -1517,23 +1636,29 @@ def parse_response(resp: httpx.Response) -> tuple[pd.DataFrame, str | None]: # protects against any "" sentinel a future schema might use. return _handle_stats_nesting(body, geopd=GEOPANDAS), body.get("next") or None - def follow_up(cursor: str, client: httpx.Client) -> httpx.Response: + async def follow_up(cursor: str, sess: httpx.AsyncClient) -> httpx.Response: # Build a fresh params dict per page so the caller's ``args`` # is never mutated. - return client.request( + return await sess.request( method, url=url, params={**args, "next_token": cursor}, headers=headers ) - # The stats path doesn't go through ``multi_value_chunked``, so it opens - # its own progress context; ``_paginate`` reports pages/rate-limit into it. - with _progress.progress_context(service=service): - df, response = _paginate( + async def _run() -> tuple[pd.DataFrame, httpx.Response]: + return await _paginate( req, parse_response=parse_response, follow_up=follow_up, client=client, ) + # The stats path opens its own progress context (it doesn't go through + # ``multi_value_chunked``); ``_paginate`` reports pages/rate-limit + # into it. The portal copies the calling context, so the reporter still + # reaches the worker thread. + with _progress.progress_context(service=service): + with start_blocking_portal() as portal: + df, response = portal.call(_run) + if expand_percentiles: df = _expand_percentiles(df) return df, BaseMetadata(response) diff --git a/tests/conftest.py b/tests/conftest.py index afbdfec2..6958c480 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,10 +1,13 @@ """ Test scaffolding for the dataretrieval test suite. -Relaxes ``pytest-httpx``'s strict-mode flags so unconsumed mocks and -unmatched real requests don't fail the suite (matches the historical -``requests-mock``-style permissiveness the test code was written -against, and keeps mocked-URL setup terse). +* Relaxes ``pytest-httpx``'s strict-mode flags so unconsumed mocks and + unmatched requests don't fail the suite (keeps mocked-URL setup terse). +* Pins ``API_USGS_CONCURRENT=1`` and ``API_USGS_RETRIES=0`` for every + test by default, so sub-request dispatch is deterministic and a single + transient surfaces immediately (no backoff). Concurrency and retry + tests opt in by re-setting the env vars inside their body via + ``monkeypatch.setenv``. """ from __future__ import annotations @@ -13,9 +16,8 @@ def pytest_collection_modifyitems(config, items): - """Apply relaxed ``pytest-httpx`` strict-mode settings to every - test in the suite — matches the permissive defaults the historical - tests were written against.""" + """Apply relaxed ``pytest-httpx`` strict-mode settings to every test + so unconsumed mocks and unmatched requests don't fail the suite.""" marker = pytest.mark.httpx_mock( assert_all_responses_were_requested=False, assert_all_requests_were_expected=False, @@ -30,3 +32,18 @@ def non_mocked_hosts() -> list[str]: """No hosts are exempted from mocking; every HTTP call must hit a mock registered through the ``httpx_mock`` fixture.""" return [] + + +@pytest.fixture(autouse=True) +def _pin_chunker_env(monkeypatch): + """Pin every test to one connection and no retries. + + Production defaults ``API_USGS_CONCURRENT`` to 16 and + ``API_USGS_RETRIES`` to 4. Pinning ``API_USGS_CONCURRENT=1`` keeps + sub-request dispatch deterministic for the mocked suite, and + ``API_USGS_RETRIES=0`` makes a single transient surface immediately + rather than be retried. Concurrency and retry tests opt in by + overriding the env inside their body. + """ + monkeypatch.setenv("API_USGS_CONCURRENT", "1") + monkeypatch.setenv("API_USGS_RETRIES", "0") diff --git a/tests/waterdata_chunking_test.py b/tests/waterdata_chunking_test.py index 21b23757..1e2d242f 100644 --- a/tests/waterdata_chunking_test.py +++ b/tests/waterdata_chunking_test.py @@ -15,20 +15,28 @@ and then fail in production. """ +import asyncio +import concurrent.futures import datetime import sys +import warnings from unittest import mock from urllib.parse import quote_plus +import httpx +import numpy as np import pandas as pd import pytest if sys.version_info < (3, 10): pytest.skip("Skip entire module on Python < 3.10", allow_module_level=True) +from dataretrieval.waterdata import _config from dataretrieval.waterdata import chunking as _chunking +from dataretrieval.waterdata import utils as _utils from dataretrieval.waterdata.chunking import ( _LIST_SEP, + _NEVER_CHUNK, _OR_SEP, _QUOTA_HEADER, ChunkInterrupted, @@ -36,13 +44,41 @@ QuotaExhausted, RateLimited, RequestTooLarge, + RetryPolicy, ServiceInterrupted, ServiceUnavailable, _chunked_client, + _combine_chunk_frames, + _combine_chunk_responses, _extract_axes, + _request_bytes, + _retry, + _retryable, + _safe_request_bytes, multi_value_chunked, ) -from dataretrieval.waterdata.utils import _construct_api_requests +from dataretrieval.waterdata.utils import _DATE_RANGE_PARAMS, _construct_api_requests + + +def _aiozero(_d): + """An async no-op sleep — monkeypatched over ``asyncio.sleep`` (via + the chunking module's binding) so retry backoff doesn't wait in tests.""" + + async def _noop(): + return None + + return _noop() + + +def _recording_sleep(slept): + """An ``_aiozero`` variant that appends each requested delay to ``slept`` + before resolving — for tests that need to assert what would have been waited.""" + + def _record(delay): + slept.append(delay) + return _aiozero(delay) + + return _record class _FakeReq: @@ -87,9 +123,6 @@ def test_never_chunk_covers_all_date_range_params(): adding a new param to ``_DATE_RANGE_PARAMS`` without also adding it to ``_NEVER_CHUNK`` would silently let the chunker try to comma-join an interval string.""" - from dataretrieval.waterdata.chunking import _NEVER_CHUNK - from dataretrieval.waterdata.utils import _DATE_RANGE_PARAMS - missing = _DATE_RANGE_PARAMS - _NEVER_CHUNK assert not missing, ( f"_DATE_RANGE_PARAMS contains entries not in _NEVER_CHUNK: " @@ -228,7 +261,7 @@ def test_multi_value_chunked_passes_through_when_url_fits(): calls = [] @multi_value_chunked(build_request=_fake_build, url_limit=8000) - def fetch(args): + async def fetch(args): calls.append(args) return pd.DataFrame(), mock.Mock( elapsed=datetime.timedelta(seconds=0.1), headers={} @@ -247,7 +280,7 @@ def test_multi_value_chunked_emits_3d_cartesian_product(): calls = [] @multi_value_chunked(build_request=_fake_build, url_limit=240) - def fetch(args): + async def fetch(args): calls.append(tuple(tuple(args[k]) for k in ("sites", "pcodes", "stats"))) return pd.DataFrame(), mock.Mock( elapsed=datetime.timedelta(seconds=0.1), headers={} @@ -296,7 +329,7 @@ def test_multi_value_chunked_lazy_url_limit(monkeypatch): calls = [] @multi_value_chunked(build_request=_fake_build) # url_limit defaults to None - def fetch(args): + async def fetch(args): calls.append(args) return pd.DataFrame(), mock.Mock( elapsed=datetime.timedelta(seconds=0.1), headers={} @@ -310,19 +343,19 @@ def fetch(args): def test_chunked_session_shared_across_sub_requests(): """Every sub-request of one chunked call sees the same - ``httpx.Client`` on the ``_chunked_client`` ContextVar, so + ``httpx.AsyncClient`` on the ``_chunked_client`` ContextVar, so downstream paginated helpers (``_walk_pages``) can reuse the connection pool instead of handshaking fresh on each sub-request.""" sessions_seen = [] @multi_value_chunked(build_request=_fake_build, url_limit=240) - def fetch(args): + async def fetch(args): sessions_seen.append(_chunked_client.get()) return pd.DataFrame(), mock.Mock( elapsed=datetime.timedelta(seconds=0.1), headers={} ) - # Outside a chunked call: no session published. + # Outside a chunked call: no session published (in this thread/context). assert _chunked_client.get() is None fetch({"sites": ["S1" * 10, "S2" * 10, "S3" * 10, "S4" * 10]}) @@ -334,19 +367,22 @@ def fetch(args): assert all(s is not None for s in sessions_seen) # And it was the same object every time. assert len({id(s) for s in sessions_seen}) == 1 - # On exit the ContextVar is reset to its default. + # The portal's worker context is torn down on exit, so the calling + # thread's ContextVar still reads its default. assert _chunked_client.get() is None def test_chunked_session_isolated_per_resume(): """A follow-up ``resume`` after an interruption opens a fresh session — the previous one was closed when its ``resume`` returned. - The ContextVar is reset between calls so leakage can't carry + The ContextVar is reset between runs so leakage can't carry a closed session into the retry.""" state = {"i": 0, "blow_up": True} + sessions_seen = [] @multi_value_chunked(build_request=_fake_build, url_limit=240) - def fetch(args): + async def fetch(args): + sessions_seen.append(_chunked_client.get()) i = state["i"] state["i"] += 1 if i == 1 and state["blow_up"]: @@ -362,13 +398,23 @@ def fetch(args): with pytest.raises(QuotaExhausted) as excinfo: fetch({"sites": ["S1" * 10, "S2" * 10, "S3" * 10, "S4" * 10]}) - # First resume's session is closed; ContextVar is reset. + # First run published a shared client to its sub-requests; the calling + # thread's ContextVar is unaffected (reads its default). assert _chunked_client.get() is None + first_run_sessions = list(sessions_seen) + assert first_run_sessions and all(s is not None for s in first_run_sessions) state["blow_up"] = False excinfo.value.call.resume() - # Second resume's session is also cleaned up. + # Second run's ContextVar is also reset in the calling thread. assert _chunked_client.get() is None + # The resume opened a FRESH client, distinct from the first run's, so no + # closed client leaks across runs. + resume_sessions = sessions_seen[len(first_run_sessions) :] + assert resume_sessions and all(s is not None for s in resume_sessions) + assert {id(s) for s in resume_sessions}.isdisjoint( + {id(s) for s in first_run_sessions} + ) def _quota_response(remaining: int | str | None) -> mock.Mock: @@ -385,7 +431,7 @@ def test_quota_exhausted_on_mid_call_429(): offset so callers can resume after the window resets.""" state = {"i": 0} - def fetch(args): + async def fetch(args): i = state["i"] state["i"] += 1 if i == 2: @@ -408,10 +454,12 @@ def fetch(args): decorated({"sites": ["S1" * 10, "S2" * 10, "S3" * 10, "S4" * 10, "S5" * 10]}) err = excinfo.value - assert err.completed_chunks == 2 # chunks 0 and 1 completed; 429 hit on i=2 + # Async fan-out: every non-failing sub-request completes (the gather + # runs all of them; only i==2 raises), so 4 of 5 complete. + assert err.completed_chunks == 4 # only the i==2 sub-request failed assert err.total_chunks == 5 assert err.partial_frame is not None - assert set(err.partial_frame["i"]) == {0, 1} + assert set(err.partial_frame["i"]) == {0, 1, 3, 4} def test_quota_exhausted_on_first_chunk_429_has_no_partial_response(): @@ -420,7 +468,7 @@ def test_quota_exhausted_on_first_chunk_429_has_no_partial_response(): is empty) so callers can branch on that to distinguish "abort before any data arrived" from "abort after partial collection".""" - def fetch(args): + async def fetch(args): raise RateLimited("429: Too many requests made.") decorated = multi_value_chunked(build_request=_fake_build, url_limit=240)(fetch) @@ -437,14 +485,18 @@ def test_quota_exhausted_resume_picks_up_where_429_stopped(): once the window resets, ``e.call.resume()`` re-issues only the sub-requests that hadn't completed and returns the full combined result. Chunks completed before the 429 are not re-fetched.""" - # The fake fetch 429s on the third call, then succeeds on every - # subsequent call. We track which sub-args have been issued so we - # can assert chunks 0/1 aren't re-fetched on resume. + # One sub-request (the chunk containing the failing site) 429s on the + # first gather, then succeeds once the window resets. Under the async + # fan-out every OTHER sub-request completes on the first gather, so + # resume re-issues only the single still-pending chunk. We track which + # sub-args have been issued to assert the completed chunks aren't + # re-fetched. fetched_sites: list[tuple[str, ...]] = [] + failing_site = "S3" * 10 rate_limited_once = {"fired": False} - def fetch(args): - if len(fetched_sites) == 2 and not rate_limited_once["fired"]: + async def fetch(args): + if failing_site in args["sites"] and not rate_limited_once["fired"]: rate_limited_once["fired"] = True raise RateLimited("429: Too many requests made.") site_tuple = tuple(args["sites"]) @@ -455,23 +507,24 @@ def fetch(args): ) decorated = multi_value_chunked(build_request=_fake_build, url_limit=240)(fetch) - sites = ["S1" * 10, "S2" * 10, "S3" * 10, "S4" * 10, "S5" * 10] + sites = ["S1" * 10, "S2" * 10, failing_site, "S4" * 10, "S5" * 10] - # First attempt: 429 on the third sub-request. + # First attempt: 429 on the chunk carrying the failing site; the other + # four sub-requests complete. with pytest.raises(QuotaExhausted) as excinfo: decorated({"sites": sites}) err = excinfo.value - assert err.completed_chunks == 2 + assert err.completed_chunks == 4 pre_resume_count = len(fetched_sites) - assert pre_resume_count == 2 # chunks 0 and 1 completed + assert pre_resume_count == 4 # every chunk but the failing one completed - # Resume: re-issues only the still-pending sub-requests. + # Resume: re-issues only the still-pending sub-request. df, _ = err.call.resume() - # Three more fetches happened on resume (chunks 2, 3, 4); chunks 0 - # and 1 were not re-fetched. - assert len(fetched_sites) - pre_resume_count == 3, ( - f"expected 3 new fetches on resume (chunks 2, 3, 4); got " + # Exactly one more fetch happened on resume (the chunk that 429'd); + # the four already-completed chunks were not re-fetched. + assert len(fetched_sites) - pre_resume_count == 1, ( + f"expected 1 new fetch on resume (the failing chunk); got " f"{len(fetched_sites) - pre_resume_count}" ) # Every original site appears in the combined frame exactly once. @@ -483,33 +536,33 @@ def test_quota_exhausted_resume_can_reraise_on_persistent_429(): ``call.resume()`` raises ``QuotaExhausted`` again — the ``ChunkedCall``'s in-flight state carries forward, so a subsequent resume after a longer wait still picks up cleanly.""" - state = {"attempts": 0} - - def fetch(args): - i = state["attempts"] - state["attempts"] += 1 - # First attempt 429s on chunk 2. Resume attempt 429s on what - # would be chunk 2 again (still the first un-completed - # sub-request). - if i == 2 or i == 3: + # Key the failure on the chunk's CONTENT (one persistently-429ing + # site) rather than a global call counter: under the async fan-out + # every other sub-request completes, and the same still-pending + # sub-request re-fails on resume — so the completed count is stable. + failing_site = "S3" * 10 + + async def fetch(args): + if failing_site in args["sites"]: raise RateLimited("429: Too many requests made.") return ( - pd.DataFrame({"i": [i], "sites": list(args["sites"])}), + pd.DataFrame({"sites": list(args["sites"])}), _quota_response(500), ) decorated = multi_value_chunked(build_request=_fake_build, url_limit=240)(fetch) - sites = ["S1" * 10, "S2" * 10, "S3" * 10, "S4" * 10, "S5" * 10] + sites = ["S1" * 10, "S2" * 10, failing_site, "S4" * 10, "S5" * 10] with pytest.raises(QuotaExhausted) as first: decorated({"sites": sites}) with pytest.raises(QuotaExhausted) as second: first.value.call.resume() - # Both exceptions report the same completed_chunks count — the - # second resume didn't make progress (it 429'd on the same chunk). - assert first.value.completed_chunks == 2 - assert second.value.completed_chunks == 2 + # Both exceptions report the same completed_chunks count — every + # sub-request but the persistently-429ing one completed on the first + # gather, and the resume re-issued only that one, which 429'd again. + assert first.value.completed_chunks == 4 + assert second.value.completed_chunks == 4 def test_resume_produces_dataset_identical_to_uninterrupted_run(): @@ -527,7 +580,7 @@ def make_fetch(rate_limit_at_call: int | None): keyed by the sub-args's sites.""" state = {"calls": 0, "tripped": False} - def fetch(args): + async def fetch(args): state["calls"] += 1 if state["calls"] == rate_limit_at_call and not state["tripped"]: state["tripped"] = True @@ -585,7 +638,7 @@ def test_chunker_passes_through_non_429_runtime_error(): it must propagate unchanged so callers see the real cause.""" state = {"i": 0} - def fetch(args): + async def fetch(args): i = state["i"] state["i"] += 1 if i == 2: @@ -608,7 +661,7 @@ def test_chunker_wraps_service_unavailable_as_resumable(): ``.call.resume()`` resumes only the still-pending sub-requests.""" state = {"i": 0, "blow_up": True} - def fetch(args): + async def fetch(args): i = state["i"] state["i"] += 1 if i == 2 and state["blow_up"]: @@ -627,7 +680,9 @@ def fetch(args): err = excinfo.value # Resumable: handle on .call with already-completed work preserved. assert err.call is not None - assert err.completed_chunks == 2 + # Async fan-out: only the i==2 sub-request fails; the gather completes + # the other four, so 4 of 5 are recorded before the failure surfaces. + assert err.completed_chunks == 4 assert err.total_chunks == 5 assert not err.call.partial_frame.empty # Upstream recovers; resuming completes the call. @@ -656,15 +711,13 @@ def test_connection_error_wrapped_as_service_interrupted(): and the user would lose the resumable handle to ``.call.resume()``. Verify ``ChunkedCall`` wraps it as ``ServiceInterrupted`` so partial progress is preserved.""" - import httpx as _httpx - state = {"i": 0, "blow_up": True} - def fetch(args): + async def fetch(args): i = state["i"] state["i"] += 1 if i == 2 and state["blow_up"]: - raise _httpx.ConnectError("connection reset") + raise httpx.ConnectError("connection reset") return ( pd.DataFrame({"sites": list(args["sites"])}), _quota_response(500), @@ -675,10 +728,11 @@ def fetch(args): decorated({"sites": ["S1" * 10, "S2" * 10, "S3" * 10, "S4" * 10, "S5" * 10]}) err = excinfo.value - assert err.completed_chunks == 2 + # Async fan-out: only the i==2 sub-request fails; the other four complete. + assert err.completed_chunks == 4 assert err.call is not None # The transport exception is on __cause__ so callers can drill in if needed. - assert isinstance(err.__cause__, _httpx.ConnectError) + assert isinstance(err.__cause__, httpx.ConnectError) # Resume after the upstream recovers. state["blow_up"] = False df, _ = err.call.resume() @@ -691,15 +745,13 @@ def test_invalid_url_wrapped_as_service_interrupted(): ``_classify_chunk_error`` an oversize follow-up URL escapes as raw ``InvalidURL`` and the user loses ``.call.resume()`` access to the partial state. Mirror the ConnectError test.""" - import httpx as _httpx - state = {"i": 0, "blow_up": True} - def fetch(args): + async def fetch(args): i = state["i"] state["i"] += 1 if i == 2 and state["blow_up"]: - raise _httpx.InvalidURL("URL is too long: 65536 bytes > 65000") + raise httpx.InvalidURL("URL is too long: 65536 bytes > 65000") return ( pd.DataFrame({"sites": list(args["sites"])}), _quota_response(500), @@ -710,9 +762,10 @@ def fetch(args): decorated({"sites": ["S1" * 10, "S2" * 10, "S3" * 10, "S4" * 10, "S5" * 10]}) err = excinfo.value - assert err.completed_chunks == 2 + # Async fan-out: only the i==2 sub-request fails; the other four complete. + assert err.completed_chunks == 4 assert err.call is not None - assert isinstance(err.__cause__, _httpx.InvalidURL) + assert isinstance(err.__cause__, httpx.InvalidURL) # The top-level message must surface the underlying cause text so # the user doesn't have to traverse ``__cause__`` to know what # actually failed (previously the message was generic "Service @@ -731,7 +784,7 @@ def test_service_interrupted_exposes_partial_frame_and_response(): crashed with AttributeError on 5xx.""" state = {"i": 0} - def fetch(args): + async def fetch(args): i = state["i"] state["i"] += 1 if i == 2: @@ -764,7 +817,7 @@ def test_partial_frame_snapshot_stable_across_resume(): a name that promises pre-resume state.""" state = {"i": 0, "blow_up": True} - def fetch(args): + async def fetch(args): i = state["i"] state["i"] += 1 if i == 2 and state["blow_up"]: @@ -801,7 +854,7 @@ def test_partial_frame_snapshot_is_a_copy_when_single_chunk(): ``pd.concat`` (which already produces a fresh frame).""" state = {"i": 0, "blow_up": True} - def fetch(args): + async def fetch(args): i = state["i"] state["i"] += 1 if i == 1 and state["blow_up"]: @@ -811,12 +864,13 @@ def fetch(args): _quota_response(500), ) - # 4 sites at url_limit=240 → 2 sub-requests. The 429 fires on the - # SECOND sub-request, so the exception captures exactly ONE - # completed chunk — the path where _combine_chunk_frames aliases. + # 2 sites at url_limit=240 → 2 singleton sub-requests. The 429 fires + # on the SECOND sub-request and the gather completes the other, so the + # exception captures exactly ONE completed chunk — the path where + # _combine_chunk_frames aliases its single non-empty frame. decorated = multi_value_chunked(build_request=_fake_build, url_limit=240)(fetch) with pytest.raises(QuotaExhausted) as excinfo: - decorated({"sites": ["S1" * 10, "S2" * 10, "S3" * 10, "S4" * 10]}) + decorated({"sites": ["S1" * 10, "S2" * 10]}) err = excinfo.value assert err.completed_chunks == 1 @@ -834,8 +888,6 @@ def test_combine_chunk_responses_returns_independent_headers(): hooks, metadata extensions) must not back-propagate into the underlying chunk response's headers, which still live on ``ChunkedCall._chunks``.""" - from dataretrieval.waterdata.chunking import _combine_chunk_responses - r0 = mock.Mock( elapsed=datetime.timedelta(seconds=0.1), headers={"X-Foo": "0"}, url="u0" ) @@ -858,13 +910,6 @@ def test_paginate_terminates_on_empty_string_cursor(): coerce falsy non-None values to None so an empty-string next- cursor (a real-but-unusual end-of-stream sentinel some pagination APIs use) doesn't trap us in an infinite ``follow_up('')`` loop.""" - import datetime as _dt - from unittest import mock as _mock - - import httpx as _httpx - - from dataretrieval.waterdata import utils as _utils - # Synthesize an OGC response with numberReturned > 0 and a "next" # link whose href is an empty string — simulating a server-side # sentinel that ``_next_req_url`` reads as ``""``. @@ -873,23 +918,23 @@ def test_paginate_terminates_on_empty_string_cursor(): "features": [{"id": "1", "properties": {"val": "a"}}], "links": [{"rel": "next", "href": ""}], } - resp = _mock.MagicMock(spec=_httpx.Response) + resp = mock.MagicMock(spec=httpx.Response) resp.status_code = 200 resp.url = "https://example.com/items?limit=1" - resp.elapsed = _dt.timedelta(seconds=0.1) + resp.elapsed = datetime.timedelta(seconds=0.1) resp.headers = {} resp.json.return_value = body_with_empty_next - client = _mock.MagicMock(spec=_httpx.Client) + client = mock.AsyncMock(spec=httpx.AsyncClient) client.send.return_value = resp - req = _mock.MagicMock(spec=_httpx.Request) + req = mock.MagicMock(spec=httpx.Request) req.method = "GET" req.headers = {} req.content = b"" req.url = "https://example.com/items?limit=1" - df, final = _utils._walk_pages(geopd=False, req=req, client=client) + df, final = asyncio.run(_utils._walk_pages(geopd=False, req=req, client=client)) # Single send + zero follow-ups: the loop terminated on the empty cursor. assert client.send.called @@ -902,10 +947,6 @@ def test_combine_chunk_frames_does_not_collapse_none_ids(): so a blanket dedup would collapse every id-less row into one — silent data loss. The function must dedupe only the id-bearing rows and preserve id-less rows verbatim.""" - import numpy as np - - from dataretrieval.waterdata.chunking import _combine_chunk_frames - # Frame A has real ids; frame B has feature-IDs of None for two # different rows that must both survive. df_a = pd.DataFrame({"id": ["x", "y"], "val": [1, 2]}) @@ -921,8 +962,6 @@ def test_combine_chunk_frames_still_dedupes_overlapping_ids(): """The original dedup contract — overlapping OR-clause partitions that produce duplicate-id rows across chunks must still collapse to one row — has to keep working when ids ARE present.""" - from dataretrieval.waterdata.chunking import _combine_chunk_frames - df_a = pd.DataFrame({"id": ["x", "y"], "val": [1, 2]}) df_b = pd.DataFrame({"id": ["y", "z"], "val": [2, 3]}) combined = _combine_chunk_frames([df_a, df_b]) @@ -936,7 +975,7 @@ def test_retry_after_surfaces_on_quota_exhausted(): can honor the server's hint instead of guessing a wait.""" state = {"i": 0} - def fetch(args): + async def fetch(args): state["i"] += 1 if state["i"] >= 3: try: @@ -975,10 +1014,6 @@ def test_request_bytes_sums_url_and_content(): the chunker just needs to size that single attribute alongside the URL. """ - import httpx - - from dataretrieval.waterdata.chunking import _request_bytes - # GET request with no body req = httpx.Request("GET", "https://x.example/ab") assert _request_bytes(req) == len("https://x.example/ab") @@ -996,9 +1031,6 @@ def test_safe_request_bytes_treats_invalid_url_as_overflow(): contract is that ``_safe_request_bytes`` returns ``url_limit + 1`` (a value strictly greater than the limit) when ``build_request`` raises ``InvalidURL``.""" - import httpx - - from dataretrieval.waterdata.chunking import _safe_request_bytes def build_request(**kwargs): raise httpx.InvalidURL("URL too long") @@ -1013,8 +1045,6 @@ def test_chunk_plan_handles_initial_url_overflow(): crash ``ChunkPlan.__init__``; the planner falls back to a worst-case sub-request URL for ``canonical_url`` and proceeds to halve the over-limit axes normally.""" - import httpx - real_build = _fake_build def overflowing_build(**kwargs): @@ -1043,7 +1073,7 @@ def test_multi_value_chunked_restores_canonical_url(): sub_urls: list[str] = [] @multi_value_chunked(build_request=_fake_build, url_limit=240) - def fetch(args): + async def fetch(args): # Each sub-response carries the chunked sub_args's URL, so # without canonical restoration the first chunk's URL would # leak through to md.url. @@ -1167,8 +1197,6 @@ def test_combine_chunk_frames_all_empty_preserves_geo_type(): pytest.importorskip("geopandas") import geopandas as gpd - from dataretrieval.waterdata.chunking import _combine_chunk_frames - empty_gdfs = [gpd.GeoDataFrame() for _ in range(3)] combined = _combine_chunk_frames(empty_gdfs) assert isinstance(combined, gpd.GeoDataFrame), ( @@ -1181,8 +1209,6 @@ def test_combine_chunk_frames_single_frame_is_safe_to_mutate(): input on the single-chunk fast path — a caller mutating ``call.partial_frame`` (a live view) must not back-propagate into the underlying ``_chunks[0][0]`` frame.""" - from dataretrieval.waterdata.chunking import _combine_chunk_frames - chunk = pd.DataFrame({"id": ["A", "B"], "value": [1, 2]}) returned = _combine_chunk_frames([chunk]) returned["new_col"] = "x" @@ -1202,6 +1228,250 @@ def test_iter_sub_args_passthrough_yields_a_copy(): assert "new_key" not in plan.args +# --- async fan-out path ---------------------------------------------------- +# +# Every sub-request is gathered over one ``httpx.AsyncClient`` and +# concurrency is bounded purely by that client's connection pool, sized +# from ``API_USGS_CONCURRENT``. The conftest's ``_pin_chunker_env`` +# autouse pins ``API_USGS_CONCURRENT=1`` (a single connection) for the +# whole suite; each test below raises it so the gather can dispatch +# sub-requests under a wider pool. The decorated async fetcher is the +# SAME one used on both first-run and resume. No real ``httpx.AsyncClient`` +# round-trip occurs (the fakes return mock data), even though +# :meth:`ChunkedCall._run` opens one for pool management. + + +def _async_chunked_fetch(monkeypatch, fetch_async, *, max_concurrent=16): + """Decorate a deterministic chunkable async fetch with a wide-pool + gather forced on via ``API_USGS_CONCURRENT``.""" + monkeypatch.setenv("API_USGS_CONCURRENT", str(max_concurrent)) + return multi_value_chunked(build_request=_fake_build, url_limit=240)(fetch_async) + + +def _atom_id(args): + """Build a deterministic id for a sub-args dict — used as the dedup key.""" + return ",".join(args["sites"]) if isinstance(args["sites"], list) else args["sites"] + + +def _ok_response(remaining=None): + headers = {} if remaining is None else {_QUOTA_HEADER: str(remaining)} + return mock.Mock(elapsed=datetime.timedelta(seconds=0.1), headers=headers) + + +def test_async_fan_out_emits_one_call_per_sub_request(monkeypatch): + """The fan-out hits every sub-args exactly once, dispatched + concurrently.""" + seen_args = [] + + async def fetch_async(args): + seen_args.append(tuple(args["sites"])) + return pd.DataFrame({"id": [_atom_id(args)]}), _ok_response() + + fetch = _async_chunked_fetch(monkeypatch, fetch_async) + + df, _ = fetch({"sites": ["S1" * 10, "S2" * 10, "S3" * 10, "S4" * 10]}) + + # Planner halves the 4-site list, so 2 sub-args → 2 async calls. + assert len(seen_args) > 1 + # Every sub-args atom is union-recovered. + assert sorted({s for tup in seen_args for s in tup}) == sorted( + ["S1" * 10, "S2" * 10, "S3" * 10, "S4" * 10] + ) + # Frames concat to one row per sub-request id, in deterministic order. + assert len(df) == len(seen_args) + + +def test_async_fan_out_aggregates_headers_from_latest_completion(monkeypatch): + """Aggregated headers reflect the most recently completed chunk. + + Completion order can differ from index order in parallel mode, so + rate-limit headers should come from whichever sub-request finished + last, not from the highest sub-args index. + """ + + async def fetch_async(args): + if "S1" * 10 in args["sites"]: + await asyncio.sleep(0.02) + return pd.DataFrame({"id": [_atom_id(args)]}), _ok_response(remaining=11) + return pd.DataFrame({"id": [_atom_id(args)]}), _ok_response(remaining=77) + + fetch = _async_chunked_fetch(monkeypatch, fetch_async) + _, response = fetch({"sites": ["S1" * 10, "S2" * 10, "S3" * 10, "S4" * 10]}) + assert response.headers[_QUOTA_HEADER] == "11" + + +def test_async_fan_out_failure_yields_resumable_call(monkeypatch): + """A transient 5xx mid-fan-out raises ``ServiceInterrupted`` whose + ``.call`` is a ``ChunkedCall`` holding the completed sub-requests + in a sparse index map. ``exc.call.resume()`` re-issues only the + unfinished sub-requests — through the same async fetcher and the same + async runner, just on a fresh gather.""" + # One async fetcher serves both first-run and resume. On the first + # gather it lets exactly one sub-request succeed and fails the rest + # transiently; once ``blow_up`` is cleared the resume gather completes + # every still-pending sub-request. ``calls`` counts every invocation + # across both gathers so we can assert resume only re-issued the owed + # sub-requests. + state = {"first_success": False, "blow_up": True} + calls = {"n": 0} + + async def fetch_async(args): + calls["n"] += 1 + if state["blow_up"]: + # Let the first dispatched sub-request through, fail the rest. + if not state["first_success"]: + state["first_success"] = True + return pd.DataFrame({"id": [_atom_id(args)]}), _ok_response( + remaining=99 + ) + raise ServiceUnavailable("503: simulated") + return pd.DataFrame({"id": [_atom_id(args)]}), _ok_response(remaining=99) + + fetch = _async_chunked_fetch(monkeypatch, fetch_async) + + with pytest.raises(ServiceInterrupted) as exc_info: + fetch({"sites": ["S1" * 10, "S2" * 10, "S3" * 10, "S4" * 10]}) + + interrupted = exc_info.value + assert interrupted.call is not None, "interruption must be resumable" + # Exactly one sub-request completed; the rest still owe. + assert interrupted.completed_chunks == 1 + assert interrupted.total_chunks > 1 + + # Resume re-issues only the missing sub-requests, via the same async + # runner the first run used. + state["blow_up"] = False + calls_before = calls["n"] + df, _ = interrupted.call.resume() + calls_on_resume = calls["n"] - calls_before + assert calls_on_resume == interrupted.total_chunks - 1 + # Final frame unions every sub-args. + assert len(df) == interrupted.total_chunks + + +def test_async_fan_out_resume_applies_finalize(monkeypatch): + """The ``finalize`` injected for a wide-pool call survives the + interruption (carried on the ``ChunkedCall`` through the anyio portal), + so ``exc.call.resume()`` still returns the finalized shape — guarding + the run -> resume -> finalize path. Partials stay raw (no finalize in + the exception ctor).""" + + def finalize(frame, response): + return frame.assign(finalized=True), ("MD", response) + + state = {"first_success": False, "blow_up": True} + + async def fetch_async(args): + if state["blow_up"]: + if not state["first_success"]: + state["first_success"] = True + return pd.DataFrame({"id": [_atom_id(args)]}), _ok_response( + remaining=99 + ) + raise ServiceUnavailable("503: simulated") + return pd.DataFrame({"id": [_atom_id(args)]}), _ok_response(remaining=99) + + fetch = _async_chunked_fetch(monkeypatch, fetch_async) + + sites = ["S1" * 10, "S2" * 10, "S3" * 10, "S4" * 10] + with pytest.raises(ServiceInterrupted) as exc_info: + fetch({"sites": sites}, finalize=finalize) + + # Partial snapshot stays raw — building the exception must not finalize. + assert "finalized" not in exc_info.value.partial_frame.columns + # Resume applies the finalize carried on the ChunkedCall. + state["blow_up"] = False + df, md = exc_info.value.call.resume() + assert "finalized" in df.columns + assert md[0] == "MD" + + +def test_wide_concurrency_uses_async_fetcher_with_no_warning(monkeypatch): + """A wide ``API_USGS_CONCURRENT`` is honored directly by the single + async fetcher: every sub-request fans out across it and NO + ``UserWarning`` is emitted.""" + calls = [] + monkeypatch.setenv("API_USGS_CONCURRENT", "16") + + @multi_value_chunked(build_request=_fake_build, url_limit=240) + async def fetch(args): + calls.append(tuple(args["sites"])) + return pd.DataFrame({"id": [_atom_id(args)]}), _ok_response() + + with warnings.catch_warnings(): + warnings.simplefilter("error") # any UserWarning would fail the test + df, _ = fetch({"sites": ["S1" * 10, "S2" * 10, "S3" * 10, "S4" * 10]}) + + assert len(calls) > 1 # the gather fanned out across every sub-request + assert len(df) == len(calls) + + +def test_async_fan_out_runs_inside_running_event_loop(monkeypatch): + """The parallel fan-out works even when the caller is already inside a + running event loop (Jupyter / async apps): the anyio blocking portal + runs it in a worker thread, so it does not raise a nested + ``asyncio.run`` error.""" + monkeypatch.setenv("API_USGS_CONCURRENT", "16") + async_calls = [] + + @multi_value_chunked(build_request=_fake_build, url_limit=240) + async def fetch(args): # the single async fetcher + async_calls.append(tuple(args["sites"])) + return pd.DataFrame({"id": [_atom_id(args)]}), _ok_response() + + async def driver(): # call the sync getter from within a running loop + # The sync wrapper drives the async core through the anyio portal in + # a worker thread, so it works even inside a running event loop + # without raising a nested-``asyncio.run`` error. + return fetch({"sites": ["S1" * 10, "S2" * 10, "S3" * 10, "S4" * 10]}) + + df, _ = asyncio.run(driver()) + assert len(async_calls) > 1 # every sub-request ran on the async path + assert len(df) == len(async_calls) + + +def test_async_fan_out_cancellation_wins_over_transient_sibling(monkeypatch): + """``asyncio.CancelledError`` raised by any sub-request must + propagate unmodified, even when a sibling raises a recognized + transient (which would otherwise wrap as a resumable + :class:`ChunkInterrupted`). Cancellation is asyncio's abort + signal — letting a transient-classification path consume it + would silently swallow the user's stop request. + + ``fetch_async`` has no ``await`` in its body, so the gather schedules + the tasks in submission order and each runs synchronously to its + raise — making ``call_count`` deterministic: 1 = first chunk + (success), 2 = second chunk (transient), 3 = third chunk (cancel). + + Through the sync→async blocking portal an in-flight cancellation + surfaces to the caller as ``concurrent.futures.CancelledError`` (the + thread-boundary cancellation type) rather than ``asyncio.CancelledError`` + — either way it propagates unmodified rather than being swallowed and + wrapped as a resumable ``ChunkInterrupted``. + """ + call_count = {"async": 0} + + async def fetch_async(args): + call_count["async"] += 1 + if call_count["async"] == 1: + return pd.DataFrame({"id": [_atom_id(args)]}), _ok_response(remaining=99) + if call_count["async"] == 2: + raise ServiceUnavailable("503: transient sibling") + if call_count["async"] == 3: + raise asyncio.CancelledError("user cancel") + return pd.DataFrame({"id": [_atom_id(args)]}), _ok_response(remaining=99) + + fetch = _async_chunked_fetch(monkeypatch, fetch_async) + + # 8 × 20-byte sites force the planner to >=3 sub-args under + # url_limit=240, so the fan-out gather sees at least the + # transient (call 2) AND the cancellation (call 3). + sites = [f"S{i}" * 10 for i in range(1, 9)] + + with pytest.raises((asyncio.CancelledError, concurrent.futures.CancelledError)): + fetch({"sites": sites}) + + def test_combine_chunk_responses_does_not_mutate_input_urls(): """Regression for the _set_response_url aliasing bug. @@ -1212,14 +1482,10 @@ def test_combine_chunk_responses_does_not_mutate_input_urls(): 'input responses are not mutated' invariant. The fix is to swap in a fresh ``httpx.Request`` rather than mutate the existing one. """ - import httpx as _httpx - - from dataretrieval.waterdata.chunking import _combine_chunk_responses - - req1 = _httpx.Request("GET", "https://example.com/chunk0") - req2 = _httpx.Request("GET", "https://example.com/chunk1") - r1 = _httpx.Response(200, request=req1) - r2 = _httpx.Response(200, request=req2) + req1 = httpx.Request("GET", "https://example.com/chunk0") + req2 = httpx.Request("GET", "https://example.com/chunk1") + r1 = httpx.Response(200, request=req1) + r2 = httpx.Response(200, request=req2) out = _combine_chunk_responses( [r1, r2], canonical_url="https://canonical.example/full" @@ -1230,3 +1496,335 @@ def test_combine_chunk_responses_does_not_mutate_input_urls(): assert str(r2.url) == "https://example.com/chunk1" assert str(req1.url) == "https://example.com/chunk0" assert str(req2.url) == "https://example.com/chunk1" + + +# --------------------------------------------------------------------------- +# Retry-with-backoff: RetryPolicy + _retryable + driver + decorator wiring. +# Conftest pins API_USGS_RETRIES=0, so these tests opt in explicitly and +# patch the chunking module's ``asyncio.sleep`` to a no-op (no real backoff). +# --------------------------------------------------------------------------- + + +def _wrap_cause(transport_exc): + """Wrap ``transport_exc`` the way ``_walk_pages`` does — a + ``RuntimeError`` with the typed transport error on ``__cause__`` — so + chain-walking is exercised realistically.""" + try: + raise RuntimeError("Paginated request failed") from transport_exc + except RuntimeError as wrapped: + return wrapped + + +# -- RetryPolicy (pure value object) ---------------------------------------- + + +def test_retry_policy_backoff_honors_retry_after(): + policy = RetryPolicy() + # A server Retry-After overrides the computed backoff verbatim. + assert policy.backoff(attempt=1, retry_after=7.5) == 7.5 + assert policy.backoff(attempt=4, retry_after=2.0) == 2.0 + + +def test_retry_policy_backoff_full_jitter_within_ceiling(): + policy = RetryPolicy(base_backoff=2.0, max_backoff=30.0) + for attempt, ceiling in [(1, 2.0), (2, 4.0), (3, 8.0), (5, 30.0)]: + samples = [policy.backoff(attempt, None) for _ in range(200)] + assert all(0.0 <= s <= ceiling for s in samples) + # Full jitter genuinely varies and reaches below the ceiling. + assert min(samples) < ceiling + + +def test_retry_policy_should_retry_exhaustion(): + policy = RetryPolicy(max_retries=2) + assert policy.should_retry(attempt=1, retry_after=None) + assert policy.should_retry(attempt=2, retry_after=None) + assert not policy.should_retry(attempt=3, retry_after=None) + + +def test_retry_policy_long_retry_after_escalates(): + policy = RetryPolicy(max_retries=5, retry_after_cap=60.0) + assert policy.should_retry(attempt=1, retry_after=30.0) # absorbed inline + assert not policy.should_retry(attempt=1, retry_after=120.0) # escalates + + +def test_retry_policy_from_env(monkeypatch): + monkeypatch.setenv("API_USGS_RETRIES", "2") + assert RetryPolicy.from_env().max_retries == 2 + monkeypatch.setenv("API_USGS_RETRIES", "0") + assert RetryPolicy.from_env().max_retries == 0 + monkeypatch.delenv("API_USGS_RETRIES", raising=False) + assert RetryPolicy.from_env().max_retries == _config._RETRIES_DEFAULT + monkeypatch.setenv("API_USGS_RETRIES", "-1") + with pytest.raises(ValueError): + RetryPolicy.from_env() + monkeypatch.setenv("API_USGS_RETRIES", "lots") + with pytest.raises(ValueError): + RetryPolicy.from_env() + + +def test_retry_policy_rejects_invalid_settings(): + with pytest.raises(ValueError): + RetryPolicy(max_retries=-1) + with pytest.raises(ValueError): + RetryPolicy(base_backoff=-0.5) + with pytest.raises(ValueError): + RetryPolicy(max_backoff=-1.0) + + +def test_retry_policy_override_takes_effect(): + """A Python-side override is the highest-precedence config layer and + flows through ``RetryPolicy.from_env()`` for the duration of the block.""" + from dataretrieval.waterdata._config import WaterDataConfig, override + + custom = WaterDataConfig( + retry=RetryPolicy(base_backoff=0.0, max_backoff=0.0, max_retries=2) + ) + with override(custom): + policy = RetryPolicy.from_env() + assert policy.base_backoff == 0.0 + assert policy.max_backoff == 0.0 + assert policy.max_retries == 2 + # And the override unwinds at block exit. + assert RetryPolicy.from_env().base_backoff == _config._RETRY_BASE_BACKOFF + + +# -- _retryable taxonomy ---------------------------------------------------- + + +def test_retryable_taxonomy(): + assert _retryable(RateLimited("429", retry_after=5.0)) == (True, 5.0) + assert _retryable(ServiceUnavailable("503")) == (True, None) + assert _retryable(httpx.ReadTimeout("slow")) == (True, None) + assert _retryable(httpx.ConnectError("down")) == (True, None) + # InvalidURL is resumable but NOT retryable (a too-long cursor won't fix). + assert _retryable(httpx.InvalidURL("too long")) == (False, None) + # Plain non-transient (e.g. a 4xx programmer error wrapped as RuntimeError). + assert _retryable(RuntimeError("400")) == (False, None) + + +def test_retryable_skips_wrapped_midpagination_transient(): + # A transient surfaced mid-pagination is re-wrapped as RuntimeError by + # _paginate; it must NOT be auto-retried (re-walking from page 1 + # would re-spend quota) — it escalates to the resumable handle instead. + # Only the raw, top-level (initial-request) transient is retryable. + assert _retryable(_wrap_cause(RateLimited("429", retry_after=3.0))) == (False, None) + assert _retryable(RateLimited("429", retry_after=3.0)) == (True, 3.0) + + +# -- async driver (the single retry driver; sync facade drives it) ---------- +# +# The retry loop lives in ``_retry``. These tests pin its behavioral +# contracts (transient-then-success, exhausted-reraises, +# non-retryable-not-retried, long-retry-after-escalates), run via +# ``asyncio.run``; the sleep is patched to a no-op so backoff doesn't +# actually wait. + + +def test_retry_transient_then_recovers(monkeypatch): + monkeypatch.setattr(_chunking.asyncio, "sleep", _aiozero) + calls = {"n": 0} + + async def afn(): + calls["n"] += 1 + if calls["n"] <= 2: + raise RateLimited("429") + return "ok" + + out = asyncio.run(_retry(afn, RetryPolicy(max_retries=3, base_backoff=0.0))) + assert out == "ok" + assert calls["n"] == 3 # two failures + one success + + +def test_retry_exhausted_reraises(monkeypatch): + monkeypatch.setattr(_chunking.asyncio, "sleep", _aiozero) + calls = {"n": 0} + + async def afn(): + calls["n"] += 1 + raise ServiceUnavailable("503") + + with pytest.raises(ServiceUnavailable): + asyncio.run(_retry(afn, RetryPolicy(max_retries=2, base_backoff=0.0))) + assert calls["n"] == 3 # first attempt + 2 retries + + +def test_retry_non_retryable_not_retried(monkeypatch): + slept: list[float] = [] + + monkeypatch.setattr(_chunking.asyncio, "sleep", _recording_sleep(slept)) + calls = {"n": 0} + + async def afn(): + calls["n"] += 1 + raise RuntimeError("400: bad request") + + with pytest.raises(RuntimeError): + asyncio.run(_retry(afn, RetryPolicy(max_retries=3))) + assert calls["n"] == 1 and slept == [] + + +def test_retry_long_retry_after_escalates(monkeypatch): + slept: list[float] = [] + + monkeypatch.setattr(_chunking.asyncio, "sleep", _recording_sleep(slept)) + calls = {"n": 0} + + async def afn(): + calls["n"] += 1 + raise RateLimited("429", retry_after=999.0) + + with pytest.raises(RateLimited): + asyncio.run(_retry(afn, RetryPolicy(max_retries=5, retry_after_cap=60.0))) + assert calls["n"] == 1 and slept == [] # no inline wait for a long window + + +# -- async driver (sleep-patched original) ---------------------------------- + + +def test_retry_transient_then_success(monkeypatch): + monkeypatch.setattr(_chunking.asyncio, "sleep", _aiozero) + calls = {"n": 0} + + async def afn(): + calls["n"] += 1 + if calls["n"] == 1: + raise httpx.ReadTimeout("slow") + return "ok" + + out = asyncio.run(_retry(afn, RetryPolicy(max_retries=3, base_backoff=0.0))) + assert out == "ok" and calls["n"] == 2 + + +# -- end-to-end through the decorator -------------------------------------- + + +def test_chunker_retries_transient_then_completes(monkeypatch): + """A transient on one sub-request is retried transparently; the + decorated call completes with no ChunkInterrupted.""" + monkeypatch.setenv("API_USGS_RETRIES", "3") + monkeypatch.setattr(_chunking.asyncio, "sleep", _aiozero) + state = {"failed": False} + + async def fetch(args): + # Fail the first sub-request once, then succeed everywhere. + if not state["failed"]: + state["failed"] = True + raise RateLimited("429: Too many requests made.") + return pd.DataFrame({"sites": list(args["sites"])}), _quota_response(500) + + decorated = multi_value_chunked(build_request=_fake_build, url_limit=240)(fetch) + sites = ["S1" * 10, "S2" * 10, "S3" * 10, "S4" * 10] + df, _ = decorated({"sites": sites}) + assert sorted(df["sites"]) == sorted(sites) # all recovered despite the 429 + + +def test_chunker_exhausted_retries_still_resumable(monkeypatch): + """When retries are exhausted the failure still surfaces as a + resumable ChunkInterrupted — retries don't swallow the escape hatch.""" + monkeypatch.setenv("API_USGS_RETRIES", "2") + monkeypatch.setattr(_chunking.asyncio, "sleep", _aiozero) + attempts = {"n": 0} + + async def fetch(args): + sites = list(args["sites"]) + if "S1" * 10 in sites: + attempts["n"] += 1 + raise ServiceUnavailable("503: service unavailable") + return pd.DataFrame({"sites": sites}), _quota_response(500) + + decorated = multi_value_chunked(build_request=_fake_build, url_limit=240)(fetch) + with pytest.raises(ServiceInterrupted) as excinfo: + decorated({"sites": ["S1" * 10, "S2" * 10, "S3" * 10, "S4" * 10]}) + assert excinfo.value.call is not None + assert attempts["n"] == 3 # first attempt + 2 retries before giving up + + +def test_async_fan_out_retries_transient_then_completes(monkeypatch): + """The parallel path retries a transient sub-request and completes.""" + monkeypatch.setenv("API_USGS_RETRIES", "3") + + monkeypatch.setattr(_chunking.asyncio, "sleep", _aiozero) + state = {"failed": False} + + async def fetch_async(args): + if not state["failed"]: + state["failed"] = True + raise ServiceUnavailable("503: transient") + return pd.DataFrame({"id": [_atom_id(args)]}), _ok_response() + + fetch = _async_chunked_fetch(monkeypatch, fetch_async) + df, _ = fetch({"sites": ["S1" * 10, "S2" * 10, "S3" * 10, "S4" * 10]}) + assert len(df) > 1 # every sub-args atom recovered after the retry + + +def test_async_fan_out_surfaces_fatal_over_transient(monkeypatch): + """A non-transient bug in one sub-request surfaces raw rather than + being masked behind a resumable interruption from a transient sibling.""" + monkeypatch.setenv("API_USGS_RETRIES", "2") + + monkeypatch.setattr(_chunking.asyncio, "sleep", _aiozero) + + async def fetch_async(args): + # One chunk carries a deterministic programmer error; the rest are + # transient. The real bug must win over the resumable transient. + if "S1" * 10 in args["sites"]: + raise ValueError("deterministic bug") + raise ServiceUnavailable("503: transient") + + fetch = _async_chunked_fetch(monkeypatch, fetch_async) + with pytest.raises(ValueError, match="deterministic bug"): + fetch({"sites": ["S1" * 10, "S2" * 10, "S3" * 10, "S4" * 10]}) + + +# --- finalize hook (resume finalizes; partials stay raw) ------------------- +# +# Regression for the bug where ``exc.call.resume()`` returned the chunker's +# raw ``(frame, httpx.Response)`` instead of the post-processed shape a normal +# getter call yields. The fix injects a ``finalize`` transform applied at the +# terminal resume()/resume_async() returns. The partial_* accessors stay RAW +# so building/inspecting a ChunkInterrupted never triggers finalize's side +# effects (for OGC, _deal_with_empty issues a schema network GET on an empty +# frame — that must NOT fire inside the exception constructor). + + +def test_resume_finalizes_but_partials_stay_raw(monkeypatch): + """resume() applies the injected ``finalize``; ``partial_frame`` / + ``partial_response`` are the raw snapshot, and constructing the + ``ChunkInterrupted`` must not invoke ``finalize`` at all (no side effects + such as a schema fetch in the exception ctor).""" + calls = {"finalize": 0} + + def finalize(frame, response): + # Stand in for the OGC pipeline: mark the frame and wrap the response. + calls["finalize"] += 1 + return frame.assign(finalized=True), ("METADATA", response) + + # Fail the 2nd issued sub-request once (the 1st completes, so partial + # state is non-empty), then succeed on resume. Conftest pins a single + # connection and no retries, so the failure surfaces immediately. + state = {"n": 0} + + @multi_value_chunked(build_request=_fake_build, url_limit=240) + async def fetch(args): + state["n"] += 1 + if state["n"] == 2: + raise ServiceUnavailable("503: simulated") + return pd.DataFrame({"id": [_atom_id(args)]}), _ok_response() + + sites = ["S1" * 10, "S2" * 10, "S3" * 10, "S4" * 10] + with pytest.raises(ServiceInterrupted) as exc_info: + fetch({"sites": sites}, finalize=finalize) + + interrupted = exc_info.value + assert interrupted.completed_chunks >= 1 + # Building the exception did NOT run finalize — no network/side effects. + assert calls["finalize"] == 0 + # Partial snapshot is the RAW combined frame/response (not finalized). + assert "finalized" not in interrupted.partial_frame.columns + assert not isinstance(interrupted.partial_response, tuple) + + # Resume DOES finalize and yields the same shape a normal call would. + df, md = interrupted.call.resume() + assert "finalized" in df.columns + assert md[0] == "METADATA" + assert calls["finalize"] >= 1 diff --git a/tests/waterdata_config_test.py b/tests/waterdata_config_test.py new file mode 100644 index 00000000..bd98f413 --- /dev/null +++ b/tests/waterdata_config_test.py @@ -0,0 +1,281 @@ +"""Tests for the layered config loader. + +Verifies the precedence chain (defaults → user file → local file → env vars +→ Python override) and the round-trip through the INI / configparser layer. +""" + +from __future__ import annotations + +import textwrap +from pathlib import Path + +import pytest + +from dataretrieval.waterdata import _config +from dataretrieval.waterdata._config import ( + CONCURRENCY_UNBOUNDED, + ConcurrencyPolicy, + RetryPolicy, + WaterDataConfig, + override, +) + +# ---- isolate from real config files --------------------------------------- + + +@pytest.fixture(autouse=True) +def _isolate(monkeypatch, tmp_path): + """Point user-config + local-config + cwd at empty tmp dirs so no real + file or shell env leaks in. Tests that want a layer install it explicitly + via :func:`_write` / :func:`monkeypatch.setenv`. + + Also clears the four ``API_USGS_*`` env vars (the autouse conftest + fixture sets two, but per-layer tests need a blank slate).""" + # Redirect user-config path. + user_dir = tmp_path / "user_config" + user_dir.mkdir() + monkeypatch.setattr(_config, "_user_config_path", lambda: user_dir / "config.cfg") + # Make Path.cwd() return the tmp dir so the local-config file lookup + # lands here too. + cwd = tmp_path / "cwd" + cwd.mkdir() + monkeypatch.chdir(cwd) + # Clear the env vars set by other autouse fixtures. + for name in ( + _config.ENV_RETRIES, + _config.ENV_CONCURRENT, + _config.ENV_PAT, + _config.ENV_PROGRESS, + ): + monkeypatch.delenv(name, raising=False) + + +def _write(path: Path, body: str) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text(textwrap.dedent(body).lstrip()) + + +# ---- layer 1: defaults ---------------------------------------------------- + + +def test_defaults_when_nothing_configured(): + cfg = WaterDataConfig.load() + assert cfg.retry == RetryPolicy() + assert cfg.concurrency == ConcurrencyPolicy() + assert cfg.api_token is None + assert cfg.progress is None + + +# ---- layer 2: user file --------------------------------------------------- + + +def test_user_file_overrides_defaults(): + _write( + _config._user_config_path(), + """ + [default] + api_token = user-token + progress = on + + [retry] + max_retries = 7 + base_backoff = 0.25 + + [concurrency] + max_connections = 32 + """, + ) + cfg = WaterDataConfig.load() + assert cfg.api_token == "user-token" + assert cfg.progress is True + assert cfg.retry.max_retries == 7 + assert cfg.retry.base_backoff == 0.25 + # Unset fields still take the dataclass defaults. + assert cfg.retry.max_backoff == _config._RETRY_MAX_BACKOFF + assert cfg.concurrency.max_connections == 32 + + +# ---- layer 3: local file overrides user file ------------------------------ + + +def test_local_file_overrides_user_file(): + _write( + _config._user_config_path(), + """ + [default] + api_token = user-token + + [retry] + max_retries = 7 + """, + ) + _write( + Path.cwd() / _config._LOCAL_CONFIG_NAME, + """ + [default] + api_token = local-token + + [retry] + max_retries = 99 + base_backoff = 1.5 + """, + ) + cfg = WaterDataConfig.load() + assert cfg.api_token == "local-token" + assert cfg.retry.max_retries == 99 + assert cfg.retry.base_backoff == 1.5 + + +# ---- layer 4: env overrides files ----------------------------------------- + + +def test_env_overrides_files(monkeypatch): + _write( + _config._user_config_path(), + """ + [default] + api_token = file-token + + [retry] + max_retries = 5 + + [concurrency] + max_connections = 8 + """, + ) + monkeypatch.setenv(_config.ENV_RETRIES, "2") + monkeypatch.setenv(_config.ENV_CONCURRENT, CONCURRENCY_UNBOUNDED) + monkeypatch.setenv(_config.ENV_PAT, "env-token") + monkeypatch.setenv(_config.ENV_PROGRESS, "off") + + cfg = WaterDataConfig.load() + assert cfg.api_token == "env-token" + assert cfg.progress is False + assert cfg.retry.max_retries == 2 # env wins over file's 5 + assert cfg.concurrency.max_connections is None # "unbounded" → None + + +def test_env_only_sets_what_it_provides(monkeypatch): + """An env var sets only its own field; other file-set fields are + preserved (the deep-update keeps sibling keys).""" + _write( + _config._user_config_path(), + """ + [retry] + max_retries = 5 + base_backoff = 1.0 + max_backoff = 10.0 + """, + ) + monkeypatch.setenv(_config.ENV_RETRIES, "2") + cfg = WaterDataConfig.load() + assert cfg.retry.max_retries == 2 # env overrides + assert cfg.retry.base_backoff == 1.0 # file-set, preserved + assert cfg.retry.max_backoff == 10.0 # file-set, preserved + + +# ---- layer 5: Python override wins above all ------------------------------ + + +def test_python_override_wins(monkeypatch): + monkeypatch.setenv(_config.ENV_RETRIES, "2") + custom = WaterDataConfig( + retry=RetryPolicy(max_retries=99, base_backoff=0.0, max_backoff=0.0) + ) + with override(custom): + # current() short-circuits to the active override (no file/env load). + assert _config.current() is custom + # And the legacy from_env() factories pick it up too. + assert RetryPolicy.from_env().max_retries == 99 + assert RetryPolicy.from_env().base_backoff == 0.0 + # On exit, the layered loader resumes. + assert RetryPolicy.from_env().max_retries == 2 # back to env + + +def test_override_is_contextvar_scoped(): + """Nested ``override`` blocks pop correctly; the outer override is + restored at inner exit.""" + outer = WaterDataConfig(retry=RetryPolicy(max_retries=1)) + inner = WaterDataConfig(retry=RetryPolicy(max_retries=2)) + with override(outer): + assert _config.current() is outer + with override(inner): + assert _config.current() is inner + assert _config.current() is outer + + +# ---- parsing / validation ------------------------------------------------- + + +def test_env_concurrency_unbounded_keyword(monkeypatch): + monkeypatch.setenv(_config.ENV_CONCURRENT, "UNBOUNDED") # case-insensitive + assert WaterDataConfig.load().concurrency.max_connections is None + + +def test_env_concurrency_invalid_value(monkeypatch): + monkeypatch.setenv(_config.ENV_CONCURRENT, "abc") + with pytest.raises(ValueError, match=_config.ENV_CONCURRENT): + WaterDataConfig.load() + + +def test_env_retries_negative_is_rejected(monkeypatch): + monkeypatch.setenv(_config.ENV_RETRIES, "-1") + with pytest.raises(ValueError): + WaterDataConfig.load() + + +def test_progress_parser_recognizes_truthy_falsy_and_auto(): + assert _config._parse_progress("on") is True + assert _config._parse_progress("true") is True + assert _config._parse_progress("1") is True + assert _config._parse_progress("off") is False + assert _config._parse_progress("FALSE") is False # case-insensitive + assert _config._parse_progress("auto") is None + assert _config._parse_progress("") is None + assert _config._parse_progress("nonsense") is None + + +def test_missing_files_are_silent(): + """No user or local file → no error, just falls through to defaults.""" + assert not _config._user_config_path().exists() + assert not (Path.cwd() / _config._LOCAL_CONFIG_NAME).exists() + cfg = WaterDataConfig.load() + assert cfg == WaterDataConfig() # all defaults + + +def test_unknown_keys_in_file_are_ignored(): + """A stray key in a sub-table shouldn't crash construction.""" + _write( + _config._user_config_path(), + """ + [retry] + max_retries = 3 + not_a_field = 42 + + [concurrency] + max_connections = 4 + also_not_a_field = "x" + """, + ) + cfg = WaterDataConfig.load() + assert cfg.retry.max_retries == 3 + assert cfg.concurrency.max_connections == 4 + + +# ---- direct dataclass validation ----------------------------------------- + + +def test_retry_policy_rejects_negative_settings(): + with pytest.raises(ValueError): + RetryPolicy(max_retries=-1) + with pytest.raises(ValueError): + RetryPolicy(base_backoff=-0.5) + + +def test_concurrency_policy_rejects_zero_or_negative(): + with pytest.raises(ValueError): + ConcurrencyPolicy(max_connections=0) + with pytest.raises(ValueError): + ConcurrencyPolicy(max_connections=-1) + # ``None`` is fine — explicit "unbounded". + ConcurrencyPolicy(max_connections=None) diff --git a/tests/waterdata_filters_test.py b/tests/waterdata_filters_test.py index 32879318..a447cada 100644 --- a/tests/waterdata_filters_test.py +++ b/tests/waterdata_filters_test.py @@ -6,6 +6,7 @@ import pandas as pd import pytest +from dataretrieval.waterdata import get_continuous from dataretrieval.waterdata.filters import ( _check_numeric_filter_pitfall, _split_top_level_or, @@ -152,12 +153,10 @@ def test_long_filter_fans_out_into_multiple_requests(): sub-requests via the joint planner; every original clause is preserved across sub-requests; results concatenate to one row per sub-request given the one-row-per-chunk mock.""" - from dataretrieval.waterdata import get_continuous - expr = _filter_chunking_clauses() sent_filters: list[str] = [] - def fake_walk_pages(*, geopd, req): + async def fake_walk_pages(*, geopd, req): idx = len(sent_filters) sent_filters.append(_query_params(req).get("filter", [None])[0]) return pd.DataFrame({"id": [f"chunk-{idx}"], "value": [idx]}), _fake_response() @@ -168,7 +167,8 @@ def fake_walk_pages(*, geopd, req): side_effect=_filter_size_aware_build, ), mock.patch( - "dataretrieval.waterdata.utils._walk_pages", side_effect=fake_walk_pages + "dataretrieval.waterdata.utils._walk_pages", + side_effect=fake_walk_pages, ), ): df, _ = get_continuous( @@ -190,12 +190,10 @@ def fake_walk_pages(*, geopd, req): def test_long_filter_deduplicates_cross_chunk_overlap(): """Features returned by multiple sub-requests with the same ``id`` are deduplicated in the concatenated result.""" - from dataretrieval.waterdata import get_continuous - expr = _filter_chunking_clauses() call_count = {"n": 0} - def fake_walk_pages(*_args, **_kwargs): + async def fake_walk_pages(*_args, **_kwargs): call_count["n"] += 1 return ( pd.DataFrame({"id": ["shared-feature"], "value": [1]}), @@ -208,7 +206,8 @@ def fake_walk_pages(*_args, **_kwargs): side_effect=_filter_size_aware_build, ), mock.patch( - "dataretrieval.waterdata.utils._walk_pages", side_effect=fake_walk_pages + "dataretrieval.waterdata.utils._walk_pages", + side_effect=fake_walk_pages, ), ): df, _ = get_continuous( @@ -232,12 +231,10 @@ def test_empty_chunks_do_not_downgrade_geodataframe(): import geopandas as gpd from shapely.geometry import Point - from dataretrieval.waterdata import get_continuous - expr = _filter_chunking_clauses() call_count = {"n": 0} - def fake_walk_pages(*_args, **_kwargs): + async def fake_walk_pages(*_args, **_kwargs): call_count["n"] += 1 if call_count["n"] == 2: return pd.DataFrame(), _fake_response() @@ -256,7 +253,8 @@ def fake_walk_pages(*_args, **_kwargs): side_effect=_filter_size_aware_build, ), mock.patch( - "dataretrieval.waterdata.utils._walk_pages", side_effect=fake_walk_pages + "dataretrieval.waterdata.utils._walk_pages", + side_effect=fake_walk_pages, ), ): df, _ = get_continuous( @@ -273,8 +271,6 @@ def fake_walk_pages(*_args, **_kwargs): def test_cql_json_filter_is_not_chunked(): """Chunking applies only to cql-text; cql-json is passed through unchanged.""" - from dataretrieval.waterdata import get_continuous - clause = "(time >= '2023-01-01T00:00:00Z' AND time <= '2023-01-01T00:30:00Z')" expr = " OR ".join([clause] * 300) sent_filters = [] @@ -290,9 +286,11 @@ def fake_construct_api_requests(**kwargs): ), mock.patch( "dataretrieval.waterdata.utils._walk_pages", - return_value=( - pd.DataFrame({"id": ["row-1"], "value": [1]}), - _fake_response(), + new=mock.AsyncMock( + return_value=( + pd.DataFrame({"id": ["row-1"], "value": [1]}), + _fake_response(), + ) ), ), ): @@ -428,8 +426,6 @@ def test_get_continuous_surfaces_pitfall_to_caller(): """End-to-end: the check runs at the ``get_continuous`` boundary, not as a deep internal-only protection, so callers see the error before any HTTP traffic.""" - from dataretrieval.waterdata import get_continuous - with mock.patch("dataretrieval.waterdata.utils._construct_api_requests") as build: with pytest.raises(ValueError, match="lexicographic"): get_continuous( diff --git a/tests/waterdata_progress_test.py b/tests/waterdata_progress_test.py index faa61630..08f6ca26 100644 --- a/tests/waterdata_progress_test.py +++ b/tests/waterdata_progress_test.py @@ -6,12 +6,15 @@ reporter. """ +import asyncio +import datetime import io import sys import types from unittest import mock import httpx +import pandas as pd import pytest from dataretrieval.waterdata import _progress @@ -20,7 +23,20 @@ current, progress_context, ) -from dataretrieval.waterdata.utils import _walk_pages +from dataretrieval.waterdata.chunking import ChunkedCall, ChunkPlan +from dataretrieval.waterdata.utils import _paginate, _walk_pages + + +def _run_walk_pages(*, geopd, req, client): + """Drive the async ``_walk_pages`` to completion synchronously. + + The chunker core is async-only now, so these tests build an + ``AsyncMock(spec=httpx.AsyncClient)`` whose ``.send``/``.request`` are + awaitable and run the coroutine via ``asyncio.run``. The progress + reporter is bound on a contextvar, which the coroutine inherits when + ``asyncio.run`` copies the calling context. + """ + return asyncio.run(_walk_pages(geopd=geopd, req=req, client=client)) @pytest.fixture(autouse=True) @@ -65,6 +81,56 @@ def test_page_count_is_pluralized(): assert "2 pages" in stream.getvalue() +def test_note_retry_renders_then_clears_on_next_page(): + stream = io.StringIO() + reporter = ProgressReporter(stream=stream, enabled=True) + reporter.set_chunks(3) + reporter.start_chunk(1) + reporter.note_retry(attempt=2, wait=8.0) + assert "retrying (attempt 2, waiting 8s)" in stream.getvalue() + # The next page redraws without the note (last frame is after the + # final carriage return). + reporter.add_page(rows=5) + assert "retrying" not in stream.getvalue().rsplit("\r", 1)[-1] + + +def test_note_retry_subsecond_wait_shows_decimal(): + # A sub-second backoff must not collapse to a misleading "0s". + stream = io.StringIO() + reporter = ProgressReporter(stream=stream, enabled=True) + reporter.note_retry(attempt=1, wait=0.3) + out = stream.getvalue() + assert "waiting 0.3s" in out and "waiting 0s" not in out + + +def test_note_retry_cleared_on_close(): + # An exhausted retry leaves retry_note set with no following page; + # close() must clear it so the persisted last line isn't a stale note. + stream = io.StringIO() + reporter = ProgressReporter(stream=stream, enabled=True) + reporter.add_page(rows=1) + reporter.note_retry(attempt=3, wait=5.0) + reporter.close() + assert "retrying" not in stream.getvalue().rsplit("\r", 1)[-1] + + +def test_note_retry_is_noop_when_disabled(): + stream = io.StringIO() + reporter = ProgressReporter(stream=stream, enabled=False) + reporter.note_retry(attempt=1, wait=1.0) + assert stream.getvalue() == "" + + +def test_note_retry_accepts_integer_wait(): + # An int ``wait`` (e.g. whole seconds) must render without raising: + # ``round(int, 1)`` returns an int and ``int.is_integer()`` only exists + # on Python 3.12+, while the package floor is 3.9. Renders like the float. + stream = io.StringIO() + reporter = ProgressReporter(stream=stream, enabled=True) + reporter.note_retry(attempt=1, wait=5) + assert "retrying (attempt 1, waiting 5s)" in stream.getvalue() + + def test_chunk_segment_only_shown_when_multiple_chunks(): single = io.StringIO() reporter = ProgressReporter(stream=single, enabled=True) @@ -305,7 +371,7 @@ def test_walk_pages_reports_pages_and_rate_limit(): ) resp2 = _resp([{"id": "2", "properties": {"v": "b"}}], rate_remaining="4998") - client = mock.MagicMock(spec=httpx.Client) + client = mock.AsyncMock(spec=httpx.AsyncClient) client.send.return_value = resp1 client.request.return_value = resp2 @@ -316,7 +382,7 @@ def test_walk_pages_reports_pages_and_rate_limit(): stream = io.StringIO() with progress_context(service="daily", stream=stream, enabled=True): - df, _ = _walk_pages(geopd=False, req=req, client=client) + df, _ = _run_walk_pages(geopd=False, req=req, client=client) assert len(df) == 2 out = stream.getvalue() @@ -330,7 +396,7 @@ def test_walk_pages_reports_pages_and_rate_limit(): def test_walk_pages_without_context_does_not_error(): # No active reporter: pagination must still work and stay silent. resp = _resp([{"id": "1", "properties": {"v": "a"}}]) - client = mock.MagicMock(spec=httpx.Client) + client = mock.AsyncMock(spec=httpx.AsyncClient) client.send.return_value = resp req = mock.MagicMock(spec=httpx.Request) @@ -338,7 +404,7 @@ def test_walk_pages_without_context_does_not_error(): req.headers = {} req.url = "https://example.com/p1" - df, _ = _walk_pages(geopd=False, req=req, client=client) + df, _ = _run_walk_pages(geopd=False, req=req, client=client) assert len(df) == 1 assert current() is None @@ -350,7 +416,7 @@ def test_broken_progress_stream_does_not_truncate_pagination(): [{"id": "1", "properties": {"v": "a"}}], next_url="https://example.com/p2" ) resp2 = _resp([{"id": "2", "properties": {"v": "b"}}]) - client = mock.MagicMock(spec=httpx.Client) + client = mock.AsyncMock(spec=httpx.AsyncClient) client.send.return_value = resp1 client.request.return_value = resp2 @@ -360,6 +426,116 @@ def test_broken_progress_stream_does_not_truncate_pagination(): req.url = "https://example.com/p1" with progress_context(stream=_RaisingStream(), enabled=True): - df, _ = _walk_pages(geopd=False, req=req, client=client) + df, _ = _run_walk_pages(geopd=False, req=req, client=client) assert len(df) == 2 # both pages returned despite the broken progress stream + + +# -- pagination integration ---------------------------------------------------- + + +def test_paginate_reports_pages_through_active_reporter(monkeypatch): + """The async paginate path must drive the same progress reporter. + Pages and rate-limit updates from each completed page should land + via the active ``ProgressReporter``, exactly as they would on + ``_walk_pages``.""" + resp1 = _resp( + [{"id": "1", "properties": {"v": "a"}}], + next_url="https://example.com/p2", + rate_remaining="4999", + ) + resp2 = _resp([{"id": "2", "properties": {"v": "b"}}], rate_remaining="4998") + + async def parse_response(resp): + body = resp.json() + nxt = next( + (link["href"] for link in body["links"] if link["rel"] == "next"), None + ) + return mock.MagicMock(empty=False, __len__=lambda self: 1), nxt + + # parse_response is sync (like the page parsers). + def parse_sync(resp): + body = resp.json() + nxt = next( + (link["href"] for link in body["links"] if link["rel"] == "next"), None + ) + return pd.DataFrame(body["features"]), nxt + + async def follow_up(cursor, sess): + return resp2 + + client = mock.AsyncMock(spec=httpx.AsyncClient) + client.send.return_value = resp1 + + req = mock.MagicMock(spec=httpx.Request) + req.method = "GET" + req.headers = {} + req.url = "https://example.com/p1" + + stream = io.StringIO() + + async def run(): + with progress_context(service="continuous", stream=stream, enabled=True): + df, _ = await _paginate( + req, + parse_response=parse_sync, + follow_up=follow_up, + client=client, + ) + return df + + df = asyncio.run(run()) + assert len(df) == 2 + out = stream.getvalue() + assert "Retrieving: continuous ·" in out + assert "2 pages" in out + assert "4,998 requests remaining" in out + assert out.endswith("\n") + + +def test_fan_out_async_sets_chunks_on_active_reporter(monkeypatch): + """The async fan-out core (``ChunkedCall._run``) records + ``plan.total`` on the active reporter so the progress line knows how + many sub-requests are in flight, and ticks ``current_chunk`` via + ``start_chunk(len(completed))`` as each gathered sub-request finishes + — reaching ``plan.total`` in the all-success case.""" + + # Fake build_request whose URL length scales with the sites list, + # mirroring the planner's _request_bytes contract. _FakeReq has the + # same shape as httpx.Request for sizing purposes. + class _FakeReq: + __slots__ = ("url", "content") + + def __init__(self, url): + self.url = url + self.content = b"" + + def build(*, sites): + return _FakeReq("x" * (200 + len(",".join(sites)))) + + sites = ["S" * 10 + str(i) for i in range(4)] + plan = ChunkPlan({"sites": sites}, build, url_limit=240) + assert plan.total > 1, "test setup error: plan must fan out" + + async def fetch_async(args): + return pd.DataFrame({"id": [",".join(args["sites"])]}), mock.Mock( + elapsed=datetime.timedelta(seconds=0.01), + headers={"x-ratelimit-remaining": "999"}, + ) + + stream = io.StringIO() + + async def run(): + # Drive the async execution core directly (the same coroutine the + # sync ``resume()`` facade runs through the anyio portal). + with progress_context(service="daily", stream=stream, enabled=True) as rep: + await ChunkedCall(plan, fetch_async)._run(4) + return rep.total_chunks, rep.current_chunk + + total_recorded, current_recorded = asyncio.run(run()) + assert total_recorded == plan.total + # Each sub-request that completes bumps current_chunk via + # start_chunk(len(completed)), so by the time the gather finishes + # current_chunk reflects the total number of successful chunks — + # plan.total in the all-success case. + assert current_recorded == plan.total diff --git a/tests/waterdata_test.py b/tests/waterdata_test.py index 09f66aa5..9e0d4c70 100644 --- a/tests/waterdata_test.py +++ b/tests/waterdata_test.py @@ -3,6 +3,7 @@ import sys from unittest import mock +import numpy as np import pandas as pd import pytest from pandas import DataFrame @@ -551,6 +552,23 @@ def test_get_reference_table_wrong_name(): get_reference_table("agency-cod") +@pytest.mark.parametrize("bad", [0, -1, 2.5, 10.0, True]) +def test_get_reference_table_rejects_bad_max_rows(bad): + # max_rows must be a genuine positive int; a non-positive value, a float + # (even integral like 10.0), or a bool must raise ValueError up front — + # not crash later inside pandas .head(). Raises before any HTTP request. + with pytest.raises(ValueError, match="positive integer"): + get_reference_table("agency-codes", max_rows=bad) + + +def test_get_reference_table_accepts_numpy_int_max_rows(): + # numpy integers are valid caps: isinstance(np.int64, int) is False, so the + # validation must accept numbers.Integral (not just int) — otherwise a cap + # derived from a numpy/pandas computation is wrongly rejected. + df, _ = get_reference_table("agency-codes", max_rows=np.int64(2)) + assert len(df) == 2 + + def test_get_stats_por(): df, _ = get_stats_por( monitoring_location_id="USGS-12451000", @@ -693,8 +711,6 @@ def test_pandas_series_normalizes_to_list(self): assert isinstance(result, list) def test_numpy_array_normalizes_to_list(self): - import numpy as np - result = _normalize_str_iterable(np.array(["00060", "00010"]), "p") assert result == ["00060", "00010"] assert isinstance(result, list) diff --git a/tests/waterdata_utils_test.py b/tests/waterdata_utils_test.py index bb5ece10..7063b767 100644 --- a/tests/waterdata_utils_test.py +++ b/tests/waterdata_utils_test.py @@ -1,3 +1,5 @@ +import asyncio +import datetime import json import logging from unittest import mock @@ -9,19 +11,38 @@ import dataretrieval.waterdata.utils as _utils_module from dataretrieval.waterdata.chunking import RateLimited, ServiceUnavailable from dataretrieval.waterdata.utils import ( + OGC_API_URL, _arrange_cols, + _check_ogc_requests, _error_body, + _finalize_ogc, _format_api_dates, _get_args, + _get_resp_data, _handle_stats_nesting, + _next_req_url, _parse_retry_after, _raise_for_non_200, + _row_cap, _walk_pages, + get_stats_data, ) _LOGGER_NAME = _utils_module.__name__ +def _run_walk_pages(*, geopd, req, client): + """Drive the async ``_walk_pages`` to completion synchronously. + + The chunker core is async-only now, so these tests build an + ``AsyncMock(spec=httpx.AsyncClient)`` whose ``.send``/``.request`` are + awaitable and run the coroutine via ``asyncio.run``. This thin shim + keeps the historical sync-shaped call sites terse while exercising the + real async pagination loop. + """ + return asyncio.run(_walk_pages(geopd=geopd, req=req, client=client)) + + def test_get_args_basic(): local_vars = { "monitoring_location_id": "USGS-123", @@ -74,7 +95,7 @@ def test_walk_pages_multiple_mocked(): resp2.status_code = 200 # Mock client (Session) - mock_client = mock.MagicMock(spec=httpx.Client) + mock_client = mock.AsyncMock(spec=httpx.AsyncClient) # First call to send() returns resp1, then call to request() in loop returns resp2 mock_client.send.return_value = resp1 mock_client.request.return_value = resp2 @@ -86,7 +107,7 @@ def test_walk_pages_multiple_mocked(): mock_req.url = "https://example.com/page1" # Call _walk_pages - df, final_resp = _walk_pages(geopd=False, req=mock_req, client=mock_client) + df, final_resp = _run_walk_pages(geopd=False, req=mock_req, client=mock_client) assert len(df) == 2 assert list(df["val"]) == ["a", "b"] @@ -96,6 +117,91 @@ def test_walk_pages_multiple_mocked(): assert mock_client.request.call_args[0][1] == "https://example.com/page2" +def test_row_cap_truncates_and_stops_within_first_page(): + # Regression for BUG 2: ``_row_cap`` bounds the TOTAL rows. A first page + # already over the cap is truncated to exactly ``max_rows`` and the + # ``next`` link is never followed. + resp1 = mock.MagicMock() + resp1.json.return_value = { + "numberReturned": 3, + "features": [{"id": str(i), "properties": {"val": i}} for i in range(3)], + "links": [{"rel": "next", "href": "https://example.com/page2"}], + } + resp1.headers = {} + resp1.status_code = 200 + resp1.url = "https://example.com/page1" + + mock_client = mock.AsyncMock(spec=httpx.AsyncClient) + mock_client.send.return_value = resp1 + + mock_req = mock.MagicMock(spec=httpx.Request) + mock_req.method = "GET" + mock_req.headers = {} + mock_req.url = "https://example.com/page1" + + with _row_cap(2): + df, _ = _run_walk_pages(geopd=False, req=mock_req, client=mock_client) + + assert len(df) == 2 # truncated to the cap, not the page's 3 rows + assert not mock_client.request.called # ``next`` link never followed + + +def test_row_cap_stops_across_pages(): + # The cap accumulates across pages: page 1 (1 row) is under the cap so + # page 2 is fetched; once the cap (2) is met the third page is NOT. + def _page(idx, *, has_next): + resp = mock.MagicMock() + nxt = f"https://example.com/page{idx + 1}" + resp.json.return_value = { + "numberReturned": 1, + "features": [{"id": str(idx), "properties": {"val": idx}}], + "links": [{"rel": "next", "href": nxt}] if has_next else [], + } + resp.headers = {} + resp.status_code = 200 + resp.url = f"https://example.com/page{idx}" + return resp + + mock_client = mock.AsyncMock(spec=httpx.AsyncClient) + mock_client.send.return_value = _page(1, has_next=True) + # page 2 still advertises a ``next`` (page 3) that must never be fetched. + mock_client.request.return_value = _page(2, has_next=True) + + mock_req = mock.MagicMock(spec=httpx.Request) + mock_req.method = "GET" + mock_req.headers = {} + mock_req.url = "https://example.com/page1" + + with _row_cap(2): + df, _ = _run_walk_pages(geopd=False, req=mock_req, client=mock_client) + + assert len(df) == 2 + assert mock_client.request.call_count == 1 # fetched page 2, stopped before 3 + + +def test_finalize_ogc_truncates_combined_to_max_rows(): + # max_rows is enforced on the *combined* frame in _finalize_ogc (after + # dedup/sort), so it bounds the total exactly even when a chunked call's + # per-sub-request pages overshoot the per-_paginate early-stop. + frame = pd.DataFrame({"id": [str(i) for i in range(10)]}) + resp = mock.MagicMock() + resp.url = "https://example.com/q" + resp.elapsed = datetime.timedelta(seconds=0.1) + resp.headers = {} + + df, md = _finalize_ogc( + frame, + resp, + properties=None, + output_id="thing_id", + convert_type=False, + service="things", + max_rows=3, + ) + assert len(df) == 3 + assert hasattr(md, "url") # wrapped as BaseMetadata + + def _resp_ok(features): """Build a 200-OK mock response carrying the given features list.""" links = [{"rel": "next", "href": "https://example.com/page2"}] if features else [] @@ -115,7 +221,7 @@ def _walk_pages_with_failure(failure_resp_or_exc): """Run _walk_pages where page 1 succeeds and page 2 fails as given.""" resp1 = _resp_ok([{"id": "1", "properties": {"val": "a"}}]) - mock_client = mock.MagicMock(spec=httpx.Client) + mock_client = mock.AsyncMock(spec=httpx.AsyncClient) mock_client.send.return_value = resp1 if isinstance(failure_resp_or_exc, BaseException): mock_client.request.side_effect = failure_resp_or_exc @@ -127,7 +233,7 @@ def _walk_pages_with_failure(failure_resp_or_exc): mock_req.headers = {} mock_req.url = "https://example.com/page1" - return _walk_pages(geopd=False, req=mock_req, client=mock_client) + return _run_walk_pages(geopd=False, req=mock_req, client=mock_client) def test_walk_pages_raises_on_connection_error_mid_pagination(): @@ -206,7 +312,7 @@ def test_walk_pages_wraps_initial_page_parse_error(): # Body is unparseable JSON (gateway HTML page, truncated stream). resp.json.side_effect = json.JSONDecodeError("Expecting value", "...", 0) - mock_client = mock.MagicMock(spec=httpx.Client) + mock_client = mock.AsyncMock(spec=httpx.AsyncClient) mock_client.send.return_value = resp mock_req = mock.MagicMock(spec=httpx.Request) @@ -215,7 +321,7 @@ def test_walk_pages_wraps_initial_page_parse_error(): mock_req.url = "https://example.com/page1" with pytest.raises(RuntimeError, match="Paginated request failed") as excinfo: - _walk_pages(geopd=False, req=mock_req, client=mock_client) + _run_walk_pages(geopd=False, req=mock_req, client=mock_client) # The JSONDecodeError causing it is on __cause__ so callers can drill in. assert isinstance(excinfo.value.__cause__, json.JSONDecodeError) @@ -228,8 +334,6 @@ def test_get_resp_data_handles_missing_features_key(): ``_paginate`` as a generic transport error. ``_handle_stats_nesting`` was already hardened against this; ``_get_resp_data`` now mirrors that defensiveness and returns an empty frame instead.""" - from dataretrieval.waterdata.utils import _get_resp_data - resp = mock.Mock() resp.json.return_value = {"numberReturned": 1, "links": []} df = _get_resp_data(resp, geopd=False) @@ -244,12 +348,10 @@ def test_walk_pages_does_not_mutate_initial_response(): ``.elapsed`` before pagination completed (a Session response hook, a logging middleware) must continue to see the original first-page values — NOT the rewritten cumulative values.""" - import datetime as _dt - page1 = mock.MagicMock() page1.status_code = 200 page1.url = "https://example.com/page1" - page1.elapsed = _dt.timedelta(seconds=1) + page1.elapsed = datetime.timedelta(seconds=1) page1.headers = {"x-ratelimit-remaining": "999"} page1.json.return_value = { "numberReturned": 1, @@ -262,7 +364,7 @@ def test_walk_pages_does_not_mutate_initial_response(): page2 = mock.MagicMock() page2.status_code = 200 page2.url = "https://example.com/page2" - page2.elapsed = _dt.timedelta(seconds=2) + page2.elapsed = datetime.timedelta(seconds=2) page2.headers = {"x-ratelimit-remaining": "998"} page2.json.return_value = { "numberReturned": 1, @@ -270,7 +372,7 @@ def test_walk_pages_does_not_mutate_initial_response(): "links": [], } - mock_client = mock.MagicMock(spec=httpx.Client) + mock_client = mock.AsyncMock(spec=httpx.AsyncClient) mock_client.send.return_value = page1 mock_client.request.return_value = page2 @@ -279,7 +381,7 @@ def test_walk_pages_does_not_mutate_initial_response(): mock_req.headers = {} mock_req.url = "https://example.com/page1" - df, final = _walk_pages(geopd=False, req=mock_req, client=mock_client) + df, final = _run_walk_pages(geopd=False, req=mock_req, client=mock_client) assert len(df) == 2 # The original first-page response object must be unmutated: @@ -290,7 +392,7 @@ def test_walk_pages_does_not_mutate_initial_response(): # The returned aggregate carries page-2 headers + cumulative elapsed. assert final.headers["x-ratelimit-remaining"] == "998" - assert final.elapsed == _dt.timedelta(seconds=3) + assert final.elapsed == datetime.timedelta(seconds=3) # And mutating the aggregate's headers doesn't leak into either page. final.headers["X-Trace-Id"] = "abc" assert "X-Trace-Id" not in page1.headers @@ -316,15 +418,13 @@ def _run_get_stats_data_with_failure(failure_resp_or_exc, monkeypatch): `monkeypatch` stubs ``_handle_stats_nesting`` so the synthetic minimal response body doesn't need to parse — these tests only assert on the pagination loop's error surfacing.""" - from dataretrieval.waterdata.utils import get_stats_data - monkeypatch.setattr( _utils_module, "_handle_stats_nesting", mock.MagicMock(return_value=pd.DataFrame()), ) - mock_client = mock.MagicMock(spec=httpx.Client) + mock_client = mock.AsyncMock(spec=httpx.AsyncClient) mock_client.send.return_value = _stats_initial_ok() if isinstance(failure_resp_or_exc, BaseException): mock_client.request.side_effect = failure_resp_or_exc @@ -441,8 +541,6 @@ def test_get_resp_data_empty_preserves_geopd_type(): ``GeoDataFrame`` (not a plain ``DataFrame``) when geopd is True, so paginating across a sparse intermediate page doesn't downgrade the final concat result.""" - from dataretrieval.waterdata.utils import _get_resp_data - fake_gpd = mock.MagicMock() class _Sentinel: @@ -472,8 +570,6 @@ def test_get_resp_data_always_materializes_id_column(): ``_arrange_cols`` rename to the service-specific output_id (``daily_id``, ``channel_measurements_id``, etc.) isn't a silent no-op.""" - from dataretrieval.waterdata.utils import _get_resp_data - resp = mock.MagicMock() resp.json.return_value = { "numberReturned": 2, @@ -598,8 +694,6 @@ def test_format_api_dates_rejects_mapping(): """`time={"2024-01-01": "x"}` would silently materialize as the keys list, accepting input the user clearly didn't intend. """ - import pytest - with pytest.raises(TypeError, match="date input must be a string or sequence"): _format_api_dates({"2024-01-01": "ignored"}) @@ -742,8 +836,6 @@ def test_next_req_url_rejects_cross_host(): auth-like artifacts) were minted for the original host; following a server-supplied cross-host URL would leak them — and the URL itself could be sensitive.""" - from dataretrieval.waterdata.utils import _next_req_url - resp = mock.MagicMock() resp.url = httpx.URL("https://api.waterdata.usgs.gov/page1") body = { @@ -761,9 +853,6 @@ def test_check_ogc_requests_raises_typed_on_5xx(httpx_mock): ``_raise_for_non_200`` so callers see ``ServiceUnavailable`` / ``RateLimited`` / ``RuntimeError`` — the same typed contract as the main data path.""" - from dataretrieval.waterdata.chunking import ServiceUnavailable - from dataretrieval.waterdata.utils import OGC_API_URL, _check_ogc_requests - httpx_mock.add_response( method="GET", url=f"{OGC_API_URL}/collections/daily/schema",