Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 18 additions & 5 deletions astrbot/core/astr_main_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,10 +362,25 @@ def _get_workspace_path_for_umo(umo: str) -> Path:
return Path(get_astrbot_workspaces_path()) / normalized_umo


def _is_group_session(event: AstrMessageEvent) -> bool:
"""Return whether the event belongs to a group session.

Args:
event: Message event to inspect.

Returns:
True if the event has a group id.
"""
return bool(event.get_group_id())


def _apply_workspace_extra_prompt(
event: AstrMessageEvent,
req: ProviderRequest,
) -> None:
if _is_group_session(event):
return

extra_prompt_path = _get_workspace_path_for_umo(event.unified_msg_origin) / (
"EXTRA_PROMPT.md"
)
Expand Down Expand Up @@ -498,13 +513,11 @@ async def _ensure_persona_and_skills(
skill_manager = SkillManager()
skills = skill_manager.list_skills(active_only=True, runtime=runtime)
skills = _filter_skills_for_current_config(skills, cfg)
workspace_skills = (
skill_manager.list_workspace_skills(
workspace_skills = []
if runtime == "local" and not _is_group_session(event):
workspace_skills = skill_manager.list_workspace_skills(
_get_workspace_path_for_umo(event.unified_msg_origin)
)
if runtime == "local"
else []
)

if skills or workspace_skills:
if persona and persona.get("skills") is not None:
Expand Down
99 changes: 99 additions & 0 deletions tests/unit/test_astr_main_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -873,6 +873,70 @@ async def test_ensure_skills_includes_workspace_skills(
in req.system_prompt
)

@pytest.mark.asyncio
async def test_ensure_skills_skips_workspace_skills_for_group_sessions(
self,
monkeypatch,
tmp_path,
mock_event,
mock_context,
):
module = ama
data_dir = tmp_path / "data"
global_skills_dir = tmp_path / "global_skills"
plugins_dir = tmp_path / "plugins"
workspaces_dir = tmp_path / "workspaces"
for path in (data_dir, global_skills_dir, plugins_dir):
path.mkdir(parents=True, exist_ok=True)

global_skill_dir = global_skills_dir / "workspace-skill"
global_skill_dir.mkdir(parents=True)
global_skill_dir.joinpath("SKILL.md").write_text(
"---\ndescription: Global scoped skill.\n---\n",
encoding="utf-8",
)

mock_event.get_group_id.return_value = "group123"
mock_event.message_obj.group_id = "group123"
mock_event.unified_msg_origin = "test_platform:GroupMessage:group123"
workspace_root = workspaces_dir / module.normalize_umo_for_workspace(
mock_event.unified_msg_origin
)
workspace_skill_dir = workspace_root / "skills" / "workspace-skill"
workspace_skill_dir.mkdir(parents=True)
workspace_skill_dir.joinpath("SKILL.md").write_text(
"---\ndescription: Workspace scoped skill.\n---\n",
encoding="utf-8",
)

monkeypatch.setattr(
module,
"get_astrbot_workspaces_path",
lambda: str(workspaces_dir),
)
monkeypatch.setattr(
"astrbot.core.skills.skill_manager.get_astrbot_data_path",
lambda: str(data_dir),
)
monkeypatch.setattr(
"astrbot.core.skills.skill_manager.get_astrbot_skills_path",
lambda: str(global_skills_dir),
)
monkeypatch.setattr(
"astrbot.core.skills.skill_manager.get_astrbot_plugin_path",
lambda: str(plugins_dir),
)

req = ProviderRequest()
req.conversation = MagicMock(persona_id=None)

await module._ensure_persona_and_skills(
req, {"computer_use_runtime": "local"}, mock_context, mock_event
)

assert "Global scoped skill." in req.system_prompt
assert "Workspace scoped skill." not in req.system_prompt

@pytest.mark.asyncio
async def test_ensure_skills_respects_empty_persona_skills_for_workspace(
self,
Expand Down Expand Up @@ -1229,6 +1293,41 @@ async def test_decorate_llm_request_no_conversation(self, mock_event, mock_conte

assert req.prompt == "Hello"

@pytest.mark.asyncio
async def test_decorate_llm_request_skips_workspace_extra_prompt_for_group(
self,
monkeypatch,
tmp_path,
mock_event,
mock_context,
sample_config,
):
"""Test group sessions do not load workspace extra prompts."""
module = ama
workspaces_dir = tmp_path / "workspaces"
mock_event.get_group_id.return_value = "group123"
mock_event.message_obj.group_id = "group123"
mock_event.unified_msg_origin = "test_platform:GroupMessage:group123"
workspace_root = workspaces_dir / module.normalize_umo_for_workspace(
mock_event.unified_msg_origin
)
workspace_root.mkdir(parents=True)
workspace_root.joinpath("EXTRA_PROMPT.md").write_text(
"Group workspace injected prompt.",
encoding="utf-8",
)
monkeypatch.setattr(
module,
"get_astrbot_workspaces_path",
lambda: str(workspaces_dir),
)
req = ProviderRequest(prompt="Hello", system_prompt="System")
req.conversation = None

await module._decorate_llm_request(mock_event, req, mock_context, sample_config)

assert req.system_prompt == "System"


class TestPluginToolFix:
"""Tests for _plugin_tool_fix function."""
Expand Down
Loading