From 3ee5da62d9834e794e8b503e34f59be012a3cf3c Mon Sep 17 00:00:00 2001 From: zzz27578 <2950506809@qq.com> Date: Sat, 20 Jun 2026 01:59:57 +0800 Subject: [PATCH] fix: apply fallback chat models to background wakeups --- astrbot/core/astr_agent_tool_exec.py | 7 ++- tests/unit/test_astr_agent_tool_exec.py | 76 +++++++++++++++++++++++++ 2 files changed, 80 insertions(+), 3 deletions(-) diff --git a/astrbot/core/astr_agent_tool_exec.py b/astrbot/core/astr_agent_tool_exec.py index 8c3ed661f9..2e5915bad0 100644 --- a/astrbot/core/astr_agent_tool_exec.py +++ b/astrbot/core/astr_agent_tool_exec.py @@ -543,11 +543,12 @@ async def _wake_main_agent_for_background_result( message_type=session.message_type, ) cron_event.role = event.role + cfg = ctx.get_config(umo=event.unified_msg_origin) or {} + provider_settings = cfg.get("provider_settings") or {} config = MainAgentBuildConfig( tool_call_timeout=run_context.tool_call_timeout, - streaming_response=ctx.get_config() - .get("provider_settings", {}) - .get("stream", False), + streaming_response=provider_settings.get("stream", False), + provider_settings=provider_settings, ) req = ProviderRequest() diff --git a/tests/unit/test_astr_agent_tool_exec.py b/tests/unit/test_astr_agent_tool_exec.py index 61fb4048c8..c0a18374a5 100644 --- a/tests/unit/test_astr_agent_tool_exec.py +++ b/tests/unit/test_astr_agent_tool_exec.py @@ -1,4 +1,5 @@ from types import SimpleNamespace +from unittest.mock import AsyncMock import mcp import pytest @@ -19,6 +20,7 @@ class _DummyEvent: def __init__(self, message_components: list[object] | None = None) -> None: self.unified_msg_origin = "webchat:FriendMessage:webchat!user!session" self.message_obj = SimpleNamespace(message=message_components or []) + self.role = "member" def get_extra(self, _key: str): return None @@ -36,6 +38,15 @@ def _build_run_context(message_components: list[object] | None = None): return ContextWrapper(context=ctx) +class _DoneRunner: + async def step_until_done(self, _max_step): + for item in (): + yield item + + def get_final_llm_resp(self): + return SimpleNamespace(role="assistant", completion_text="done") + + def test_build_handoff_toolset_keeps_permission_guards_for_default_tools(): mgr = FunctionToolManager() plugin_tool = FunctionTool( @@ -354,6 +365,71 @@ async def _fake_tool_loop_agent(**kwargs): assert captured["tool_call_timeout"] == 120 +@pytest.mark.asyncio +async def test_background_wakeup_passes_provider_settings_to_main_agent( + monkeypatch: pytest.MonkeyPatch, +): + provider_settings = { + "fallback_chat_models": ["fallback-provider"], + "request_max_retries": 3, + "stream": True, + } + captured: dict = {} + + async def _fake_get_session_conv(**_kwargs): + return SimpleNamespace(history="[]") + + async def _fake_build_main_agent(**kwargs): + captured.update(kwargs) + return SimpleNamespace(agent_runner=_DoneRunner()) + + monkeypatch.setattr( + "astrbot.core.astr_main_agent._get_session_conv", + _fake_get_session_conv, + ) + monkeypatch.setattr( + "astrbot.core.astr_main_agent.build_main_agent", + _fake_build_main_agent, + ) + monkeypatch.setattr( + "astrbot.core.astr_agent_tool_exec.persist_agent_history", + AsyncMock(), + ) + + send_tool = FunctionTool( + name="send_message_to_user", + description="send", + parameters={"type": "object", "properties": {}}, + ) + context = SimpleNamespace( + get_config=lambda **_kwargs: {"provider_settings": provider_settings}, + get_llm_tool_manager=lambda: SimpleNamespace( + get_builtin_tool=lambda _tool_cls: send_tool + ), + conversation_manager=SimpleNamespace(), + ) + run_context = ContextWrapper( + context=SimpleNamespace(event=_DummyEvent([]), context=context), + tool_call_timeout=456, + ) + + await FunctionToolExecutor._wake_main_agent_for_background_result( + run_context, + task_id="task-id", + tool_name="long_tool", + result_text="ok", + tool_args={}, + note="task finished", + summary_name="BackgroundTask", + ) + + config = captured["config"] + assert config.tool_call_timeout == 456 + assert config.streaming_response == provider_settings["stream"] + assert config.provider_settings == provider_settings + assert config.provider_settings["fallback_chat_models"] == ["fallback-provider"] + + @pytest.mark.asyncio async def test_collect_handoff_image_urls_filters_extensionless_file_outside_temp_root( monkeypatch: pytest.MonkeyPatch,