diff --git a/astrbot/core/astr_main_agent.py b/astrbot/core/astr_main_agent.py index 16ebac7a8b..f788d50997 100644 --- a/astrbot/core/astr_main_agent.py +++ b/astrbot/core/astr_main_agent.py @@ -362,10 +362,25 @@ def _get_workspace_path_for_umo(umo: str) -> Path: return Path(get_astrbot_workspaces_path()) / normalized_umo +def _is_group_session(event: AstrMessageEvent) -> bool: + """Return whether the event belongs to a group session. + + Args: + event: Message event to inspect. + + Returns: + True if the event has a group id. + """ + return bool(event.get_group_id()) + + def _apply_workspace_extra_prompt( event: AstrMessageEvent, req: ProviderRequest, ) -> None: + if _is_group_session(event): + return + extra_prompt_path = _get_workspace_path_for_umo(event.unified_msg_origin) / ( "EXTRA_PROMPT.md" ) @@ -498,13 +513,11 @@ async def _ensure_persona_and_skills( skill_manager = SkillManager() skills = skill_manager.list_skills(active_only=True, runtime=runtime) skills = _filter_skills_for_current_config(skills, cfg) - workspace_skills = ( - skill_manager.list_workspace_skills( + workspace_skills = [] + if runtime == "local" and not _is_group_session(event): + workspace_skills = skill_manager.list_workspace_skills( _get_workspace_path_for_umo(event.unified_msg_origin) ) - if runtime == "local" - else [] - ) if skills or workspace_skills: if persona and persona.get("skills") is not None: diff --git a/tests/unit/test_astr_main_agent.py b/tests/unit/test_astr_main_agent.py index 3723ec0d49..13fb73595a 100644 --- a/tests/unit/test_astr_main_agent.py +++ b/tests/unit/test_astr_main_agent.py @@ -873,6 +873,70 @@ async def test_ensure_skills_includes_workspace_skills( in req.system_prompt ) + @pytest.mark.asyncio + async def test_ensure_skills_skips_workspace_skills_for_group_sessions( + self, + monkeypatch, + tmp_path, + mock_event, + mock_context, + ): + module = ama + data_dir = tmp_path / "data" + global_skills_dir = tmp_path / "global_skills" + plugins_dir = tmp_path / "plugins" + workspaces_dir = tmp_path / "workspaces" + for path in (data_dir, global_skills_dir, plugins_dir): + path.mkdir(parents=True, exist_ok=True) + + global_skill_dir = global_skills_dir / "workspace-skill" + global_skill_dir.mkdir(parents=True) + global_skill_dir.joinpath("SKILL.md").write_text( + "---\ndescription: Global scoped skill.\n---\n", + encoding="utf-8", + ) + + mock_event.get_group_id.return_value = "group123" + mock_event.message_obj.group_id = "group123" + mock_event.unified_msg_origin = "test_platform:GroupMessage:group123" + workspace_root = workspaces_dir / module.normalize_umo_for_workspace( + mock_event.unified_msg_origin + ) + workspace_skill_dir = workspace_root / "skills" / "workspace-skill" + workspace_skill_dir.mkdir(parents=True) + workspace_skill_dir.joinpath("SKILL.md").write_text( + "---\ndescription: Workspace scoped skill.\n---\n", + encoding="utf-8", + ) + + monkeypatch.setattr( + module, + "get_astrbot_workspaces_path", + lambda: str(workspaces_dir), + ) + monkeypatch.setattr( + "astrbot.core.skills.skill_manager.get_astrbot_data_path", + lambda: str(data_dir), + ) + monkeypatch.setattr( + "astrbot.core.skills.skill_manager.get_astrbot_skills_path", + lambda: str(global_skills_dir), + ) + monkeypatch.setattr( + "astrbot.core.skills.skill_manager.get_astrbot_plugin_path", + lambda: str(plugins_dir), + ) + + req = ProviderRequest() + req.conversation = MagicMock(persona_id=None) + + await module._ensure_persona_and_skills( + req, {"computer_use_runtime": "local"}, mock_context, mock_event + ) + + assert "Global scoped skill." in req.system_prompt + assert "Workspace scoped skill." not in req.system_prompt + @pytest.mark.asyncio async def test_ensure_skills_respects_empty_persona_skills_for_workspace( self, @@ -1229,6 +1293,41 @@ async def test_decorate_llm_request_no_conversation(self, mock_event, mock_conte assert req.prompt == "Hello" + @pytest.mark.asyncio + async def test_decorate_llm_request_skips_workspace_extra_prompt_for_group( + self, + monkeypatch, + tmp_path, + mock_event, + mock_context, + sample_config, + ): + """Test group sessions do not load workspace extra prompts.""" + module = ama + workspaces_dir = tmp_path / "workspaces" + mock_event.get_group_id.return_value = "group123" + mock_event.message_obj.group_id = "group123" + mock_event.unified_msg_origin = "test_platform:GroupMessage:group123" + workspace_root = workspaces_dir / module.normalize_umo_for_workspace( + mock_event.unified_msg_origin + ) + workspace_root.mkdir(parents=True) + workspace_root.joinpath("EXTRA_PROMPT.md").write_text( + "Group workspace injected prompt.", + encoding="utf-8", + ) + monkeypatch.setattr( + module, + "get_astrbot_workspaces_path", + lambda: str(workspaces_dir), + ) + req = ProviderRequest(prompt="Hello", system_prompt="System") + req.conversation = None + + await module._decorate_llm_request(mock_event, req, mock_context, sample_config) + + assert req.system_prompt == "System" + class TestPluginToolFix: """Tests for _plugin_tool_fix function."""