From 59f79cd518bc9793960dcdfa28b4e6b04006d5cd Mon Sep 17 00:00:00 2001 From: VectorPeak <73048950+VectorPeak@users.noreply.github.com> Date: Tue, 30 Jun 2026 14:28:54 +0800 Subject: [PATCH 1/2] fix: reject non-200 download responses --- astrbot/core/utils/io.py | 14 ++++++ tests/unit/test_io_download_file.py | 76 +++++++++++++++++++++++++++++ 2 files changed, 90 insertions(+) create mode 100644 tests/unit/test_io_download_file.py diff --git a/astrbot/core/utils/io.py b/astrbot/core/utils/io.py index 3e9c20e323..498fa188fb 100644 --- a/astrbot/core/utils/io.py +++ b/astrbot/core/utils/io.py @@ -215,6 +215,10 @@ async def download_file( _safe_url_for_log(url), resp.status, ) + raise RuntimeError( + "Failed to download file from " + f"{_safe_url_for_log(url)}. HTTP status code: {resp.status}" + ) total_size = int(resp.headers.get("content-length", 0)) downloaded_size = 0 start_time = time.time() @@ -291,6 +295,16 @@ async def download_file( ssl_context.verify_mode = ssl.CERT_NONE async with aiohttp.ClientSession() as session: async with session.get(url, ssl=ssl_context, timeout=120) as resp: + if resp.status != 200: + logger.error( + "Failed to download file from %s. HTTP status code: %s", + _safe_url_for_log(url), + resp.status, + ) + raise RuntimeError( + "Failed to download file from " + f"{_safe_url_for_log(url)}. HTTP status code: {resp.status}" + ) total_size = int(resp.headers.get("content-length", 0)) downloaded_size = 0 start_time = time.time() diff --git a/tests/unit/test_io_download_file.py b/tests/unit/test_io_download_file.py new file mode 100644 index 0000000000..b547793c1c --- /dev/null +++ b/tests/unit/test_io_download_file.py @@ -0,0 +1,76 @@ +import pytest + +from astrbot.core.utils import io + + +class _FakeContent: + def __init__(self, chunks: list[bytes]): + self._chunks = chunks + + async def read(self, _size: int) -> bytes: + if self._chunks: + return self._chunks.pop(0) + return b"" + + +class _FakeResponse: + def __init__(self, *, status: int, chunks: list[bytes]): + self.status = status + self.headers = {"content-length": str(sum(len(chunk) for chunk in chunks))} + self.content = _FakeContent(chunks) + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + return False + + +class _FakeSession: + def __init__(self, response: _FakeResponse): + self._response = response + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + return False + + def get(self, *_args, **_kwargs): + return self._response + + +def _patch_download_session(monkeypatch, response: _FakeResponse): + monkeypatch.setattr(io.aiohttp, "TCPConnector", lambda **_kwargs: object()) + monkeypatch.setattr( + io.aiohttp, + "ClientSession", + lambda **_kwargs: _FakeSession(response), + ) + + +@pytest.mark.asyncio +async def test_download_file_rejects_non_200_response(monkeypatch, tmp_path): + target_path = tmp_path / "missing.bin" + _patch_download_session( + monkeypatch, + _FakeResponse(status=404, chunks=[b"not found"]), + ) + + with pytest.raises(RuntimeError, match="HTTP status code: 404"): + await io.download_file("https://example.test/missing", str(target_path)) + + assert not target_path.exists() + + +@pytest.mark.asyncio +async def test_download_file_writes_successful_response(monkeypatch, tmp_path): + target_path = tmp_path / "ok.bin" + _patch_download_session( + monkeypatch, + _FakeResponse(status=200, chunks=[b"hello", b" world"]), + ) + + await io.download_file("https://example.test/ok.bin", str(target_path)) + + assert target_path.read_bytes() == b"hello world" From 403bbb6b12992fade8d429a70b66dcc39b1ae1a6 Mon Sep 17 00:00:00 2001 From: VectorPeak <73048950+VectorPeak@users.noreply.github.com> Date: Tue, 30 Jun 2026 14:46:06 +0800 Subject: [PATCH 2/2] fix: share download response handling --- astrbot/core/utils/io.py | 244 +++++++++++++--------------- tests/unit/test_io_download_file.py | 37 ++++- 2 files changed, 146 insertions(+), 135 deletions(-) diff --git a/astrbot/core/utils/io.py b/astrbot/core/utils/io.py index 498fa188fb..0c7ebcd985 100644 --- a/astrbot/core/utils/io.py +++ b/astrbot/core/utils/io.py @@ -178,6 +178,101 @@ async def _emit_download_progress(progress_callback, payload: dict) -> None: await result +class DownloadFileHTTPError(RuntimeError): + """Raised when a file download returns an unsuccessful HTTP status.""" + + +def _raise_for_download_status(resp, url: str) -> None: + if resp.status == 200: + return + logger.error( + "Failed to download file from %s. HTTP status code: %s", + _safe_url_for_log(url), + resp.status, + ) + raise DownloadFileHTTPError( + "Failed to download file from " + f"{_safe_url_for_log(url)}. HTTP status code: {resp.status}" + ) + + +async def _download_response_to_file( + resp, + file_obj, + url: str, + show_progress: bool, + progress_callback, + show_downloading_label: bool = True, +) -> None: + """Write a successful download response to a local file. + + Args: + resp: aiohttp response object to read from. + file_obj: Open writable binary file object. + url: Source URL used for progress events and sanitized errors. + show_progress: Whether to print progress to stdout. + progress_callback: Optional callback for progress payloads. + show_downloading_label: Whether to use the standard download heading. + + """ + + total_size = int(resp.headers.get("content-length", 0)) + downloaded_size = 0 + start_time = time.time() + if show_progress: + if show_downloading_label: + print( + f"Downloading: {_safe_url_for_log(url)} | " + f"Size: {total_size / 1024:.2f} KB" + ) + else: + print(f"Size: {total_size / 1024:.2f} KB | URL: {_safe_url_for_log(url)}") + await _emit_download_progress( + progress_callback, + { + "url": url, + "downloaded": 0, + "total": total_size, + "percent": 0, + "speed": 0, + }, + ) + while True: + chunk = await resp.content.read(8192) + if not chunk: + break + file_obj.write(chunk) + downloaded_size += len(chunk) + elapsed_time = time.time() - start_time if time.time() - start_time > 0 else 1 + speed = downloaded_size / 1024 / elapsed_time # KB/s + percent = downloaded_size / total_size if total_size > 0 else 0 + await _emit_download_progress( + progress_callback, + { + "url": url, + "downloaded": downloaded_size, + "total": total_size, + "percent": percent, + "speed": speed, + }, + ) + if show_progress: + print( + f"\rProgress: {percent:.2%} Speed: {speed:.2f} KB/s", + end="", + ) + await _emit_download_progress( + progress_callback, + { + "url": url, + "downloaded": downloaded_size, + "total": total_size, + "percent": 1, + "speed": 0, + }, + ) + + async def download_file( url: str, path: str, @@ -209,73 +304,15 @@ async def download_file( connector=connector, ) as session: async with session.get(url, timeout=1800) as resp: - if resp.status != 200: - logger.error( - "Failed to download file from %s. HTTP status code: %s", - _safe_url_for_log(url), - resp.status, - ) - raise RuntimeError( - "Failed to download file from " - f"{_safe_url_for_log(url)}. HTTP status code: {resp.status}" - ) - total_size = int(resp.headers.get("content-length", 0)) - downloaded_size = 0 - start_time = time.time() - if show_progress: - print( - f"Downloading: {_safe_url_for_log(url)} | " - f"Size: {total_size / 1024:.2f} KB" - ) - await _emit_download_progress( - progress_callback, - { - "url": url, - "downloaded": 0, - "total": total_size, - "percent": 0, - "speed": 0, - }, - ) + _raise_for_download_status(resp, url) with open(path, "wb") as f: - while True: - chunk = await resp.content.read(8192) - if not chunk: - break - f.write(chunk) - downloaded_size += len(chunk) - elapsed_time = ( - time.time() - start_time - if time.time() - start_time > 0 - else 1 - ) - speed = downloaded_size / 1024 / elapsed_time # KB/s - percent = downloaded_size / total_size if total_size > 0 else 0 - await _emit_download_progress( - progress_callback, - { - "url": url, - "downloaded": downloaded_size, - "total": total_size, - "percent": percent, - "speed": speed, - }, - ) - if show_progress: - print( - f"\rProgress: {percent:.2%} Speed: {speed:.2f} KB/s", - end="", - ) - await _emit_download_progress( - progress_callback, - { - "url": url, - "downloaded": downloaded_size, - "total": total_size, - "percent": 1, - "speed": 0, - }, - ) + await _download_response_to_file( + resp, + f, + url, + show_progress, + progress_callback, + ) except (aiohttp.ClientConnectorSSLError, aiohttp.ClientConnectorCertificateError): if not allow_insecure_ssl_fallback: raise @@ -295,73 +332,16 @@ async def download_file( ssl_context.verify_mode = ssl.CERT_NONE async with aiohttp.ClientSession() as session: async with session.get(url, ssl=ssl_context, timeout=120) as resp: - if resp.status != 200: - logger.error( - "Failed to download file from %s. HTTP status code: %s", - _safe_url_for_log(url), - resp.status, - ) - raise RuntimeError( - "Failed to download file from " - f"{_safe_url_for_log(url)}. HTTP status code: {resp.status}" - ) - total_size = int(resp.headers.get("content-length", 0)) - downloaded_size = 0 - start_time = time.time() - if show_progress: - print( - f"Size: {total_size / 1024:.2f} KB | " - f"URL: {_safe_url_for_log(url)}" - ) - await _emit_download_progress( - progress_callback, - { - "url": url, - "downloaded": 0, - "total": total_size, - "percent": 0, - "speed": 0, - }, - ) + _raise_for_download_status(resp, url) with open(path, "wb") as f: - while True: - chunk = await resp.content.read(8192) - if not chunk: - break - f.write(chunk) - downloaded_size += len(chunk) - elapsed_time = ( - time.time() - start_time - if time.time() - start_time > 0 - else 1 - ) - speed = downloaded_size / 1024 / elapsed_time # KB/s - percent = downloaded_size / total_size if total_size > 0 else 0 - await _emit_download_progress( - progress_callback, - { - "url": url, - "downloaded": downloaded_size, - "total": total_size, - "percent": percent, - "speed": speed, - }, - ) - if show_progress: - print( - f"\rProgress: {percent:.2%} Speed: {speed:.2f} KB/s", - end="", - ) - await _emit_download_progress( - progress_callback, - { - "url": url, - "downloaded": downloaded_size, - "total": total_size, - "percent": 1, - "speed": 0, - }, - ) + await _download_response_to_file( + resp, + f, + url, + show_progress, + progress_callback, + show_downloading_label=False, + ) if show_progress: print() diff --git a/tests/unit/test_io_download_file.py b/tests/unit/test_io_download_file.py index b547793c1c..5cea2f0e50 100644 --- a/tests/unit/test_io_download_file.py +++ b/tests/unit/test_io_download_file.py @@ -27,7 +27,7 @@ async def __aexit__(self, exc_type, exc, tb): class _FakeSession: - def __init__(self, response: _FakeResponse): + def __init__(self, response: _FakeResponse | Exception): self._response = response async def __aenter__(self): @@ -37,15 +37,21 @@ async def __aexit__(self, exc_type, exc, tb): return False def get(self, *_args, **_kwargs): + if isinstance(self._response, Exception): + raise self._response return self._response def _patch_download_session(monkeypatch, response: _FakeResponse): + _patch_download_sessions(monkeypatch, [response]) + + +def _patch_download_sessions(monkeypatch, responses: list[_FakeResponse | Exception]): monkeypatch.setattr(io.aiohttp, "TCPConnector", lambda **_kwargs: object()) monkeypatch.setattr( io.aiohttp, "ClientSession", - lambda **_kwargs: _FakeSession(response), + lambda **_kwargs: _FakeSession(responses.pop(0)), ) @@ -57,7 +63,32 @@ async def test_download_file_rejects_non_200_response(monkeypatch, tmp_path): _FakeResponse(status=404, chunks=[b"not found"]), ) - with pytest.raises(RuntimeError, match="HTTP status code: 404"): + with pytest.raises(io.DownloadFileHTTPError, match="HTTP status code: 404"): + await io.download_file("https://example.test/missing", str(target_path)) + + assert not target_path.exists() + + +@pytest.mark.asyncio +async def test_download_file_rejects_non_200_response_after_ssl_fallback( + monkeypatch, + tmp_path, +): + class FakeSSLError(Exception): + pass + + target_path = tmp_path / "missing.bin" + _patch_download_sessions( + monkeypatch, + [ + FakeSSLError(), + _FakeResponse(status=404, chunks=[b"not found"]), + ], + ) + monkeypatch.setattr(io.aiohttp, "ClientConnectorSSLError", FakeSSLError) + monkeypatch.setattr(io.aiohttp, "ClientConnectorCertificateError", FakeSSLError) + + with pytest.raises(io.DownloadFileHTTPError, match="HTTP status code: 404"): await io.download_file("https://example.test/missing", str(target_path)) assert not target_path.exists()