diff --git a/embodichain/gen_sim/action_agent_pipeline/__init__.py b/embodichain/gen_sim/action_agent_pipeline/__init__.py new file mode 100644 index 00000000..0517d273 --- /dev/null +++ b/embodichain/gen_sim/action_agent_pipeline/__init__.py @@ -0,0 +1,21 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +"""Action-agent graph compilation and atomic-action runtime.""" + +__all__: list[str] = [] diff --git a/embodichain/gen_sim/action_agent_pipeline/agents/__init__.py b/embodichain/gen_sim/action_agent_pipeline/agents/__init__.py new file mode 100644 index 00000000..4f45d84b --- /dev/null +++ b/embodichain/gen_sim/action_agent_pipeline/agents/__init__.py @@ -0,0 +1,24 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +__all__ = [ + "agent_base", + "compile_agent", + "llm", + "task_agent", +] diff --git a/embodichain/gen_sim/action_agent_pipeline/agents/agent_base.py b/embodichain/gen_sim/action_agent_pipeline/agents/agent_base.py new file mode 100644 index 00000000..fc967f65 --- /dev/null +++ b/embodichain/gen_sim/action_agent_pipeline/agents/agent_base.py @@ -0,0 +1,94 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from abc import ABCMeta +import os + +from embodichain.utils.utility import load_txt + + +def _resolve_prompt_path(file_name: str, config_dir: str | None = None) -> str: + # If absolute path, use directly + if os.path.isabs(file_name): + if os.path.exists(file_name): + return file_name + raise FileNotFoundError(f"Prompt file not found: {file_name}") + + # Try config directory first (for task-specific prompts) + if config_dir: + config_path = os.path.join(config_dir, file_name) + if os.path.exists(config_path): + return config_path + + # Try action_agent_pipeline/prompts directory for reusable prompts. + agents_prompts_dir = os.path.join( + os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "prompts" + ) + agents_path = os.path.join(agents_prompts_dir, file_name) + if os.path.exists(agents_path): + return agents_path + + # If still not found, raise error with search paths + searched_paths = [] + if config_dir: + searched_paths.append(f" - {config_dir}/{file_name}") + searched_paths.append(f" - {agents_prompts_dir}/{file_name}") + + raise FileNotFoundError( + f"Prompt file not found: {file_name}\n" + f"Searched in:\n" + "\n".join(searched_paths) + ) + + +class AgentBase(metaclass=ABCMeta): + def __init__(self, **kwargs) -> None: + + assert ( + "prompt_kwargs" in kwargs.keys() + ), "Key prompt_kwargs must exist in config." + + for key, value in kwargs.items(): + setattr(self, key, value) + + # Get config directory if provided + config_dir = kwargs.get("config_dir", None) + if config_dir: + config_dir = os.path.dirname(os.path.abspath(config_dir)) + + # Preload and store prompt contents inside self.prompt_kwargs + for key, val in self.prompt_kwargs.items(): + if val["type"] == "text": + file_path = _resolve_prompt_path(val["name"], config_dir) + val["content"] = load_txt(file_path) + else: + raise ValueError( + f"Now only support `text` type but {val['type']} is given." + ) + + def generate(self, *args, **kwargs): + pass + + def act(self, *args, **kwargs): + pass + + def get_composed_observations(self, **kwargs): + ret = {} + for key, val in self.prompt_kwargs.items(): + ret[key] = val["content"] + ret.update(kwargs) + return ret diff --git a/embodichain/gen_sim/action_agent_pipeline/agents/compile_agent.py b/embodichain/gen_sim/action_agent_pipeline/agents/compile_agent.py new file mode 100644 index 00000000..bd16c6e6 --- /dev/null +++ b/embodichain/gen_sim/action_agent_pipeline/agents/compile_agent.py @@ -0,0 +1,119 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +import hashlib +import json +from pathlib import Path +from typing import Any + +from embodichain.gen_sim.action_agent_pipeline.agents.agent_base import AgentBase +from embodichain.gen_sim.action_agent_pipeline.utils.llm_json import ( + extract_json_object, + normalize_json_content, +) +from embodichain.data import database_agent_prompt_dir + +__all__ = ["CompileAgent"] + +COMPILED_GRAPH_SCHEMA_VERSION = "nominal_graph_v1" + + +class CompileAgent(AgentBase): + """Compile and execute nominal atomic-action graph specs.""" + + query_prefix = "# " + query_suffix = "." + prompt_kwargs: dict[str, dict[str, Any]] + + def __init__(self, **kwargs) -> None: + for key, value in kwargs.items(): + setattr(self, key, value) + self.prompt_kwargs = kwargs.get("prompt_kwargs", {}) + + def generate(self, **kwargs): + log_dir = kwargs.get( + "log_dir", Path(database_agent_prompt_dir) / self.task_name + ) + file_path = Path(log_dir) / "agent_compiled_graph.json" + task_graph = extract_json_object(kwargs["task_graph"]) + task_graph_hash = _stable_json_hash(task_graph) + + if not kwargs.get("regenerate", False) and file_path.exists(): + existing_bundle = extract_json_object(file_path.read_text(encoding="utf-8")) + metadata = existing_bundle.get("metadata", {}) + if ( + metadata.get("schema_version") == COMPILED_GRAPH_SCHEMA_VERSION + and metadata.get("task_graph_hash") == task_graph_hash + ): + print(f"Compiled graph artifact already exists at {file_path}.") + return file_path, kwargs, None + + content = normalize_json_content( + { + "task_graph": task_graph, + "metadata": { + "schema_version": COMPILED_GRAPH_SCHEMA_VERSION, + "task_graph_hash": task_graph_hash, + }, + } + ) + + file_path.parent.mkdir(parents=True, exist_ok=True) + file_path.write_text(content, encoding="utf-8") + print(f"Compiled graph artifact saved to {file_path}") + return file_path, kwargs, content + + def act(self, graph_file_path, **kwargs): + graph_file_path = Path(graph_file_path) + if graph_file_path.suffix != ".json": + raise ValueError("CompileAgent executes compiled graph JSON artifacts.") + + from embodichain.gen_sim.action_agent_pipeline.runtime.graph_compiler import ( + compile_agent_graph_from_file, + ) + + runtime_kwargs = _runtime_kwargs(kwargs, getattr(self, "prompt_kwargs", {})) + graph = compile_agent_graph_from_file(graph_file_path) + result = graph.run(**runtime_kwargs) + print("Compiled agent graph executed successfully.") + return result + + def get_composed_observations(self, **kwargs): + return dict(kwargs) + + +def _stable_json_hash(content: dict[str, Any]) -> str: + payload = json.dumps( + content, ensure_ascii=False, sort_keys=True, separators=(",", ":") + ) + return hashlib.sha256(payload.encode("utf-8")).hexdigest() + + +def _runtime_kwargs( + kwargs: dict[str, Any], + prompt_kwargs: dict[str, dict[str, Any]], +) -> dict[str, Any]: + prompt_only_keys = set(prompt_kwargs) + prompt_only_keys.update( + { + "task_graph", + "observations", + "regenerate", + } + ) + return {key: value for key, value in kwargs.items() if key not in prompt_only_keys} diff --git a/embodichain/gen_sim/action_agent_pipeline/agents/llm.py b/embodichain/gen_sim/action_agent_pipeline/agents/llm.py new file mode 100644 index 00000000..b6c62485 --- /dev/null +++ b/embodichain/gen_sim/action_agent_pipeline/agents/llm.py @@ -0,0 +1,71 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from embodichain.gen_sim.action_agent_pipeline.utils.mllm import create_chat_openai + +__all__ = ["create_llm", "task_llm"] + + +# ------------------------------------------------------------------------------ +# LLM factory +# ------------------------------------------------------------------------------ + + +def create_llm(*, temperature=0.0, model=None, usage_stage=None): + return create_chat_openai( + temperature=temperature, + model=model, + usage_stage=usage_stage, + ) + + +# ------------------------------------------------------------------------------ +# LLM instances +# ------------------------------------------------------------------------------ + + +# Initialize LLM instances, but handle errors gracefully for documentation builds +def _create_llm_safe(*, temperature=0.0, model=None, usage_stage=None): + try: + return create_llm( + temperature=temperature, + model=model, + usage_stage=usage_stage, + ) + except Exception: + return None + + +task_llm = _create_llm_safe( + temperature=0.0, + usage_stage="action_agent.task_graph", +) + +if __name__ == "__main__": + + def call_llm(prompt, temperature=0.0, model=None): + llm = create_llm( + temperature=temperature, + model=model, + usage_stage="action_agent.debug", + ) + response = llm.invoke(prompt) + return response.content + + response = call_llm(prompt="Which model you are?", temperature=0.0) + print(response) diff --git a/embodichain/gen_sim/action_agent_pipeline/agents/task_agent.py b/embodichain/gen_sim/action_agent_pipeline/agents/task_agent.py new file mode 100644 index 00000000..6efbdc32 --- /dev/null +++ b/embodichain/gen_sim/action_agent_pipeline/agents/task_agent.py @@ -0,0 +1,72 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from pathlib import Path +from typing import Any + +from embodichain.gen_sim.action_agent_pipeline.agents.agent_base import AgentBase +from embodichain.gen_sim.action_agent_pipeline.utils.llm_json import ( + normalize_json_content, +) +from embodichain.gen_sim.action_agent_pipeline.prompts import TaskPrompt +from embodichain.data import database_agent_prompt_dir +from embodichain.utils.utility import load_txt + +__all__ = ["TaskAgent"] + + +class TaskAgent(AgentBase): + """Generate the nominal atomic-action task graph.""" + + prompt_name: str + prompt_kwargs: dict[str, dict[str, Any]] + + def __init__(self, llm, **kwargs) -> None: + super().__init__(**kwargs) + if llm is None: + raise ValueError( + "LLM is None. Configure the shared MLLM entry point " + "`embodichain.gen_sim.action_agent_pipeline.utils.mllm` with " + "OPENAI_API_KEY, optional " + "OPENAI_MODEL/OPENAI_BASE_URL, or the gen-sim LLM config." + ) + self.llm = llm + + def generate(self, **kwargs) -> str: + log_dir = kwargs.get( + "log_dir", Path(database_agent_prompt_dir) / self.task_name + ) + file_path = Path(log_dir) / "agent_task_graph.json" + + if not kwargs.get("regenerate", False) and file_path.exists(): + print(f"Task graph already exists at {file_path}.") + return load_txt(file_path) + + prompt = getattr(TaskPrompt, self.prompt_name)(**kwargs) + response = self.llm.invoke(prompt) + print(f"\033[92m\nTask agent output:\n{response.content}\n\033[0m") + + content = normalize_json_content(response.content) + file_path.parent.mkdir(parents=True, exist_ok=True) + file_path.write_text(content, encoding="utf-8") + print(f"Generated task graph saved to {file_path}") + + return content + + def act(self, *args, **kwargs): + return super().act(*args, **kwargs) diff --git a/embodichain/gen_sim/action_agent_pipeline/cli/__init__.py b/embodichain/gen_sim/action_agent_pipeline/cli/__init__.py new file mode 100644 index 00000000..015c4151 --- /dev/null +++ b/embodichain/gen_sim/action_agent_pipeline/cli/__init__.py @@ -0,0 +1,19 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +__all__: list[str] = [] diff --git a/embodichain/gen_sim/action_agent_pipeline/cli/generate_ur5_basket_config.py b/embodichain/gen_sim/action_agent_pipeline/cli/generate_ur5_basket_config.py new file mode 100644 index 00000000..3a754e11 --- /dev/null +++ b/embodichain/gen_sim/action_agent_pipeline/cli/generate_ur5_basket_config.py @@ -0,0 +1,246 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +import argparse +from pathlib import Path + +from embodichain.gen_sim.action_agent_pipeline.generation.ur5_basket_config import ( + TargetReplacementSpec, + generate_ur5_basket_config_from_project, +) + +__all__ = ["cli"] + + +def cli() -> None: + parser = argparse.ArgumentParser( + description=( + "Generate a Dual-UR5 basket-placement action-agent config from an " + "exported tabletop gym project." + ) + ) + parser.add_argument( + "--gym_project", + type=str, + required=True, + help=( + "Path to a project root, formatted tabletop scene folder, or " + "gym_config.json/gym_config_merged.json. Directory inputs prefer " + "gym_config_merged.json when available." + ), + ) + parser.add_argument( + "--output_dir", + type=str, + required=True, + help="Destination directory for generated agent configs.", + ) + parser.add_argument( + "--task_name", + type=str, + default="UR5BreadBasket", + help="Task name passed to run_agent.", + ) + parser.add_argument( + "--task_description", + type=str, + default=None, + help=( + "Simple natural-language relative-placement task. Providing this " + "uses the LLM to generate a constrained config-level prompt/spec." + ), + ) + parser.add_argument( + "--task_file", + type=str, + default=None, + help="Optional text file containing --task_description.", + ) + parser.add_argument( + "--use_llm_roles", + action="store_true", + default=False, + help=( + "Use the shared LLM only to refine object role mapping. The task " + "template and prompts remain deterministic." + ), + ) + parser.add_argument( + "--llm_model", + type=str, + default=None, + help="Optional LLM model override for --use_llm_roles.", + ) + parser.add_argument( + "--target_body_scale", + type=float, + default=0.7, + help=( + "Uniform body_scale for generated target objects. Basket-like " + "containers keep their source body_scale." + ), + ) + parser.add_argument( + "--target_replacement1", + "--target-replacement1", + nargs=2, + metavar=("SOURCE_UID", "PROMPT"), + default=None, + help=( + "Generate /mesh_assets/new1 from PROMPT and use it " + "to replace SOURCE_UID in the generated config." + ), + ) + parser.add_argument( + "--target_replacement2", + "--target-replacement2", + nargs=2, + metavar=("SOURCE_UID", "PROMPT"), + default=None, + help=( + "Generate /mesh_assets/new2 from PROMPT and use it " + "to replace SOURCE_UID in the generated config." + ), + ) + parser.add_argument( + "--sync_replacement_names", + "--sync-replacement-names", + action="store_true", + default=False, + help=( + "Also update replacement target runtime UIDs and generated prompts " + "from the replacement prompts." + ), + ) + parser.add_argument( + "--reuse_target_replacements", + "--reuse-target-replacements", + dest="reuse_target_replacements", + action=argparse.BooleanOptionalAction, + default=True, + help=( + "Reuse existing prompt-generated replacement GLBs when the prompt " + "and expected output name match. Defaults to true." + ), + ) + parser.add_argument( + "--prewarm_coacd_cache", + "--prewarm-coacd-cache", + dest="prewarm_coacd_cache", + action=argparse.BooleanOptionalAction, + default=True, + help=( + "Precompute environment CoACD cache files during config generation. " + "Defaults to true." + ), + ) + parser.add_argument( + "--overwrite", + action="store_true", + default=False, + help="Overwrite generated files if they already exist.", + ) + parser.add_argument( + "--max_episodes", + type=int, + default=1, + help="max_episodes value written to fast_gym_config.json.", + ) + parser.add_argument( + "--max_episode_steps", + type=int, + default=1000, + help="max_episode_steps value written to fast_gym_config.json.", + ) + args = parser.parse_args() + task_description = _resolve_task_description(args) + target_replacements = _resolve_target_replacements(args) + + paths = generate_ur5_basket_config_from_project( + gym_project=args.gym_project, + output_dir=args.output_dir, + task_name=args.task_name, + task_description=task_description, + use_llm_roles=args.use_llm_roles, + llm_model=args.llm_model, + target_body_scale=args.target_body_scale, + target_replacements=target_replacements, + sync_replacement_names=args.sync_replacement_names, + reuse_target_replacements=args.reuse_target_replacements, + prewarm_coacd_cache=args.prewarm_coacd_cache, + overwrite=args.overwrite, + max_episodes=args.max_episodes, + max_episode_steps=args.max_episode_steps, + ) + + print(f"Generated gym config: {paths.gym_config}") + print(f"Generated agent config: {paths.agent_config}") + print(f"Generated task prompt: {paths.task_prompt}") + print(f"Generated basic background: {paths.basic_background}") + print(f"Generated atom actions: {paths.atom_actions}") + if paths.summary: + print("Generation summary:") + for key, value in paths.summary.items(): + print(f" {key}: {value}") + print( + "Run with:\n" + "python -m embodichain.gen_sim.action_agent_pipeline.cli.run_agent " + f"--task_name {args.task_name} " + f'--gym_config "{paths.gym_config}" ' + f'--agent_config "{paths.agent_config}" ' + "--regenerate" + ) + + +def _resolve_task_description(args: argparse.Namespace) -> str | None: + if args.task_description and args.task_file: + raise ValueError("Use either --task_description or --task_file, not both.") + if args.task_file: + return Path(args.task_file).expanduser().read_text(encoding="utf-8").strip() + if args.task_description: + return args.task_description.strip() + return None + + +def _resolve_target_replacements( + args: argparse.Namespace, +) -> list[TargetReplacementSpec]: + replacements = [] + if args.target_replacement1: + source_uid, prompt = args.target_replacement1 + replacements.append( + TargetReplacementSpec( + source_uid=source_uid, + prompt=prompt, + output_dir_name="new1", + ) + ) + if args.target_replacement2: + source_uid, prompt = args.target_replacement2 + replacements.append( + TargetReplacementSpec( + source_uid=source_uid, + prompt=prompt, + output_dir_name="new2", + ) + ) + return replacements + + +if __name__ == "__main__": + cli() diff --git a/embodichain/gen_sim/action_agent_pipeline/cli/pipeline_records.py b/embodichain/gen_sim/action_agent_pipeline/cli/pipeline_records.py new file mode 100644 index 00000000..79d5a189 --- /dev/null +++ b/embodichain/gen_sim/action_agent_pipeline/cli/pipeline_records.py @@ -0,0 +1,383 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +"""Pipeline history and manifest record helpers.""" + +from __future__ import annotations + +import argparse +from collections.abc import Sequence +from datetime import datetime +import hashlib +import json +from pathlib import Path +from typing import Any + +__all__ = [ + "append_pipeline_history", + "build_pipeline_record", + "find_history_entry_by_index", + "history_entry_has_source", + "history_entry_index", + "path_from_history_entry", + "pipeline_history_path", + "read_pipeline_history", + "resolve_record_path", + "resolve_source_gym_config", + "write_pipeline_manifests", +] + + +def pipeline_history_path(args: argparse.Namespace) -> Path: + return Path(args.pipeline_history_path).expanduser().resolve() + + +def read_pipeline_history( + history_path: Path, + *, + schema_version: int, +) -> dict[str, Any]: + if not history_path.exists(): + return {"schema_version": schema_version, "runs": []} + + data = json.loads(history_path.read_text(encoding="utf-8")) + if not isinstance(data, dict): + raise ValueError(f"Pipeline history must be a JSON object: {history_path}") + runs = data.get("runs") + if not isinstance(runs, list): + raise ValueError(f"Pipeline history must contain a runs list: {history_path}") + return { + "schema_version": data.get("schema_version", schema_version), + "runs": runs, + } + + +def find_history_entry_by_index( + runs: list[Any], history_index: int +) -> dict[str, Any] | None: + for entry in runs: + if isinstance(entry, dict) and history_entry_index(entry) == history_index: + return entry + return None + + +def history_entry_index(entry: dict[str, Any]) -> int: + try: + return int(entry.get("index", 0)) + except (TypeError, ValueError): + return 0 + + +def history_entry_has_source(entry: dict[str, Any]) -> bool: + return bool(entry.get("source_gym_config") or entry.get("source_gym_project_dir")) + + +def path_from_history_entry(entry: dict[str, Any], *, repo_root: Path) -> Path: + source = entry.get("source_gym_config") or entry.get("source_gym_project_dir") + if not source: + raise ValueError( + f"Pipeline history entry #{entry.get('index')} has no source gym path." + ) + path = resolve_record_path(str(source), repo_root=repo_root) + if not path.exists(): + raise FileNotFoundError(f"Pipeline history source path does not exist: {path}") + return path + + +def resolve_record_path(value: str | Path, *, repo_root: Path) -> Path: + path = Path(value).expanduser() + if path.is_absolute(): + return path.resolve() + return (repo_root / path).resolve() + + +def write_pipeline_manifests( + *, + args: argparse.Namespace, + resolution: Any, + generated_paths: Any, + target_replacements: Sequence[object], + repo_root: Path, + schema_version: int, + manifest_filename: str, +) -> dict[str, Any]: + history_path = pipeline_history_path(args) + record = build_pipeline_record( + args=args, + resolution=resolution, + generated_paths=generated_paths, + history_path=history_path, + target_replacements=target_replacements, + repo_root=repo_root, + schema_version=schema_version, + ) + record = append_pipeline_history( + history_path, + record, + schema_version=schema_version, + ) + + manifest_path = Path(generated_paths.output_dir) / manifest_filename + manifest_path.write_text( + json.dumps(record, ensure_ascii=False, indent=4) + "\n", + encoding="utf-8", + ) + print(f"Updated pipeline history: {history_path}", flush=True) + print(f"Wrote pipeline manifest: {manifest_path}", flush=True) + return record + + +def build_pipeline_record( + *, + args: argparse.Namespace, + resolution: Any, + generated_paths: Any, + history_path: Path, + target_replacements: Sequence[object], + repo_root: Path, + schema_version: int, +) -> dict[str, Any]: + source_gym_config = resolve_source_gym_config( + Path(resolution.path), + gym_config_preference=("gym_config_merged.json", "gym_config.json"), + ) + source_gym_project_dir = source_gym_config.parent + source_sha256 = _file_sha256(source_gym_config) + record: dict[str, Any] = { + "schema_version": schema_version, + "created_at": datetime.now().astimezone().isoformat(timespec="seconds"), + "task_name": args.task_name, + "source_mode": resolution.mode, + "source_id": f"gym_config_sha256:{source_sha256}", + "source_gym_config_sha256": source_sha256, + "path_base": "repo_root", + "source_gym_project_dir": _record_path(source_gym_project_dir, repo_root), + "source_gym_config": _record_path(source_gym_config, repo_root), + "input_path": _record_path(Path(resolution.path), repo_root), + "config_output_dir": _record_path(Path(generated_paths.output_dir), repo_root), + "generated_gym_config": _record_path( + Path(generated_paths.gym_config), + repo_root, + ), + "generated_agent_config": _record_path( + Path(generated_paths.agent_config), + repo_root, + ), + "generated_task_prompt": _record_path( + Path(generated_paths.task_prompt), + repo_root, + ), + "generated_basic_background": _record_path( + Path(generated_paths.basic_background), + repo_root, + ), + "generated_atom_actions": _record_path( + Path(generated_paths.atom_actions), + repo_root, + ), + "pipeline_history_path": _record_path(history_path, repo_root), + "target_body_scale": args.target_body_scale, + "target_replacements": _target_replacement_records( + args, + target_replacements, + ), + "sync_replacement_names": args.sync_replacement_names, + "reuse_target_replacements": args.reuse_target_replacements, + "prewarm_coacd_cache": args.prewarm_coacd_cache, + "overwrite_config": args.overwrite_config, + "regenerate": args.regenerate, + "skip_run_agent": args.skip_run_agent, + "generation_summary": generated_paths.summary, + } + if args.task_description: + record["task_description"] = args.task_description + record.update(_source_request_record(args, resolution, repo_root=repo_root)) + return record + + +def resolve_source_gym_config( + input_path: Path, + *, + gym_config_preference: Sequence[str], +) -> Path: + input_path = input_path.expanduser().resolve() + if input_path.is_file(): + if input_path.name not in gym_config_preference: + expected = ", ".join(gym_config_preference) + raise ValueError(f"Expected one of {expected}, got: {input_path}") + return input_path + + for filename in gym_config_preference: + path = input_path / filename + if path.is_file(): + return path.resolve() + + matches = [] + for filename in gym_config_preference: + matches.extend(sorted(input_path.rglob(filename))) + unique_matches = sorted({path.resolve() for path in matches}) + if len(unique_matches) == 1: + return unique_matches[0] + if not unique_matches: + expected = " or ".join(gym_config_preference) + raise FileNotFoundError(f"{expected} not found under: {input_path}") + match_text = ", ".join(path.as_posix() for path in unique_matches) + raise ValueError( + f"Multiple gym config files found under {input_path}: {match_text}" + ) + + +def append_pipeline_history( + history_path: Path, + record: dict[str, Any], + *, + schema_version: int, +) -> dict[str, Any]: + history = read_pipeline_history(history_path, schema_version=schema_version) + runs = history["runs"] + next_index = ( + max( + (history_entry_index(entry) for entry in runs if isinstance(entry, dict)), + default=0, + ) + + 1 + ) + record = dict(record) + record["index"] = next_index + + runs.append(record) + history["schema_version"] = schema_version + history_path.parent.mkdir(parents=True, exist_ok=True) + history_path.write_text( + json.dumps(history, ensure_ascii=False, indent=4) + "\n", + encoding="utf-8", + ) + return record + + +def _source_request_record( + args: argparse.Namespace, + resolution: Any, + *, + repo_root: Path, +) -> dict[str, Any]: + record: dict[str, Any] = {} + if args.image_name: + record["image_name"] = args.image_name + if args.image: + record["image"] = _record_path(Path(args.image).expanduser(), repo_root) + if args.use_image2scene: + record.update( + { + "server": args.server, + "background": args.background, + "image2scene_root": _record_path( + Path(args.image2scene_root).expanduser(), + repo_root, + ), + "image2scene_download_dir": str(args.image2scene_download_dir), + "image2scene_output_root": str(args.image2scene_output_root), + "image2scene_gen_config": str(args.image2scene_gen_config), + "image2scene_llm_config": str(args.image2scene_llm_config), + } + ) + if args.image2scene_extract_dir is not None: + record["image2scene_extract_dir"] = str(args.image2scene_extract_dir) + if args.image2scene_merged_output is not None: + record["image2scene_merged_output"] = str(args.image2scene_merged_output) + elif resolution.mode == "image2tabletop": + record.update( + { + "server": args.server, + "gym_project_root": _record_path( + Path(args.gym_project_root).expanduser(), + repo_root, + ), + "overwrite_gym_project": args.overwrite_gym_project, + } + ) + elif resolution.mode == "existing_gym_project": + record["gym_project"] = _record_path( + Path(args.gym_project).expanduser(), + repo_root, + ) + elif resolution.mode == "history" and resolution.base_history is not None: + base_source_path = path_from_history_entry( + resolution.base_history, + repo_root=repo_root, + ) + record.update( + { + "base_task_name": args.base_task_name, + "base_history_index": resolution.base_history.get("index"), + "base_history_task_name": resolution.base_history.get("task_name"), + "base_history_source_id": resolution.base_history.get("source_id"), + "base_history_source_gym_config": _record_path( + base_source_path, + repo_root, + ), + } + ) + return record + + +def _target_replacement_records( + args: argparse.Namespace, + target_replacements: Sequence[object], +) -> list[dict[str, str]]: + requested_by_output_dir = { + output_dir_name: replacement[0] + for output_dir_name, replacement in ( + ("new1", args.target_replacement1), + ("new2", args.target_replacement2), + ) + if replacement and len(replacement) == 2 + } + records = [] + for replacement in target_replacements: + output_dir_name = str(getattr(replacement, "output_dir_name")) + source_uid = str(getattr(replacement, "source_uid")) + record = { + "source_uid": source_uid, + "prompt": str(getattr(replacement, "prompt")), + "output_dir_name": output_dir_name, + } + requested_source_uid = requested_by_output_dir.get(output_dir_name) + if requested_source_uid and requested_source_uid != source_uid: + record["requested_source_uid"] = requested_source_uid + records.append(record) + return records + + +def _record_path(path: Path, repo_root: Path) -> str: + path = path.expanduser() + if not path.is_absolute(): + path = (Path.cwd() / path).resolve() + else: + path = path.resolve() + repo_root = repo_root.expanduser().resolve() + try: + return path.relative_to(repo_root).as_posix() + except ValueError: + return path.as_posix() + + +def _file_sha256(path: Path) -> str: + digest = hashlib.sha256() + with path.open("rb") as file: + for chunk in iter(lambda: file.read(1024 * 1024), b""): + digest.update(chunk) + return digest.hexdigest() diff --git a/embodichain/gen_sim/action_agent_pipeline/cli/run_agent.py b/embodichain/gen_sim/action_agent_pipeline/cli/run_agent.py new file mode 100644 index 00000000..999996e4 --- /dev/null +++ b/embodichain/gen_sim/action_agent_pipeline/cli/run_agent.py @@ -0,0 +1,155 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +import argparse + +import gymnasium +import numpy as np +import torch +import tqdm + +from embodichain.gen_sim.action_agent_pipeline.env_adapters.tableware.agent_env import ( # noqa: F401 + AtomicActionsAgentEnv, +) +from embodichain.lab.gym.utils.gym_utils import ( + add_env_launcher_args_to_parser, + build_env_cfg_from_args, +) +from embodichain.utils.logger import log_error, log_info, log_warning +from embodichain.utils.utility import load_config + +__all__ = ["cli"] + + +def cli() -> None: + np.set_printoptions(5, suppress=True) + torch.set_printoptions(precision=5, sci_mode=False) + + parser = argparse.ArgumentParser() + add_env_launcher_args_to_parser(parser) + parser.add_argument( + "--task_name", + type=str, + help="Name of the task.", + required=True, + ) + parser.add_argument( + "--agent_config", + type=str, + help="Path to the agent configuration file.", + required=True, + ) + parser.add_argument( + "--regenerate", + action="store_true", + help="Whether to regenerate code if already existed.", + default=False, + ) + + args = parser.parse_args() + + if args.num_envs != 1: + log_error(f"Currently only support num_envs=1, but got {args.num_envs}.") + raise SystemExit(1) + + env_cfg, gym_config, _ = build_env_cfg_from_args(args) + agent_config = load_config(args.agent_config) + + env = gymnasium.make( + id=gym_config["id"], + cfg=env_cfg, + agent_config=agent_config, + agent_config_path=args.agent_config, + task_name=args.task_name, + ) + + _run_action_agent(args, env, gym_config) + + if args.headless: + env.reset(options={"final": True}) + + +def _run_action_agent(args: argparse.Namespace, env: gymnasium.Env, gym_config: dict): + """Run action-agent graphs without relying on the shared run_env runner.""" + if getattr(args, "preview", False): + log_warning("Preview mode is handled by the shared runner and is skipped here.") + + log_info("Start action-agent data generation.", color="green") + for trajectory_idx in range(gym_config.get("max_episodes", 1)): + _generate_action_agent_trajectory( + args, + env, + trajectory_idx, + ) + _, _ = env.reset() + + +def _generate_action_agent_trajectory( + args: argparse.Namespace, + env: gymnasium.Env, + trajectory_idx: int, +) -> bool: + _, _ = env.reset() + action_list = env.get_wrapper_attr("create_demo_action_list")( + action_sentence=trajectory_idx, + save_path=getattr(args, "save_path", ""), + save_video=getattr(args, "save_video", False), + debug_mode=getattr(args, "debug_mode", False), + regenerate=getattr(args, "regenerate", False), + ) + if action_list is None or len(action_list) == 0: + log_warning("Action is invalid. Skip to next generation.") + return False + + if getattr(action_list, "already_executed", False): + log_info("Action list was already executed by the action-agent runtime.") + _log_task_success(env) + return True + + for action in tqdm.tqdm( + action_list, + desc=f"Executing action list #{trajectory_idx}", + unit="step", + ): + env.step(action) + _log_task_success(env) + return True + + +def _log_task_success(env: gymnasium.Env) -> bool | None: + try: + success_fn = ( + env.get_wrapper_attr("is_task_success") + if hasattr(env, "get_wrapper_attr") + else env.is_task_success + ) + success = success_fn() + except Exception as exc: + log_warning(f"Failed to evaluate task success after execution: {exc}") + return None + + if isinstance(success, torch.Tensor): + success_value = bool(success.detach().cpu().flatten().all().item()) + else: + success_value = bool(np.asarray(success).flatten().all()) + log_info(f"Task success after execution: {success_value}", color="green") + return success_value + + +if __name__ == "__main__": + cli() diff --git a/embodichain/gen_sim/action_agent_pipeline/cli/run_agent_pipeline.py b/embodichain/gen_sim/action_agent_pipeline/cli/run_agent_pipeline.py new file mode 100644 index 00000000..ac08b311 --- /dev/null +++ b/embodichain/gen_sim/action_agent_pipeline/cli/run_agent_pipeline.py @@ -0,0 +1,1329 @@ +#!/usr/bin/env python3 +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +"""Run the Image2Tabletop -> config generation -> action-agent pipeline.""" + +from __future__ import annotations + +import argparse +from collections.abc import Callable +from dataclasses import dataclass +from datetime import datetime +import json +import os +from pathlib import Path +import re +import shlex +import subprocess +import sys +from typing import Any + + +def _repo_root() -> Path: + current = Path(__file__).resolve() + for parent in current.parents: + if (parent / "setup.py").is_file() and (parent / "embodichain").is_dir(): + return parent + return Path.cwd().resolve() + + +__all__ = ["main"] + +_REPO_ROOT = _repo_root() +if str(_REPO_ROOT) not in sys.path: + sys.path.insert(0, str(_REPO_ROOT)) + +from embodichain.gen_sim.action_agent_pipeline.cli.pipeline_records import ( + find_history_entry_by_index as _records_find_history_entry_by_index, + history_entry_has_source as _records_history_entry_has_source, + history_entry_index as _records_history_entry_index, + path_from_history_entry as _records_path_from_history_entry, + pipeline_history_path as _records_pipeline_history_path, + read_pipeline_history as _records_read_pipeline_history, + resolve_source_gym_config as _records_resolve_source_gym_config, + write_pipeline_manifests as _records_write_pipeline_manifests, +) + +_DEFAULT_SERVER = "http://192.168.3.23:4523" +_DEFAULT_GYM_PROJECT_ROOT = _REPO_ROOT / "gym_project" +_DEFAULT_ACTION_AGENT_WORKSPACE = _DEFAULT_GYM_PROJECT_ROOT / "action_agent_pipeline" +_DEFAULT_IMAGE = _DEFAULT_ACTION_AGENT_WORKSPACE / "images/demo1.jpg" +_DEFAULT_IMAGE_DIR = _DEFAULT_IMAGE.parent +_DEFAULT_EXISTING_GYM_PROJECT = _DEFAULT_GYM_PROJECT_ROOT / "1780562837_gym_project" +_DEFAULT_IMAGE2SCENE_ROOT = _REPO_ROOT / "gym_project/environment/image2tabletop" +_DEFAULT_IMAGE2SCENE_IMAGE = "scene_image/robotwin_example.png" +_DEFAULT_IMAGE2SCENE_DOWNLOAD_DIR = "./downloads" +_DEFAULT_IMAGE2SCENE_OUTPUT_ROOT = "./generated" +_DEFAULT_IMAGE2SCENE_CONFIG = "./gen_config.json" +_DEFAULT_CONFIG_OUTPUT_DIR = _DEFAULT_ACTION_AGENT_WORKSPACE / "configs/demo3_text" +_DEFAULT_PIPELINE_HISTORY = ( + _DEFAULT_ACTION_AGENT_WORKSPACE / "configs/pipeline_history.json" +) +_DEFAULT_TASK_NAME = "Demo3_Text" +_DEFAULT_TASK_TEMPLATE_NAMES = frozenset({"Demo1_Text"}) +_IMAGE_SUFFIXES = (".jpg", ".jpeg", ".png", ".webp", ".bmp") +_GYM_CONFIG_PREFERENCE = ("gym_config_merged.json", "gym_config.json") +_PIPELINE_HISTORY_SCHEMA_VERSION = 1 +_PIPELINE_MANIFEST_FILENAME = "pipeline_manifest.json" +_INDEXED_REPLACEMENT_ALIAS_RE = re.compile( + r"^(?P[A-Za-z][A-Za-z0-9 _-]*?)[ _-]?(?P[0-9]+)$" +) + + +@dataclass(frozen=True) +class ProjectResolution: + path: Path + mode: str + base_history: dict[str, Any] | None = None + + +def _build_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser( + description=( + "Generate a tabletop gym project from one image, generate action-agent " + "configs from that project, then run the generated task." + ) + ) + image_group = parser.add_mutually_exclusive_group() + image_group.add_argument( + "--image", + default=None, + help=( + f"Input image path. If omitted, defaults to {_DEFAULT_IMAGE.as_posix()} " + f"or {_DEFAULT_IMAGE2SCENE_IMAGE} with --use-image2scene." + ), + ) + image_group.add_argument( + "--image-name", + "--image_name", + dest="image_name", + default=None, + help=( + "Image file name under the default image directory. The suffix is " + 'optional, e.g. "demo6" resolves to demo6.jpg.' + ), + ) + parser.add_argument( + "--server", + default=_DEFAULT_SERVER, + help=f"Image2Tabletop API server. Defaults to {_DEFAULT_SERVER}", + ) + parser.add_argument( + "--use-image2scene", + action="store_true", + default=False, + help=( + "Use gym_project/environment/image2tabletop/demo_api/client/" + "image2scene_pipeline.py as the first stage and continue from its " + "gym_config_merged.json output." + ), + ) + parser.add_argument( + "--background", + default=None, + help=( + "Background description passed to image2scene_pipeline.py. Required " + "with --use-image2scene." + ), + ) + parser.add_argument( + "--image2scene-root", + default=str(_DEFAULT_IMAGE2SCENE_ROOT), + help=( + "Working directory for image2scene_pipeline.py. Defaults to " + f"{_DEFAULT_IMAGE2SCENE_ROOT.as_posix()}" + ), + ) + parser.add_argument( + "--image2scene-download-dir", + default=_DEFAULT_IMAGE2SCENE_DOWNLOAD_DIR, + help=( + "Download directory passed to image2scene_pipeline.py. Relative " + "paths are interpreted under --image2scene-root. Defaults to " + f"{_DEFAULT_IMAGE2SCENE_DOWNLOAD_DIR}." + ), + ) + parser.add_argument( + "--image2scene-output-root", + default=_DEFAULT_IMAGE2SCENE_OUTPUT_ROOT, + help=( + "Generated EC project directory passed to image2scene_pipeline.py. " + "Relative paths are interpreted under --image2scene-root. Defaults " + f"to {_DEFAULT_IMAGE2SCENE_OUTPUT_ROOT}." + ), + ) + parser.add_argument( + "--image2scene-gen-config", + default=_DEFAULT_IMAGE2SCENE_CONFIG, + help=( + "Generation config passed to image2scene_pipeline.py. Relative " + "paths are interpreted under --image2scene-root. Defaults to " + f"{_DEFAULT_IMAGE2SCENE_CONFIG}." + ), + ) + parser.add_argument( + "--image2scene-llm-config", + default=_DEFAULT_IMAGE2SCENE_CONFIG, + help=( + "LLM config passed to image2scene_pipeline.py. Relative paths are " + "interpreted under --image2scene-root. Defaults to " + f"{_DEFAULT_IMAGE2SCENE_CONFIG}." + ), + ) + parser.add_argument( + "--image2scene-extract-dir", + default=None, + help=( + "Optional extract directory passed to image2scene_pipeline.py. " + "Relative paths are interpreted under --image2scene-root." + ), + ) + parser.add_argument( + "--image2scene-merged-output", + default=None, + help=( + "Optional merged output path passed to image2scene_pipeline.py. " + "Relative paths are interpreted under --image2scene-root." + ), + ) + parser.add_argument( + "--gym-project-root", + default=str(_DEFAULT_GYM_PROJECT_ROOT), + help=( + "Directory where Image2Tabletop generated gym projects are written. " + f"Defaults to {_DEFAULT_GYM_PROJECT_ROOT.as_posix()}" + ), + ) + parser.add_argument( + "--use-existing-gym-project", + action="store_true", + default=False, + help=( + "Skip Image2Tabletop API and start from --gym-project. Defaults to " + "false." + ), + ) + parser.add_argument( + "--base-task-name", + "--base_task_name", + dest="base_task_name", + default=None, + help=( + "Start from the latest pipeline history entry with this task name. " + "Use this to chain demos, e.g. demo2 based on Demo1_Text." + ), + ) + parser.add_argument( + "--base-history-index", + "--base_history_index", + dest="base_history_index", + type=int, + default=None, + help=( + "Start from a specific pipeline history index. When used with " + "--base-task-name, the history entry must match that task name." + ), + ) + parser.add_argument( + "--gym-project", + "--gym_project", + dest="gym_project", + default=str(_DEFAULT_EXISTING_GYM_PROJECT), + help=( + "Existing gym project used with --use-existing-gym-project. " + f"Defaults to {_DEFAULT_EXISTING_GYM_PROJECT.as_posix()}" + ), + ) + parser.add_argument( + "--config-output-dir", + "--output_dir", + dest="config_output_dir", + default=str(_DEFAULT_CONFIG_OUTPUT_DIR), + help=( + "Destination directory for generated config files. Defaults to " + f"{_DEFAULT_CONFIG_OUTPUT_DIR.as_posix()}" + ), + ) + parser.add_argument( + "--pipeline-history-path", + "--pipeline_history_path", + dest="pipeline_history_path", + default=str(_DEFAULT_PIPELINE_HISTORY), + help=( + "Global pipeline history JSON path. Defaults to " + f"{_DEFAULT_PIPELINE_HISTORY.as_posix()}" + ), + ) + parser.add_argument( + "--task_name", + "--task-name", + dest="task_name", + default=_DEFAULT_TASK_NAME, + help=f"Task name passed to run_agent. Defaults to {_DEFAULT_TASK_NAME}", + ) + parser.add_argument( + "--task_description", + "--task-description", + dest="task_description", + default="", + help=( + 'Task description passed to config generation. Defaults to "". ' + "Ignored for default-template tasks such as Demo1_Text." + ), + ) + parser.add_argument( + "--target_body_scale", + "--target-body-scale", + dest="target_body_scale", + type=float, + default=0.8, + help=( + "Uniform body_scale for generated target objects. Basket-like " + "containers keep their source body_scale. Defaults to 0.8." + ), + ) + parser.add_argument( + "--target_replacement1", + "--target-replacement1", + nargs="+", + metavar="SOURCE_OR_PROMPT", + default=None, + help=( + "Generate /mesh_assets/new1 from PROMPT. Accepts either " + "PROMPT, which auto-selects the lower-y duplicated rigid " + "object, or SOURCE_UID PROMPT for explicit selection." + ), + ) + parser.add_argument( + "--target_replacement2", + "--target-replacement2", + nargs="+", + metavar="SOURCE_OR_PROMPT", + default=None, + help=( + "Generate /mesh_assets/new2 from PROMPT. Accepts either " + "PROMPT, which auto-selects the higher-y duplicated rigid " + "object, or SOURCE_UID PROMPT for explicit selection." + ), + ) + parser.add_argument( + "--sync_replacement_names", + "--sync-replacement-names", + action="store_true", + default=False, + help=( + "Also update replacement target runtime UIDs and generated prompts " + "from the replacement prompts." + ), + ) + parser.add_argument( + "--reuse-target-replacements", + "--reuse_target_replacements", + dest="reuse_target_replacements", + action=argparse.BooleanOptionalAction, + default=True, + help=( + "Reuse existing prompt-generated replacement GLBs when the prompt " + "and expected output name match. Defaults to true." + ), + ) + parser.add_argument( + "--prewarm-coacd-cache", + "--prewarm_coacd_cache", + dest="prewarm_coacd_cache", + action=argparse.BooleanOptionalAction, + default=True, + help=( + "Precompute environment CoACD cache files during config generation. " + "Defaults to true." + ), + ) + parser.add_argument( + "--poll-interval", + type=float, + default=10.0, + help="Image2Tabletop job polling interval in seconds. Defaults to 10.0.", + ) + parser.add_argument( + "--skip-health-check", + action="store_true", + default=False, + help="Skip GET /health before submitting the image.", + ) + parser.add_argument( + "--overwrite-gym-project", + action="store_true", + default=False, + help="Replace an existing generated gym project with the same name.", + ) + parser.add_argument( + "--overwrite-config", + action=argparse.BooleanOptionalAction, + default=True, + help="Overwrite generated config files. Defaults to true.", + ) + parser.add_argument( + "--regenerate", + action=argparse.BooleanOptionalAction, + default=True, + help="Pass --regenerate to run_agent. Defaults to true.", + ) + parser.add_argument( + "--skip-run-agent", + action="store_true", + default=False, + help="Stop after generating config files instead of launching run_agent.", + ) + parser.add_argument( + "--llm-usage-output", + default=None, + help=( + "JSONL path for local LLM token usage records. Defaults to " + "/llm_usage.jsonl." + ), + ) + parser.add_argument( + "--llm-usage-summary-output", + default=None, + help=( + "JSON path for the aggregated local LLM token usage summary. " + "Defaults to /llm_usage_summary.json." + ), + ) + parser.add_argument( + "--llm-usage-run-id", + default=None, + help="Optional run id written into local LLM token usage records.", + ) + parser.add_argument( + "--no-llm-usage", + dest="llm_usage", + action="store_false", + default=True, + help="Disable local LLM token usage recording for this pipeline run.", + ) + return parser + + +def _ensure_repo_on_pythonpath() -> None: + if str(_REPO_ROOT) not in sys.path: + sys.path.insert(0, str(_REPO_ROOT)) + + +def _resolve_single_image( + image_input: str, + collect_image_paths: Callable[[Path], list[Path]], +) -> Path: + image_paths = collect_image_paths(Path(image_input)) + if len(image_paths) != 1: + paths = ", ".join(path.as_posix() for path in image_paths) + raise ValueError( + "This pipeline expects exactly one image, but got " + f"{len(image_paths)}: {paths}" + ) + return image_paths[0] + + +def _resolve_image_input(args: argparse.Namespace) -> Path: + if args.image_name: + return _resolve_image_name(args.image_name) + if args.image: + return Path(args.image) + return _DEFAULT_IMAGE + + +def _resolve_image_name(image_name: str) -> Path: + image_path = Path(image_name) + if image_path.parent != Path("."): + raise ValueError( + "--image-name only accepts a file name under " + f"{_DEFAULT_IMAGE_DIR.as_posix()}. Use --image for a full path." + ) + if image_path.suffix: + return _DEFAULT_IMAGE_DIR / image_path + + matches = [ + _DEFAULT_IMAGE_DIR / f"{image_name}{suffix}" for suffix in _IMAGE_SUFFIXES + ] + existing = [path for path in matches if path.exists()] + if len(existing) == 1: + return existing[0] + if not existing: + candidates = ", ".join(path.name for path in matches) + raise FileNotFoundError( + f"Image name {image_name!r} was not found. Tried: {candidates}" + ) + + matched = ", ".join(path.name for path in existing) + raise ValueError( + f"Image name {image_name!r} is ambiguous. Use --image-name with a suffix: " + f"{matched}" + ) + + +def _resolve_under_root(root: Path, path_input: str | None) -> Path | None: + if path_input is None: + return None + path = Path(path_input).expanduser() + if path.is_absolute(): + return path.resolve() + return (root / path).resolve() + + +def _image2scene_subprocess_env() -> dict[str, str]: + from embodichain.gen_sim.action_agent_pipeline.utils.llm_config import ( + get_openai_compatible_llm_config, + ) + from embodichain.gen_sim.action_agent_pipeline.utils.llm_usage import ( + scrub_usage_tracking_env, + ) + + env = scrub_usage_tracking_env() + cfg = get_openai_compatible_llm_config( + required=False, + require_base_url=False, + ) + env_overrides = { + "OPENAI_API_KEY": cfg.get("api_key"), + "OPENAI_MODEL": cfg.get("model"), + "OPENAI_BASE_URL": cfg.get("base_url"), + "EMBODICHAIN_LLM_PROXY": cfg.get("proxy_url"), + } + for name, value in env_overrides.items(): + if value: + env[name] = str(value) + + if cfg.get("model") or cfg.get("base_url"): + print( + "Using shared LLM config for image2scene subprocess: " + f"model={cfg.get('model')!r}, base_url={cfg.get('base_url')!r}", + flush=True, + ) + return env + + +def _resolve_task_description_for_generation(args: argparse.Namespace) -> str | None: + task_description = str(args.task_description or "").strip() + if args.task_name in _DEFAULT_TASK_TEMPLATE_NAMES: + if task_description: + print( + f"Ignoring --task_description for {args.task_name}; " + "using the default basket task template.", + flush=True, + ) + return None + return task_description or None + + +def _collect_merged_gym_configs(download_dir: Path) -> list[Path]: + if not download_dir.exists(): + return [] + return sorted( + path.resolve() for path in download_dir.rglob("gym_config_merged.json") + ) + + +def _latest_path(paths: list[Path]) -> Path: + return max(paths, key=lambda path: path.stat().st_mtime) + + +def _resolve_image2scene_image( + args: argparse.Namespace, image2scene_root: Path +) -> Path: + if args.image_name: + image_name = Path(args.image_name) + if image_name.parent != Path("."): + raise ValueError( + "--image-name only accepts a file name under " + f"{_DEFAULT_IMAGE_DIR.as_posix()} with " + "--use-image2scene. Use --image for a full path." + ) + if image_name.suffix: + return (_DEFAULT_IMAGE_DIR / image_name).resolve() + + matches = [ + _DEFAULT_IMAGE_DIR / f"{args.image_name}{suffix}" + for suffix in _IMAGE_SUFFIXES + ] + existing = [path.resolve() for path in matches if path.exists()] + if len(existing) == 1: + return existing[0] + if not existing: + candidates = ", ".join(path.name for path in matches) + raise FileNotFoundError( + f"Image name {args.image_name!r} was not found. Tried: {candidates}" + ) + + matched = ", ".join(path.name for path in existing) + raise ValueError( + f"Image name {args.image_name!r} is ambiguous. Use --image-name " + f"with a suffix: {matched}" + ) + + image_input = args.image or _DEFAULT_IMAGE2SCENE_IMAGE + image_path = Path(image_input).expanduser() + if image_path.is_absolute(): + return image_path.resolve() + return (image2scene_root / image_path).resolve() + + +def _run_image2scene_pipeline(args: argparse.Namespace) -> Path: + if not args.background: + raise ValueError("--background is required with --use-image2scene.") + + image2scene_root = Path(args.image2scene_root).expanduser().resolve() + if not image2scene_root.is_dir(): + raise FileNotFoundError(f"image2scene root not found: {image2scene_root}") + + script_path = image2scene_root / "demo_api/client/image2scene_pipeline.py" + if not script_path.is_file(): + raise FileNotFoundError(f"image2scene pipeline not found: {script_path}") + + image_path = _resolve_image2scene_image(args, image2scene_root) + download_dir = _resolve_under_root(image2scene_root, args.image2scene_download_dir) + output_root = _resolve_under_root(image2scene_root, args.image2scene_output_root) + gen_config = _resolve_under_root(image2scene_root, args.image2scene_gen_config) + llm_config = _resolve_under_root(image2scene_root, args.image2scene_llm_config) + extract_dir = _resolve_under_root(image2scene_root, args.image2scene_extract_dir) + merged_output = _resolve_under_root( + image2scene_root, args.image2scene_merged_output + ) + + if ( + download_dir is None + or output_root is None + or gen_config is None + or llm_config is None + ): + raise ValueError("image2scene paths must not be empty.") + + before_configs = set(_collect_merged_gym_configs(download_dir)) + command = [ + sys.executable, + str(script_path), + "--server", + args.server, + "--image", + str(image_path), + "--download-dir", + str(download_dir), + "--background", + args.background, + "--output-root", + str(output_root), + "--gen-config", + str(gen_config), + "--llm-config", + str(llm_config), + "--poll-interval", + str(args.poll_interval), + ] + if extract_dir is not None: + command.extend(["--extract-dir", str(extract_dir)]) + if merged_output is not None: + command.extend(["--merged-output", str(merged_output)]) + + print("Running image2scene pipeline:") + print(shlex.join(command), flush=True) + completed = subprocess.run( + command, + cwd=image2scene_root, + check=False, + env=_image2scene_subprocess_env(), + ) + if completed.returncode != 0: + raise RuntimeError( + f"image2scene pipeline failed with exit code {completed.returncode}" + ) + + if merged_output is not None: + if not merged_output.is_file(): + raise FileNotFoundError( + f"image2scene merged output not found: {merged_output}" + ) + print(f"Using image2scene merged gym config: {merged_output}", flush=True) + return merged_output + + after_configs = _collect_merged_gym_configs(download_dir) + new_configs = [path for path in after_configs if path not in before_configs] + if new_configs: + merged_config = _latest_path(new_configs) + elif after_configs: + merged_config = _latest_path(after_configs) + else: + raise FileNotFoundError( + f"gym_config_merged.json not found under: {download_dir}" + ) + + print(f"Using image2scene merged gym config: {merged_config}", flush=True) + return merged_config + + +def _resolve_gym_project(args: argparse.Namespace) -> ProjectResolution: + use_history = args.base_task_name is not None or args.base_history_index is not None + selected_modes = [ + args.use_image2scene, + args.use_existing_gym_project, + use_history, + ] + if sum(bool(mode) for mode in selected_modes) > 1: + raise ValueError( + "Use only one of --use-image2scene, --use-existing-gym-project, " + "or --base-task-name/--base-history-index." + ) + + if args.use_existing_gym_project: + project_path = Path(args.gym_project).expanduser().resolve() + if not project_path.exists(): + raise FileNotFoundError(f"gym project not found: {project_path}") + print(f"Using existing gym project: {project_path}", flush=True) + return ProjectResolution(path=project_path, mode="existing_gym_project") + + if args.use_image2scene: + return ProjectResolution( + path=_run_image2scene_pipeline(args), mode="image2scene" + ) + + if use_history: + history_entry = _resolve_base_history_entry(args) + project_path = _path_from_history_entry(history_entry) + print( + "Using base history " + f"#{history_entry.get('index')} ({history_entry.get('task_name')}): " + f"{project_path}", + flush=True, + ) + return ProjectResolution( + path=project_path, + mode="history", + base_history=history_entry, + ) + + from embodichain.gen_sim.action_agent_pipeline.gym_project_api.image2tabletop_client import ( + check_health, + collect_image_paths, + process_image, + ) + + image_input = _resolve_image_input(args) + image_path = _resolve_single_image(str(image_input), collect_image_paths) + if not args.skip_health_check: + check_health(args.server) + + return ProjectResolution( + path=process_image( + server=args.server, + image_path=image_path, + output_root=Path(args.gym_project_root), + poll_interval=args.poll_interval, + overwrite=args.overwrite_gym_project, + ), + mode="image2tabletop", + ) + + +def _resolve_base_history_entry(args: argparse.Namespace) -> dict[str, Any]: + if args.base_history_index is not None and args.base_history_index <= 0: + raise ValueError("--base-history-index must be a positive integer.") + + history_path = _pipeline_history_path(args) + history = _read_pipeline_history(history_path) + runs = history["runs"] + + if args.base_history_index is not None: + entry = _find_history_entry_by_index(runs, args.base_history_index) + if entry is None: + raise ValueError( + f"Pipeline history index not found: {args.base_history_index}" + ) + if args.base_task_name and entry.get("task_name") != args.base_task_name: + raise ValueError( + "Pipeline history entry " + f"#{args.base_history_index} has task_name={entry.get('task_name')!r}, " + f"expected {args.base_task_name!r}." + ) + return dict(entry) + + if not args.base_task_name: + raise ValueError("--base-task-name is required without --base-history-index.") + + candidates = [ + entry + for entry in runs + if entry.get("task_name") == args.base_task_name + and _history_entry_has_source(entry) + ] + if not candidates: + raise ValueError( + "No pipeline history entry found for task_name=" + f"{args.base_task_name!r} in {history_path}" + ) + return dict(max(candidates, key=_history_entry_index)) + + +def _pipeline_history_path(args: argparse.Namespace) -> Path: + return _records_pipeline_history_path(args) + + +def _read_pipeline_history(history_path: Path) -> dict[str, Any]: + return _records_read_pipeline_history( + history_path, + schema_version=_PIPELINE_HISTORY_SCHEMA_VERSION, + ) + + +def _find_history_entry_by_index( + runs: list[Any], history_index: int +) -> dict[str, Any] | None: + return _records_find_history_entry_by_index(runs, history_index) + + +def _history_entry_index(entry: dict[str, Any]) -> int: + return _records_history_entry_index(entry) + + +def _history_entry_has_source(entry: dict[str, Any]) -> bool: + return _records_history_entry_has_source(entry) + + +def _path_from_history_entry(entry: dict[str, Any]) -> Path: + return _records_path_from_history_entry(entry, repo_root=_REPO_ROOT) + + +def _resolve_target_replacements( + args: argparse.Namespace, + target_replacement_spec_cls: Callable[..., object], + gym_project: Path, +) -> list[object]: + replacements = [] + alias_config = None + if args.target_replacement1: + alias_config = alias_config or _load_replacement_alias_config(gym_project) + source_uid, prompt = _resolve_target_replacement_arg( + args.target_replacement1, + alias_config, + option_name="--target_replacement1", + replacement_number=1, + ) + replacements.append( + target_replacement_spec_cls( + source_uid=source_uid, + prompt=prompt, + output_dir_name="new1", + ) + ) + if args.target_replacement2: + alias_config = alias_config or _load_replacement_alias_config(gym_project) + source_uid, prompt = _resolve_target_replacement_arg( + args.target_replacement2, + alias_config, + option_name="--target_replacement2", + replacement_number=2, + ) + replacements.append( + target_replacement_spec_cls( + source_uid=source_uid, + prompt=prompt, + output_dir_name="new2", + ) + ) + return replacements + + +def _resolve_target_replacement_arg( + values: list[str], + gym_config: dict[str, Any], + *, + option_name: str, + replacement_number: int, +) -> tuple[str, str]: + if len(values) == 1: + prompt = str(values[0]).strip() + if not prompt: + raise ValueError(f"{option_name} prompt must be non-empty.") + source_uid = _auto_replacement_source_uid( + gym_config, + replacement_number=replacement_number, + option_name=option_name, + ) + return source_uid, prompt + + if len(values) == 2: + source_uid, prompt = values + prompt = str(prompt).strip() + if not prompt: + raise ValueError(f"{option_name} prompt must be non-empty.") + source_uid = _resolve_replacement_source_uid( + source_uid, + gym_config, + option_name=option_name, + ) + return source_uid, prompt + + raise ValueError( + f"{option_name} expects either PROMPT or SOURCE_UID PROMPT, got " + f"{len(values)} values: {values!r}. Quote multi-word prompts." + ) + + +def _load_replacement_alias_config(gym_project: Path) -> dict[str, Any]: + config_path = _resolve_replacement_alias_gym_config(gym_project) + data = json.loads(config_path.read_text(encoding="utf-8")) + if not isinstance(data, dict): + raise ValueError(f"Gym config must be a JSON object: {config_path}") + return data + + +def _resolve_replacement_alias_gym_config(input_path: Path) -> Path: + input_path = input_path.expanduser().resolve() + if input_path.is_file(): + sibling_gym_config = input_path.parent / "gym_config.json" + if sibling_gym_config.is_file(): + return sibling_gym_config.resolve() + return _resolve_source_gym_config(input_path) + + direct_gym_config = input_path / "gym_config.json" + if direct_gym_config.is_file(): + return direct_gym_config.resolve() + + source_config = _resolve_source_gym_config(input_path) + sibling_gym_config = source_config.parent / "gym_config.json" + if sibling_gym_config.is_file(): + return sibling_gym_config.resolve() + return source_config + + +def _auto_replacement_source_uid( + gym_config: dict[str, Any], + *, + replacement_number: int, + option_name: str, +) -> str: + if replacement_number not in {1, 2}: + raise ValueError(f"Unsupported replacement number: {replacement_number}") + + duplicate_groups = _duplicated_numbered_rigid_object_groups(gym_config) + if len(duplicate_groups) != 1: + candidates = _format_duplicate_group_candidates(duplicate_groups) + raise ValueError( + f"{option_name} was given without an explicit source uid, so the " + "pipeline expected exactly one duplicated numbered rigid_object " + f"group in gym_config.json. Found {len(duplicate_groups)} group(s): " + f"{candidates}. Use SOURCE_UID PROMPT to disambiguate." + ) + + base_name, positioned_objects = duplicate_groups[0] + if len(positioned_objects) != 2: + candidates = _format_duplicate_group_candidates(duplicate_groups) + raise ValueError( + f"{option_name} auto-selection requires exactly two objects in the " + f"duplicated group {base_name!r}, found {len(positioned_objects)}: " + f"{candidates}. Use SOURCE_UID PROMPT to disambiguate." + ) + + if ( + abs(float(positioned_objects[0]["y"]) - float(positioned_objects[1]["y"])) + < 1e-9 + ): + candidates = _format_duplicate_group_candidates(duplicate_groups) + raise ValueError( + f"{option_name} auto-selection requires distinct y coordinates in " + f"duplicated group {base_name!r}: {candidates}. Use SOURCE_UID PROMPT " + "to disambiguate." + ) + + selected = positioned_objects[replacement_number - 1] + source_uid = selected["object"]["uid"] + print( + f"Resolved {option_name} auto source -> {source_uid!r} " + f"from duplicated rigid_object group {base_name!r} by y={selected['y']}", + flush=True, + ) + return source_uid + + +def _duplicated_numbered_rigid_object_groups( + gym_config: dict[str, Any], +) -> list[tuple[str, list[dict[str, Any]]]]: + grouped: dict[str, list[dict[str, Any]]] = {} + for obj in _rigid_objects(gym_config): + parsed = _parse_numbered_rigid_object_uid(obj["uid"]) + if parsed is None: + continue + base_name, number = parsed + grouped.setdefault(base_name, []).append( + { + "number": number, + "y": _rigid_object_y_coordinate(obj), + "object": obj, + } + ) + + duplicate_groups = [] + for base_name, entries in grouped.items(): + if len(entries) < 2: + continue + duplicate_groups.append( + ( + base_name, + sorted( + entries, + key=lambda entry: ( + float(entry["y"]), + str(entry["object"]["uid"]), + ), + ), + ) + ) + return sorted(duplicate_groups, key=lambda item: item[0]) + + +def _parse_numbered_rigid_object_uid(uid: str) -> tuple[str, int] | None: + match = re.match(r"^(?P.+?)[_-]?(?P[0-9]+)$", uid) + if match is None: + return None + base_name = match.group("base").strip("_-") + if not base_name: + return None + return base_name, int(match.group("number")) + + +def _rigid_object_y_coordinate(obj: dict[str, Any]) -> float: + init_pos = obj.get("init_pos") + if not isinstance(init_pos, (list, tuple)) or len(init_pos) < 2: + raise ValueError( + "Auto replacement source selection requires each duplicated " + f"rigid_object to define init_pos with a y value, got {obj.get('uid')!r}." + ) + try: + return float(init_pos[1]) + except (TypeError, ValueError) as exc: + raise ValueError( + "Auto replacement source selection requires numeric init_pos[1], " + f"got {obj.get('uid')!r}: {init_pos[1]!r}" + ) from exc + + +def _format_duplicate_group_candidates( + groups: list[tuple[str, list[dict[str, Any]]]], +) -> str: + if not groups: + return "" + parts = [] + for base_name, entries in groups: + values = ", ".join( + f"{entry['object']['uid']}#number={entry['number']},y={entry['y']}" + for entry in entries + ) + parts.append(f"{base_name}: {values}") + return "; ".join(parts) + + +def _resolve_replacement_source_uid( + source_input: str, + gym_config: dict[str, Any], + *, + option_name: str, +) -> str: + source_input = str(source_input).strip() + rigid_objects = _rigid_objects(gym_config) + by_uid = {obj["uid"]: obj for obj in rigid_objects} + if source_input in by_uid: + return source_input + + alias = _parse_indexed_replacement_alias(source_input) + if alias is None: + candidates = _format_rigid_object_candidates(rigid_objects) + raise ValueError( + f"{option_name} source {source_input!r} is neither a rigid object uid " + f"nor an indexed alias such as bread1. Rigid object candidates: " + f"{candidates}" + ) + + keyword, alias_index = alias + matches = [ + obj for obj in rigid_objects if _rigid_object_matches_keyword(obj, keyword) + ] + if alias_index > len(matches): + candidates = _format_rigid_object_candidates(matches or rigid_objects) + raise ValueError( + f"{option_name} alias {source_input!r} requested match #{alias_index} " + f"for keyword {keyword!r}, but only found {len(matches)} match(es). " + f"Candidates: {candidates}" + ) + + resolved_uid = matches[alias_index - 1]["uid"] + print( + f"Resolved {option_name} source alias {source_input!r} -> {resolved_uid!r}", + flush=True, + ) + return resolved_uid + + +def _rigid_objects(gym_config: dict[str, Any]) -> list[dict[str, Any]]: + value = gym_config.get("rigid_object", []) + if isinstance(value, dict): + value = [value] + if not isinstance(value, list): + raise ValueError("gym config rigid_object must be a list or object.") + + rigid_objects = [] + for obj in value: + if not isinstance(obj, dict): + continue + uid = str(obj.get("uid", "")).strip() + if not uid: + continue + copied = dict(obj) + copied["uid"] = uid + rigid_objects.append(copied) + if not rigid_objects: + raise ValueError("No rigid_object entries found in gym config.") + return rigid_objects + + +def _parse_indexed_replacement_alias(alias: str) -> tuple[str, int] | None: + match = _INDEXED_REPLACEMENT_ALIAS_RE.match(alias.strip()) + if match is None: + return None + keyword = match.group("keyword").strip(" _-") + index = int(match.group("index")) + if not keyword or index < 1: + return None + return keyword, index + + +def _rigid_object_matches_keyword(obj: dict[str, Any], keyword: str) -> bool: + keyword_tokens = _search_tokens(keyword) + if not keyword_tokens: + return False + object_tokens = set(_search_tokens(_rigid_object_search_text(obj))) + return all(token in object_tokens for token in keyword_tokens) + + +def _rigid_object_search_text(obj: dict[str, Any]) -> str: + values = [ + obj.get("uid", ""), + obj.get("source_uid", ""), + obj.get("category", ""), + obj.get("semantic_label", ""), + obj.get("name", ""), + obj.get("description", ""), + ] + shape = obj.get("shape", {}) + if isinstance(shape, dict): + values.extend( + [ + shape.get("fpath", ""), + shape.get("file_path", ""), + shape.get("category", ""), + ] + ) + return " ".join(str(value) for value in values if value) + + +def _search_tokens(value: str) -> list[str]: + return re.findall(r"[a-z0-9]+", str(value).lower()) + + +def _format_rigid_object_candidates(rigid_objects: list[dict[str, Any]]) -> str: + if not rigid_objects: + return "" + parts = [] + for obj in rigid_objects: + shape = obj.get("shape", {}) + fpath = shape.get("fpath", "") if isinstance(shape, dict) else "" + parts.append(f"{obj.get('uid')} ({fpath})") + return ", ".join(parts) + + +def _write_pipeline_manifests( + *, + args: argparse.Namespace, + resolution: ProjectResolution, + generated_paths: object, + target_replacements: list[object], +) -> dict[str, Any]: + return _records_write_pipeline_manifests( + args=args, + resolution=resolution, + generated_paths=generated_paths, + target_replacements=target_replacements, + repo_root=_REPO_ROOT, + schema_version=_PIPELINE_HISTORY_SCHEMA_VERSION, + manifest_filename=_PIPELINE_MANIFEST_FILENAME, + ) + + +def _resolve_source_gym_config(input_path: Path) -> Path: + return _records_resolve_source_gym_config( + input_path, + gym_config_preference=_GYM_CONFIG_PREFERENCE, + ) + + +def _configure_llm_usage_tracking( + args: argparse.Namespace, +) -> tuple[Path, Path] | None: + if not args.llm_usage: + from embodichain.gen_sim.action_agent_pipeline.utils.llm_usage import ( + disable_usage_tracking, + ) + + disable_usage_tracking() + return None + + from embodichain.gen_sim.action_agent_pipeline.utils.llm_usage import ( + configure_usage_tracking, + ) + + output_dir = Path(args.config_output_dir).expanduser().resolve() + usage_path = ( + Path(args.llm_usage_output).expanduser().resolve() + if args.llm_usage_output + else output_dir / "llm_usage.jsonl" + ) + summary_path = ( + Path(args.llm_usage_summary_output).expanduser().resolve() + if args.llm_usage_summary_output + else output_dir / "llm_usage_summary.json" + ) + run_id = args.llm_usage_run_id or ( + f"{args.task_name}_{datetime.now().astimezone().strftime('%Y%m%d_%H%M%S')}" + ) + configure_usage_tracking( + usage_path=usage_path, + run_id=run_id, + process_name="run_agent_pipeline", + reset=True, + ) + print(f"Recording local LLM token usage: {usage_path}", flush=True) + print(f"Local LLM token usage summary: {summary_path}", flush=True) + return usage_path, summary_path + + +def _write_llm_usage_summary(usage_paths: tuple[Path, Path] | None) -> None: + if usage_paths is None: + return + + from embodichain.gen_sim.action_agent_pipeline.utils.llm_usage import ( + write_usage_summary, + ) + + usage_path, summary_path = usage_paths + summary = write_usage_summary( + usage_path=usage_path, + summary_path=summary_path, + ) + total = summary["total"] + print( + "Local LLM token usage total: " + f"calls={total['calls']}, " + f"input={total['input_tokens']}, " + f"output={total['output_tokens']}, " + f"total={total['total_tokens']}", + flush=True, + ) + + +def _run_agent_command( + *, + task_name: str, + gym_config: Path, + agent_config: Path, + regenerate: bool, +) -> int: + command = [ + sys.executable, + "-m", + "embodichain.gen_sim.action_agent_pipeline.cli.run_agent", + "--task_name", + task_name, + "--gym_config", + str(gym_config), + "--agent_config", + str(agent_config), + ] + if regenerate: + command.append("--regenerate") + + env = os.environ.copy() + if env.get("EMBODICHAIN_LLM_USAGE_PATH"): + env["EMBODICHAIN_LLM_USAGE_PROCESS"] = "run_agent" + + print("Running task:") + print(shlex.join(command), flush=True) + return subprocess.run(command, check=False, env=env).returncode + + +def main() -> int: + args = _build_parser().parse_args() + + _ensure_repo_on_pythonpath() + from embodichain.gen_sim.action_agent_pipeline.generation.ur5_basket_config import ( + TargetReplacementSpec, + generate_ur5_basket_config_from_project, + ) + + resolution = _resolve_gym_project(args) + usage_paths = _configure_llm_usage_tracking(args) + target_replacements = _resolve_target_replacements( + args, + TargetReplacementSpec, + resolution.path, + ) + task_description = _resolve_task_description_for_generation(args) + args.task_description = task_description or "" + + paths = generate_ur5_basket_config_from_project( + gym_project=resolution.path, + output_dir=args.config_output_dir, + task_name=args.task_name, + task_description=task_description, + target_body_scale=args.target_body_scale, + target_replacements=target_replacements, + sync_replacement_names=args.sync_replacement_names, + reuse_target_replacements=args.reuse_target_replacements, + prewarm_coacd_cache=args.prewarm_coacd_cache, + overwrite=args.overwrite_config, + ) + _write_pipeline_manifests( + args=args, + resolution=resolution, + generated_paths=paths, + target_replacements=target_replacements, + ) + + print(f"Using gym project/config: {resolution.path}", flush=True) + print(f"Generated gym config: {paths.gym_config}", flush=True) + print(f"Generated agent config: {paths.agent_config}", flush=True) + if args.skip_run_agent: + _write_llm_usage_summary(usage_paths) + return 0 + + return_code = _run_agent_command( + task_name=args.task_name, + gym_config=paths.gym_config, + agent_config=paths.agent_config, + regenerate=args.regenerate, + ) + _write_llm_usage_summary(usage_paths) + return return_code + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/embodichain/gen_sim/action_agent_pipeline/env_adapters/__init__.py b/embodichain/gen_sim/action_agent_pipeline/env_adapters/__init__.py new file mode 100644 index 00000000..015c4151 --- /dev/null +++ b/embodichain/gen_sim/action_agent_pipeline/env_adapters/__init__.py @@ -0,0 +1,19 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +__all__: list[str] = [] diff --git a/embodichain/gen_sim/action_agent_pipeline/env_adapters/tableware/__init__.py b/embodichain/gen_sim/action_agent_pipeline/env_adapters/tableware/__init__.py new file mode 100644 index 00000000..015c4151 --- /dev/null +++ b/embodichain/gen_sim/action_agent_pipeline/env_adapters/tableware/__init__.py @@ -0,0 +1,19 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +__all__: list[str] = [] diff --git a/embodichain/gen_sim/action_agent_pipeline/env_adapters/tableware/agent_env.py b/embodichain/gen_sim/action_agent_pipeline/env_adapters/tableware/agent_env.py new file mode 100644 index 00000000..907af3c6 --- /dev/null +++ b/embodichain/gen_sim/action_agent_pipeline/env_adapters/tableware/agent_env.py @@ -0,0 +1,54 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +import torch + +from embodichain.lab.gym.envs import EmbodiedEnv, EmbodiedEnvCfg +from embodichain.gen_sim.action_agent_pipeline.env_adapters.tableware.base_agent_env import ( + BaseAgentEnv, +) +from embodichain.gen_sim.action_agent_pipeline.env_adapters.tableware.success import ( + evaluate_configured_success, +) +from embodichain.lab.gym.utils.registration import register_env + +__all__ = ["AtomicActionsAgentEnv"] + + +@register_env("AtomicActionsAgent-v3", max_episode_steps=600) +class AtomicActionsAgentEnv(BaseAgentEnv, EmbodiedEnv): + """Config-driven agent environment for atomic-action tasks.""" + + def __init__(self, cfg: EmbodiedEnvCfg = None, **kwargs): + super().__init__(cfg, **kwargs) + if bool(getattr(self, "ignore_terminations_during_agent", False)): + self.cfg.ignore_terminations = True + super()._init_agents(**kwargs) + + def reset(self, seed: int | None = None, options: dict | None = None): + obs, info = super().reset(seed=seed, options=options) + super().get_states() + return obs, info + + def is_task_success(self, **kwargs) -> torch.Tensor: + return evaluate_configured_success(self) + + def compute_task_state(self, **kwargs) -> tuple[torch.Tensor, torch.Tensor, dict]: + success = self.is_task_success() + fail = torch.zeros_like(success) + return success, fail, {} diff --git a/embodichain/gen_sim/action_agent_pipeline/env_adapters/tableware/base_agent_env.py b/embodichain/gen_sim/action_agent_pipeline/env_adapters/tableware/base_agent_env.py new file mode 100644 index 00000000..b239834d --- /dev/null +++ b/embodichain/gen_sim/action_agent_pipeline/env_adapters/tableware/base_agent_env.py @@ -0,0 +1,335 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from copy import deepcopy + +import torch +from embodichain.utils import logger + +_TASK_PROMPT_KEYS = frozenset({"task_prompt", "basic_background", "atom_actions"}) + + +class BaseAgentEnv: + + def _init_agents(self, agent_config, task_name, agent_config_path=None): + from embodichain.gen_sim.action_agent_pipeline.agents.task_agent import ( + TaskAgent, + ) + from embodichain.gen_sim.action_agent_pipeline.agents.compile_agent import ( + CompileAgent, + ) + from embodichain.gen_sim.action_agent_pipeline.agents.llm import ( + task_llm, + ) + + task_agent_config = self._agent_config_with_prompt_keys( + agent_config["Agent"], + _TASK_PROMPT_KEYS, + ) + compile_agent_config = self._agent_config_with_prompt_keys( + agent_config["Agent"], + frozenset(), + ) + self.task_agent = TaskAgent( + task_llm, + **task_agent_config, + **agent_config["TaskAgent"], + task_name=task_name, + config_dir=agent_config_path, + ) + self.compile_agent = CompileAgent( + **compile_agent_config, + **agent_config["CompileAgent"], + task_name=task_name, + config_dir=agent_config_path, + ) + + def _agent_config_with_prompt_keys(self, agent_config, allowed_keys): + filtered = deepcopy(agent_config) + prompt_kwargs = filtered.get("prompt_kwargs", {}) or {} + filtered["prompt_kwargs"] = { + key: value for key, value in prompt_kwargs.items() if key in allowed_keys + } + return filtered + + def get_states(self): + # TODO: only support num_env = 1 for now + # store robot states in each env.reset + self.init_qpos = self.robot.get_qpos().squeeze(0) + + self._agent_arm_slots = self._resolve_agent_arm_slots() + for side in ("left", "right"): + self._initialize_agent_arm_slot(side, self._agent_arm_slots.get(side)) + + self.open_state = torch.as_tensor( + getattr( + self, + "agent_open_state", + getattr(self, "gripper_open_state", [0.05]), + ), + dtype=self.init_qpos.dtype, + device=self.init_qpos.device, + ).flatten() + self.close_state = torch.as_tensor( + getattr( + self, + "agent_close_state", + getattr(self, "gripper_close_state", [0.0]), + ), + dtype=self.init_qpos.dtype, + device=self.init_qpos.device, + ).flatten() + self.left_arm_current_gripper_state = self._initial_gripper_state("left") + self.right_arm_current_gripper_state = self._initial_gripper_state("right") + + self.update_obj_info() + + def _resolve_agent_arm_slots(self) -> dict[str, dict[str, str | None] | None]: + configured_slots = getattr(self, "agent_arm_slots", None) + if configured_slots is not None: + return self._normalize_agent_arm_slots(configured_slots) + + if hasattr(self, "single_arm_name") or hasattr(self, "single_eef_name"): + slot = getattr(self, "agent_single_arm_slot", "right") + return self._normalize_agent_arm_slots( + { + slot: { + "arm": getattr(self, "single_arm_name", "right_arm"), + "eef": getattr(self, "single_eef_name", "right_eef"), + } + } + ) + + control_parts = getattr(self.robot, "control_parts", {}) or {} + if "arm" in control_parts and "hand" in control_parts: + slot = getattr(self, "agent_single_arm_slot", "left") + return self._normalize_agent_arm_slots( + {slot: {"arm": "arm", "eef": "hand"}} + ) + + return self._normalize_agent_arm_slots( + { + "left": {"arm": "left_arm", "eef": "left_eef"}, + "right": {"arm": "right_arm", "eef": "right_eef"}, + } + ) + + def _normalize_agent_arm_slots( + self, slots + ) -> dict[str, dict[str, str | None] | None]: + normalized = {"left": None, "right": None} + for side in normalized: + slot_cfg = slots.get(side) if isinstance(slots, dict) else None + if slot_cfg is None: + continue + if isinstance(slot_cfg, str): + normalized[side] = {"arm": slot_cfg, "eef": None} + continue + normalized[side] = { + "arm": slot_cfg.get("arm", slot_cfg.get("arm_control_part")), + "eef": slot_cfg.get( + "eef", + slot_cfg.get("hand", slot_cfg.get("eef_control_part")), + ), + } + return normalized + + def _initialize_agent_arm_slot( + self, side: str, slot_cfg: dict[str, str | None] | None + ) -> None: + arm_name = slot_cfg.get("arm") if slot_cfg else None + eef_name = slot_cfg.get("eef") if slot_cfg else None + arm_joints = self._get_control_part_joint_ids(arm_name) + eef_joints = self._get_control_part_joint_ids(eef_name) + + setattr(self, f"{side}_arm_joints", arm_joints) + setattr(self, f"{side}_eef_joints", eef_joints) + + if arm_name is None or not arm_joints: + setattr(self, f"{side}_arm_init_qpos", self.init_qpos.new_empty(0)) + setattr(self, f"{side}_arm_init_xpos", None) + setattr(self, f"{side}_arm_base_pose", None) + setattr(self, f"{side}_arm_current_qpos", self.init_qpos.new_empty(0)) + setattr(self, f"{side}_arm_current_xpos", None) + return + + init_qpos = self.init_qpos[arm_joints] + init_xpos = self.robot.compute_fk( + init_qpos, name=arm_name, to_matrix=True + ).squeeze(0) + base_pose = self.robot.get_control_part_base_pose( + arm_name, to_matrix=True + ).squeeze(0) + + setattr(self, f"{side}_arm_init_qpos", init_qpos) + setattr(self, f"{side}_arm_init_xpos", init_xpos) + setattr(self, f"{side}_arm_base_pose", base_pose) + setattr(self, f"{side}_arm_current_qpos", init_qpos) + setattr(self, f"{side}_arm_current_xpos", init_xpos) + + def _get_control_part_joint_ids(self, control_part: str | None) -> list[int]: + if control_part is None: + return [] + if control_part not in (getattr(self.robot, "control_parts", {}) or {}): + return [] + return list(self.robot.get_joint_ids(name=control_part)) + + def _initial_gripper_state(self, side: str) -> torch.Tensor: + if len(getattr(self, f"{side}_eef_joints", []) or []) == 0: + return self.open_state.new_empty(0) + return self.open_state + + def update_obj_info(self): + # store some useful obj information + obj_info = getattr(self, "obj_info", {}) + obj_uids = self.sim.get_rigid_object_uid_list() + for obj_name in obj_uids: + obj = self.sim.get_rigid_object(obj_name) + obj_pose = obj.get_local_pose(to_matrix=True).squeeze(0) + + if obj_name not in obj_info: + obj_height = obj_pose[2, 3] # Extract the height (z-coordinate) + obj_info[obj_name] = { + "pose": obj_pose, # Store the full pose (4x4 matrix) + "height": obj_height, # Store the initial height (z-coordinate) + } + else: + obj_info[obj_name]["pose"] = obj_pose + + self.obj_info = obj_info + + # -------------------- Common getters / setters -------------------- + + def get_obs_for_agent(self): + obs = self.get_obs() + rgb = obs["sensor"]["cam_high"]["color"].squeeze(0) + + # Get validation camera data + camera_data = self.event_manager.get_functor("validation_cameras")(self, None) + result = {"rgb": rgb} + result.update({k: v.squeeze(0) for k, v in camera_data.items()}) + return result + + def get_current_qpos_agent(self): + return self.left_arm_current_qpos, self.right_arm_current_qpos + + def set_current_qpos_agent(self, arm_qpos, is_left): + if is_left: + self.left_arm_current_qpos = arm_qpos + else: + self.right_arm_current_qpos = arm_qpos + + def get_current_xpos_agent(self): + return self.left_arm_current_xpos, self.right_arm_current_xpos + + def set_current_xpos_agent(self, arm_xpos, is_left): + if is_left: + self.left_arm_current_xpos = arm_xpos + else: + self.right_arm_current_xpos = arm_xpos + + def get_current_gripper_state_agent(self): + return self.left_arm_current_gripper_state, self.right_arm_current_gripper_state + + def set_current_gripper_state_agent(self, arm_gripper_state, is_left): + if is_left: + self.left_arm_current_gripper_state = arm_gripper_state + else: + self.right_arm_current_gripper_state = arm_gripper_state + + # -------------------- IK / FK -------------------- + def get_arm_ik(self, target_xpos, is_left, qpos_seed=None): + control_part = self.get_agent_arm_control_part(is_left) + ret, qpos = self.robot.compute_ik( + name=control_part, pose=target_xpos, joint_seed=qpos_seed + ) + return ret.all().item(), qpos.squeeze(0) + + def get_arm_fk(self, qpos, is_left): + control_part = self.get_agent_arm_control_part(is_left) + xpos = self.robot.compute_fk( + name=control_part, qpos=torch.as_tensor(qpos), to_matrix=True + ) + return xpos.squeeze(0) + + def get_agent_arm_control_part(self, is_left: bool) -> str: + return self._get_agent_control_part(is_left=is_left, key="arm") + + def get_agent_eef_control_part(self, is_left: bool) -> str | None: + return self._get_agent_control_part(is_left=is_left, key="eef", required=False) + + def _get_agent_control_part( + self, is_left: bool, key: str, required: bool = True + ) -> str | None: + if not hasattr(self, "_agent_arm_slots"): + self._agent_arm_slots = self._resolve_agent_arm_slots() + side = "left" if is_left else "right" + slot_cfg = getattr(self, "_agent_arm_slots", {}).get(side) + control_part = slot_cfg.get(key) if slot_cfg else None + if control_part is None and required: + logger.log_error( + f"{side}_{key} is not configured for agent control.", + error_type=ValueError, + ) + return control_part + + # -------------------- get compiled graph for action list -------------------- + def generate_graph_for_actions(self, regenerate=False, **kwargs): + logger.log_info( + "Generate graph for creating action list for " + f"{self.compile_agent.task_name}.", + color="green", + ) + + print(f"\033[92m\nStart task graph generation.\n\033[0m") + task_agent_input = self.task_agent.get_composed_observations( + env=self, + regenerate=regenerate, + observations=self.get_obs_for_agent(), + **kwargs, + ) + task_graph = self.task_agent.generate(**task_agent_input) + + print(f"\033[94m\nStart graph compilation.\n\033[0m") + compile_agent_input = self.compile_agent.get_composed_observations( + env=self, + regenerate=regenerate, + task_graph=task_graph, + **kwargs, + ) + graph_file_path, kwargs, graph_content = self.compile_agent.generate( + **compile_agent_input + ) + + return graph_file_path, kwargs, graph_content + + # -------------------- get action list -------------------- + def create_demo_action_list(self, regenerate=False, *args, **kwargs): + graph_file_path, compile_kwargs, _ = self.generate_graph_for_actions( + regenerate=regenerate + ) + atomic_action_kwargs = { + "allow_grasp_annotation": True, + "force_grasp_reannotate": False, + } + for key in atomic_action_kwargs: + if key in kwargs: + atomic_action_kwargs[key] = kwargs[key] + compile_kwargs.update(atomic_action_kwargs) + action_list = self.compile_agent.act(graph_file_path, **compile_kwargs) + return action_list diff --git a/embodichain/gen_sim/action_agent_pipeline/env_adapters/tableware/success.py b/embodichain/gen_sim/action_agent_pipeline/env_adapters/tableware/success.py new file mode 100644 index 00000000..23de84b9 --- /dev/null +++ b/embodichain/gen_sim/action_agent_pipeline/env_adapters/tableware/success.py @@ -0,0 +1,237 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from collections.abc import Mapping, Sequence +from typing import Any + +import torch + +__all__ = ["evaluate_configured_success"] + + +def evaluate_configured_success( + env, + spec: Mapping[str, Any] | None = None, +) -> torch.Tensor: + """Evaluate action-agent task success predicates from env config.""" + success_spec = spec or getattr(env, "agent_success", None) + if success_spec is None: + return _constant(env, False) + return _evaluate_spec(env, success_spec) + + +def _evaluate_spec( + env, + spec: Mapping[str, Any] | Sequence[Mapping[str, Any]], +) -> torch.Tensor: + if isinstance(spec, Sequence) and not isinstance(spec, (str, bytes, Mapping)): + return _evaluate_all(env, spec) + if not isinstance(spec, Mapping): + raise TypeError(f"Success spec must be a mapping, got {type(spec)}.") + + op = str(spec.get("op", "")).lower() + if not op and "terms" in spec and "type" not in spec and "func" not in spec: + op = "all" + if op in {"all", "and"}: + return _evaluate_all(env, spec.get("terms", [])) + if op in {"any", "or"}: + return _evaluate_any(env, spec.get("terms", [])) + if op == "not": + term = spec.get("term") + terms = spec.get("terms") + if term is None and isinstance(terms, Sequence) and len(terms) == 1: + term = terms[0] + if term is None: + raise ValueError("Success op 'not' requires exactly one term.") + return ~_evaluate_spec(env, term) + + term_type = str(spec.get("type", spec.get("func", ""))).lower() + if term_type in {"object_position_near", "object_near_position"}: + return _object_position_near(env, spec) + if term_type in {"object_xy_near", "object_near_xy"}: + return _object_xy_near(env, spec) + if term_type == "object_in_container": + return _object_in_container(env, spec) + if term_type in {"object_on_object", "object_on", "on_object"}: + return _object_on_object(env, spec) + if term_type in {"object_not_fallen", "not_fallen"}: + return _object_not_fallen(env, spec) + if term_type in {"object_axis_offset_near", "object_relative_axis_near"}: + return _object_axis_offset_near(env, spec) + if term_type in {"object_axis_near", "object_coordinate_near"}: + return _object_axis_near(env, spec) + if term_type in {"object_lifted", "object_height_above_initial"}: + return _object_lifted(env, spec) + raise ValueError(f"Unsupported success term type: {term_type!r}.") + + +def _evaluate_all(env, terms: Sequence[Mapping[str, Any]]) -> torch.Tensor: + success = _constant(env, True) + for term in terms: + success = success & _evaluate_spec(env, term) + return success + + +def _evaluate_any(env, terms: Sequence[Mapping[str, Any]]) -> torch.Tensor: + success = _constant(env, False) + for term in terms: + success = success | _evaluate_spec(env, term) + return success + + +def _constant(env, value: bool) -> torch.Tensor: + return torch.full((env.num_envs,), value, dtype=torch.bool, device=env.device) + + +def _pose(env, uid: str) -> torch.Tensor: + return env.sim.get_rigid_object(uid).get_local_pose(to_matrix=True) + + +def _position(env, uid: str) -> torch.Tensor: + return _pose(env, uid)[:, :3, 3] + + +def _tensor(value: Any, *, dtype: torch.dtype, device: torch.device) -> torch.Tensor: + return torch.as_tensor(value, dtype=dtype, device=device) + + +def _object_name(spec: Mapping[str, Any]) -> str: + return str(spec.get("object", spec.get("object_uid"))) + + +def _object_position_near(env, spec: Mapping[str, Any]) -> torch.Tensor: + position = _position(env, _object_name(spec)) + target = _tensor( + spec.get("target_position", spec.get("position", spec.get("target"))), + dtype=position.dtype, + device=position.device, + ).flatten() + if target.numel() == 2: + return _object_xy_near(env, {**spec, "target_xy": target}) + target = target.reshape(1, 3) + return torch.linalg.norm(position - target, dim=-1) <= float( + spec.get("tolerance", 0.05) + ) + + +def _object_xy_near(env, spec: Mapping[str, Any]) -> torch.Tensor: + position = _position(env, _object_name(spec)) + target_xy = _tensor( + spec.get("target_xy", spec.get("xy", spec.get("target"))), + dtype=position.dtype, + device=position.device, + ).flatten()[:2] + tolerance = float(spec.get("tolerance", spec.get("xy_tolerance", 0.05))) + return ( + torch.linalg.norm(position[:, :2] - target_xy.reshape(1, 2), dim=-1) + <= tolerance + ) + + +def _object_in_container(env, spec: Mapping[str, Any]) -> torch.Tensor: + object_position = _position(env, _object_name(spec)) + container_position = _position( + env, + str(spec.get("container", spec.get("container_uid"))), + ) + xy_distance = torch.linalg.norm( + object_position[:, :2] - container_position[:, :2], + dim=-1, + ) + z_offset = object_position[:, 2] - container_position[:, 2] + return ( + (xy_distance <= float(spec.get("xy_radius", spec.get("radius", 0.1)))) + & (z_offset >= float(spec.get("min_z_offset", -0.03))) + & (z_offset <= float(spec.get("max_z_offset", 0.25))) + ) + + +def _object_on_object(env, spec: Mapping[str, Any]) -> torch.Tensor: + object_position = _position(env, _object_name(spec)) + support_position = _position( + env, + str( + spec.get( + "support", + spec.get("support_uid", spec.get("reference", spec.get("container"))), + ) + ), + ) + xy_distance = torch.linalg.norm( + object_position[:, :2] - support_position[:, :2], + dim=-1, + ) + z_offset = object_position[:, 2] - support_position[:, 2] + return ( + (xy_distance <= float(spec.get("xy_radius", spec.get("radius", 0.08)))) + & (z_offset >= float(spec.get("min_z_offset", 0.02))) + & (z_offset <= float(spec.get("max_z_offset", 0.35))) + ) + + +def _object_not_fallen(env, spec: Mapping[str, Any]) -> torch.Tensor: + pose = _pose(env, _object_name(spec)) + pose_z_axis = pose[:, :3, 2] + world_z_axis = torch.tensor([0, 0, 1], dtype=pose.dtype, device=pose.device) + dot_product = torch.sum(pose_z_axis * world_z_axis, dim=-1).clamp(-1.0, 1.0) + return torch.arccos(dot_product) < float(spec.get("max_tilt", torch.pi / 4)) + + +def _object_axis_offset_near(env, spec: Mapping[str, Any]) -> torch.Tensor: + object_position = _position(env, _object_name(spec)) + reference_position = _position( + env, + str(spec.get("reference", spec.get("reference_uid"))), + ) + axis = _axis_index(str(spec.get("axis", "y"))) + target_value = reference_position[:, axis] + float(spec.get("offset", 0.0)) + return torch.abs(object_position[:, axis] - target_value) <= float( + spec.get("tolerance", 0.02) + ) + + +def _object_axis_near(env, spec: Mapping[str, Any]) -> torch.Tensor: + object_position = _position(env, _object_name(spec)) + axis = _axis_index(str(spec.get("axis", "y"))) + target_value = float(spec.get("target", spec.get("value"))) + return torch.abs(object_position[:, axis] - target_value) <= float( + spec.get("tolerance", 0.02) + ) + + +def _object_lifted(env, spec: Mapping[str, Any]) -> torch.Tensor: + object_name = _object_name(spec) + position = _position(env, object_name) + initial_height = spec.get("initial_height") + if initial_height is None: + initial_height = getattr(env, "obj_info", {}).get(object_name, {}).get("height") + if initial_height is None: + initial_height = position[:, 2] + initial_height = _tensor( + initial_height, + dtype=position.dtype, + device=position.device, + ) + return position[:, 2] >= initial_height + float(spec.get("min_height", 0.1)) + + +def _axis_index(axis: str) -> int: + axes = {"x": 0, "y": 1, "z": 2} + if axis not in axes: + raise ValueError(f"Unsupported axis {axis!r}; expected one of x, y, z.") + return axes[axis] diff --git a/embodichain/gen_sim/action_agent_pipeline/generation/__init__.py b/embodichain/gen_sim/action_agent_pipeline/generation/__init__.py new file mode 100644 index 00000000..f89f7a8b --- /dev/null +++ b/embodichain/gen_sim/action_agent_pipeline/generation/__init__.py @@ -0,0 +1,21 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +"""Config generation helpers for the action-agent pipeline.""" + +__all__: list[str] = [] diff --git a/embodichain/gen_sim/action_agent_pipeline/generation/coacd_cache.py b/embodichain/gen_sim/action_agent_pipeline/generation/coacd_cache.py new file mode 100644 index 00000000..0787ca72 --- /dev/null +++ b/embodichain/gen_sim/action_agent_pipeline/generation/coacd_cache.py @@ -0,0 +1,186 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from collections.abc import Mapping +from pathlib import Path +from typing import Any +import hashlib + +from embodichain.utils.logger import log_info + +__all__ = [ + "coacd_cache_path_for_mesh", + "dexsim_coacd_cache_key_for_mesh", + "prewarm_coacd_cache_for_gym_config", +] + +_DEFAULT_CONVEX_DECOMP_DIR = ( + Path.home() / ".cache" / "embodichain_cache" / "convex_decomposition" +) + + +def coacd_cache_path_for_mesh( + mesh_path: str | Path, + max_convex_hull_num: int, + cache_dir: str | Path | None = None, + *, + mesh_count: int = 1, +) -> Path: + """Return the DexSim environment-side CoACD cache path for a mesh.""" + + if cache_dir is None: + cache_dir = _DEFAULT_CONVEX_DECOMP_DIR + + mesh_md5_key = dexsim_coacd_cache_key_for_mesh(mesh_path, mesh_count=mesh_count) + return Path(cache_dir).expanduser().resolve() / ( + f"{mesh_md5_key}_{int(max_convex_hull_num)}.obj" + ) + + +def dexsim_coacd_cache_key_for_mesh( + mesh_path: str | Path, + *, + mesh_count: int = 1, +) -> str: + """Return the cache key used by DexSim ``load_actor_with_coacd``.""" + + resolved_mesh_path = Path(mesh_path).expanduser().resolve(strict=False) + mesh_key_data = f"{resolved_mesh_path}|mesh_count={int(mesh_count)}" + return hashlib.md5(mesh_key_data.encode("utf-8")).hexdigest() + + +def prewarm_coacd_cache_for_gym_config( + gym_config: Mapping[str, Any], + *, + cache_dir: str | Path | None = None, + repo_root: str | Path | None = None, +) -> list[dict[str, Any]]: + """Precompute DexSim environment-side CoACD cache files for mesh objects.""" + + entries = [] + for obj in _iter_mesh_object_configs(gym_config): + max_convex_hull_num = int(obj.get("max_convex_hull_num", 1)) + if max_convex_hull_num <= 1: + continue + entries.append((obj, max_convex_hull_num)) + if not entries: + return [] + + if cache_dir is None: + cache_dir = _DEFAULT_CONVEX_DECOMP_DIR + + cache_dir = Path(cache_dir).expanduser().resolve() + cache_dir.mkdir(parents=True, exist_ok=True) + repo_root = Path(repo_root).expanduser().resolve() if repo_root else _repo_root() + + reports: list[dict[str, Any]] = [] + seen_cache_paths: set[Path] = set() + for obj, max_convex_hull_num in entries: + uid = str(obj.get("uid", "")) + raw_fpath = str(obj.get("shape", {}).get("fpath", "")) + mesh_path = _resolve_mesh_path(raw_fpath, repo_root) + cache_path = coacd_cache_path_for_mesh( + mesh_path, + max_convex_hull_num, + cache_dir, + ) + report = { + "uid": uid, + "mesh_path": mesh_path.as_posix(), + "mesh_count": 1, + "max_convex_hull_num": max_convex_hull_num, + "cache_path": cache_path.as_posix(), + } + if cache_path in seen_cache_paths: + report["status"] = "duplicate" + elif cache_path.is_file(): + report["status"] = "hit" + else: + try: + _generate_coacd_cache(mesh_path, cache_path, max_convex_hull_num) + except Exception as exc: + report["status"] = "skipped" + report["reason"] = str(exc) + else: + report["status"] = "generated" + seen_cache_paths.add(cache_path) + reports.append(report) + return reports + + +def _iter_mesh_object_configs( + gym_config: Mapping[str, Any], +) -> list[Mapping[str, Any]]: + objects = [] + for section in ("background", "rigid_object"): + value = gym_config.get(section, []) + if isinstance(value, Mapping): + value = [value] + if not isinstance(value, list): + continue + for obj in value: + if not isinstance(obj, Mapping): + continue + shape = obj.get("shape", {}) + if isinstance(shape, Mapping) and shape.get("shape_type") == "Mesh": + objects.append(obj) + return objects + + +def _resolve_mesh_path(raw_fpath: str, repo_root: Path) -> Path: + path = Path(raw_fpath).expanduser() + if path.is_absolute(): + candidate = path.resolve() + else: + candidate = (repo_root / path).resolve() + if not candidate.is_file(): + cwd_candidate = (Path.cwd() / path).resolve() + if cwd_candidate.is_file(): + candidate = cwd_candidate + if not candidate.is_file(): + raise FileNotFoundError(f"Mesh path for CoACD prewarm not found: {raw_fpath}") + return candidate + + +def _generate_coacd_cache( + mesh_path: Path, + cache_path: Path, + max_convex_hull_num: int, +) -> None: + import open3d as o3d + from dexsim.kit.meshproc import convex_decomposition_coacd + from dexsim.kit.meshproc.utility import mesh_list_to_file + + log_info( + "Prewarming environment CoACD cache: " + f"mesh={mesh_path.as_posix()}, hulls={max_convex_hull_num}" + ) + in_mesh = o3d.t.io.read_triangle_mesh(mesh_path.as_posix()) + _, out_mesh_list = convex_decomposition_coacd( + in_mesh, + max_convex_hull_num=int(max_convex_hull_num), + ) + mesh_list_to_file(cache_path.as_posix(), out_mesh_list) + + +def _repo_root() -> Path: + current = Path(__file__).resolve() + for parent in current.parents: + if (parent / "setup.py").is_file() and (parent / "embodichain").is_dir(): + return parent + return Path.cwd().resolve() diff --git a/embodichain/gen_sim/action_agent_pipeline/generation/mesh_frame_normalization.py b/embodichain/gen_sim/action_agent_pipeline/generation/mesh_frame_normalization.py new file mode 100644 index 00000000..e9576cd0 --- /dev/null +++ b/embodichain/gen_sim/action_agent_pipeline/generation/mesh_frame_normalization.py @@ -0,0 +1,530 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from dataclasses import dataclass +from importlib import metadata +from pathlib import Path +from typing import Any +import hashlib +import json +import math +import re +import struct + +__all__ = [ + "GLB_TO_OBJ_BAKED_X_ROTATION_DEGREES", + "GLB_LOCAL_X_CORRECTION_DEGREES", + "MESH_FRAME_NORMALIZATION_POLICY_VERSION", + "MeshFrameNormalizer", + "NormalizedMeshResult", +] + + +MESH_FRAME_NORMALIZATION_POLICY_VERSION = "action_agent_glb_scene_texture_obj_v3" +GLB_TO_OBJ_BAKED_X_ROTATION_DEGREES = 0.0 +GLB_LOCAL_X_CORRECTION_DEGREES = GLB_TO_OBJ_BAKED_X_ROTATION_DEGREES + +_SAFE_STEM_RE = re.compile(r"[^0-9a-zA-Z_.-]+") +_GLB_JSON_CHUNK_TYPE = 0x4E4F534A +_GLB_BINARY_CHUNK_TYPE = 0x004E4942 +_TEXTURE_EXTENSION_BY_MIME_TYPE = { + "image/jpeg": ".jpg", + "image/png": ".png", + "image/webp": ".webp", +} + + +@dataclass(frozen=True) +class _MaterialSpec: + name: str + texture_path: str | None = None + + +@dataclass(frozen=True) +class _TextureAsset: + data: bytes + extension: str + + +@dataclass(frozen=True) +class NormalizedMeshResult: + """A normalized mesh path and metadata for generation summaries.""" + + source_path: Path + normalized_path: Path + source_sha256: str + status: str + transform: list[list[float]] + dexsim_engine_version: str + + def to_summary(self) -> dict[str, Any]: + return { + "source_path": self.source_path.as_posix(), + "normalized_path": self.normalized_path.as_posix(), + "source_sha256": self.source_sha256, + "status": self.status, + "policy_version": MESH_FRAME_NORMALIZATION_POLICY_VERSION, + "dexsim_engine_version": self.dexsim_engine_version, + "transform": self.transform, + } + + +class MeshFrameNormalizer: + """Normalize GLB meshes to OBJ so visual and collision share one frame.""" + + def __init__( + self, + *, + output_dir: str | Path, + local_x_correction_degrees: float = GLB_TO_OBJ_BAKED_X_ROTATION_DEGREES, + ) -> None: + self.output_dir = Path(output_dir).expanduser().resolve() + self.local_x_correction_degrees = float(local_x_correction_degrees) + self.transform = _rotation_x_matrix4(self.local_x_correction_degrees) + self.dexsim_engine_version = _dexsim_engine_version() + self._results_by_source: dict[Path, NormalizedMeshResult] = {} + self._reports: list[dict[str, Any]] = [] + + @property + def reports(self) -> list[dict[str, Any]]: + return list(self._reports) + + def normalize_path(self, mesh_path: str | Path) -> Path: + """Return a runtime mesh path, normalizing GLB/GLTF inputs to OBJ.""" + + path = Path(mesh_path).expanduser().resolve() + if path.suffix.lower() not in {".glb", ".gltf"}: + return path + + cached = self._results_by_source.get(path) + if cached is not None: + if cached.normalized_path.is_file(): + material_spec = self._material_spec_for( + path, + cached.normalized_path, + cached.source_sha256, + ) + _repair_obj_material_reference( + cached.normalized_path, + material_spec.name, + ) + self._ensure_material_library({material_spec.name: material_spec}) + return cached.normalized_path + + source_sha256 = _file_sha256(path) + normalized_path = self._normalized_path_for(path, source_sha256) + material_spec = self._material_spec_for(path, normalized_path, source_sha256) + status = "reused" if normalized_path.is_file() else "generated" + if status == "generated": + self._write_normalized_obj( + path, + normalized_path, + source_sha256, + material_spec, + ) + else: + _repair_obj_material_reference(normalized_path, material_spec.name) + self._ensure_material_library({material_spec.name: material_spec}) + + result = NormalizedMeshResult( + source_path=path, + normalized_path=normalized_path, + source_sha256=source_sha256, + status=status, + transform=self.transform, + dexsim_engine_version=self.dexsim_engine_version, + ) + self._results_by_source[path] = result + self._reports.append(result.to_summary()) + return normalized_path + + def _normalized_path_for(self, mesh_path: Path, source_sha256: str) -> Path: + stem = _SAFE_STEM_RE.sub("_", mesh_path.stem).strip("._") or "mesh" + stem = stem[:32].strip("._") or "mesh" + runtime_hash = hashlib.sha256( + json.dumps( + { + "source_sha256": source_sha256, + "policy_version": MESH_FRAME_NORMALIZATION_POLICY_VERSION, + "dexsim_engine_version": self.dexsim_engine_version, + "transform": self.transform, + }, + sort_keys=True, + separators=(",", ":"), + ).encode("utf-8") + ).hexdigest() + return self.output_dir / f"{stem}_{runtime_hash[:16]}.obj" + + def _material_path(self) -> Path: + return self.output_dir / "material.mtl" + + def _texture_dir(self) -> Path: + return self.output_dir / "textures" + + def _material_spec_for( + self, + source_path: Path, + normalized_path: Path, + source_sha256: str, + ) -> _MaterialSpec: + material_hash = _material_hash_for(normalized_path) + material_name = f"material_{material_hash}" + texture_path = self._write_base_color_texture( + source_path, + material_hash, + source_sha256, + ) + return _MaterialSpec(name=material_name, texture_path=texture_path) + + def _write_base_color_texture( + self, + source_path: Path, + material_hash: str, + source_sha256: str, + ) -> str | None: + try: + texture = _extract_glb_base_color_texture(source_path) + except (IndexError, KeyError, TypeError, ValueError, json.JSONDecodeError): + return None + if texture is None: + return None + + texture_dir = self._texture_dir() + texture_dir.mkdir(parents=True, exist_ok=True) + texture_name = ( + f"{material_hash}_{source_sha256[:12]}_basecolor{texture.extension}" + ) + texture_path = texture_dir / texture_name + texture_path.write_bytes(texture.data) + return f"textures/{texture_name}" + + def _ensure_material_library( + self, material_specs: dict[str, _MaterialSpec] + ) -> None: + if not material_specs: + return + + material_path = self._material_path() + all_specs = { + **_read_material_specs(material_path), + **material_specs, + } + material_path.write_text( + "\n".join( + [ + "# EmbodiChain action-agent normalized mesh materials", + *[ + _format_material_spec(spec) + for spec in sorted( + all_specs.values(), key=lambda item: item.name + ) + ], + "", + ] + ), + encoding="utf-8", + ) + + def _write_normalized_obj( + self, + source_path: Path, + normalized_path: Path, + source_sha256: str, + material_spec: _MaterialSpec, + ) -> None: + trimesh = _require_trimesh() + scene = trimesh.load(str(source_path), force="scene") + mesh = _scene_to_world_mesh(scene) + if self.local_x_correction_degrees: + mesh.apply_transform(self.transform) + + normalized_path.parent.mkdir(parents=True, exist_ok=True) + obj_payload = mesh.export(file_type="obj") + if isinstance(obj_payload, bytes): + obj_text = obj_payload.decode("utf-8") + else: + obj_text = str(obj_payload) + obj_text = _ensure_obj_material_reference(obj_text, material_spec.name) + + header = "\n".join( + [ + "# EmbodiChain action-agent normalized mesh", + f"# policy_version: {MESH_FRAME_NORMALIZATION_POLICY_VERSION}", + f"# dexsim_engine_version: {self.dexsim_engine_version}", + f"# source_path: {source_path.as_posix()}", + f"# source_sha256: {source_sha256}", + f"# transform: {json.dumps(self.transform, separators=(',', ':'))}", + "", + ] + ) + normalized_path.write_text(header + obj_text, encoding="utf-8") + self._ensure_material_library({material_spec.name: material_spec}) + + +def _scene_to_world_mesh(scene: Any) -> Any: + if hasattr(scene, "to_geometry"): + mesh = scene.to_geometry() + elif hasattr(scene, "dump"): + mesh = scene.dump(concatenate=True) + else: + mesh = scene + if not hasattr(mesh, "vertices") or len(mesh.vertices) == 0: + raise ValueError("Mesh contains no vertices.") + return mesh + + +def _material_hash_for(normalized_path: Path) -> str: + hash_part = normalized_path.stem.rsplit("_", maxsplit=1)[-1] + if re.fullmatch(r"[0-9a-fA-F]{8,}", hash_part): + return hash_part.lower() + return hashlib.sha256(normalized_path.stem.encode("utf-8")).hexdigest()[:16] + + +def _repair_obj_material_reference(obj_path: Path, material_name: str) -> str: + obj_text = obj_path.read_text(encoding="utf-8") + repaired = _ensure_obj_material_reference(obj_text, material_name) + if repaired != obj_text: + obj_path.write_text(repaired, encoding="utf-8") + return repaired + + +def _ensure_obj_material_reference(obj_text: str, material_name: str) -> str: + lines = obj_text.splitlines() + header_lines: list[str] = [] + body_start = 0 + for line in lines: + if not line.startswith("#"): + break + header_lines.append(line) + body_start += 1 + + body_lines: list[str] = [] + has_usemtl = False + for line in lines[body_start:]: + if line.startswith("mtllib "): + continue + if line.startswith("usemtl "): + body_lines.append(f"usemtl {material_name}") + has_usemtl = True + continue + body_lines.append(line) + + prefix = ["mtllib material.mtl"] + if not has_usemtl: + prefix.append(f"usemtl {material_name}") + return "\n".join(header_lines + prefix + body_lines) + "\n" + + +def _read_material_specs(material_path: Path) -> dict[str, _MaterialSpec]: + if not material_path.is_file(): + return {} + + specs: dict[str, _MaterialSpec] = {} + current_name: str | None = None + current_texture_path: str | None = None + for line in material_path.read_text(encoding="utf-8").splitlines(): + if line.startswith("newmtl "): + if current_name is not None: + specs[current_name] = _MaterialSpec( + name=current_name, + texture_path=current_texture_path, + ) + current_name = line.split(maxsplit=1)[1].strip() + current_texture_path = None + continue + if current_name is not None and line.startswith("map_Kd "): + current_texture_path = line.split(maxsplit=1)[1].strip() + if current_name is not None: + specs[current_name] = _MaterialSpec( + name=current_name, + texture_path=current_texture_path, + ) + return specs + + +def _format_material_spec(spec: _MaterialSpec) -> str: + ambient = "1.0 1.0 1.0" if spec.texture_path else "0.8 0.8 0.8" + diffuse = "1.0 1.0 1.0" if spec.texture_path else "0.8 0.8 0.8" + lines = [ + f"newmtl {spec.name}", + f"Ka {ambient}", + f"Kd {diffuse}", + "Ks 0.0 0.0 0.0", + "Ns 1.0", + "d 1.0", + "illum 2", + ] + if spec.texture_path: + lines.append(f"map_Kd {spec.texture_path}") + return "\n".join(lines) + + +def _extract_glb_base_color_texture(source_path: Path) -> _TextureAsset | None: + if source_path.suffix.lower() != ".glb": + return None + + doc, binary_chunk = _read_glb(source_path) + material = _first_textured_material(doc) + if material is None: + return None + + texture_index = int(material["pbrMetallicRoughness"]["baseColorTexture"]["index"]) + textures = doc.get("textures", []) + if not isinstance(textures, list) or texture_index >= len(textures): + return None + + texture = textures[texture_index] + if not isinstance(texture, dict): + return None + image_index = texture.get("source") + if image_index is None: + return None + + images = doc.get("images", []) + if not isinstance(images, list) or int(image_index) >= len(images): + return None + + image = images[int(image_index)] + if not isinstance(image, dict): + return None + + mime_type = str(image.get("mimeType", "")) + extension = _TEXTURE_EXTENSION_BY_MIME_TYPE.get(mime_type) + if extension is None: + return None + + buffer_view_index = image.get("bufferView") + if buffer_view_index is None: + return None + + image_data = _buffer_view_bytes(doc, binary_chunk, int(buffer_view_index)) + if not image_data: + return None + return _TextureAsset(data=image_data, extension=extension) + + +def _first_textured_material(doc: dict[str, Any]) -> dict[str, Any] | None: + materials = doc.get("materials", []) + if not isinstance(materials, list): + return None + for material in materials: + if not isinstance(material, dict): + continue + pbr = material.get("pbrMetallicRoughness", {}) + if not isinstance(pbr, dict): + continue + base_color_texture = pbr.get("baseColorTexture", {}) + if not isinstance(base_color_texture, dict): + continue + if "index" in base_color_texture: + return material + return None + + +def _read_glb(source_path: Path) -> tuple[dict[str, Any], bytes]: + data = source_path.read_bytes() + if len(data) < 12: + raise ValueError(f"GLB file is too small: {source_path}") + magic, version, declared_length = struct.unpack_from("<4sII", data, 0) + if magic != b"glTF" or version != 2: + raise ValueError(f"Only GLB version 2 files are supported: {source_path}") + if declared_length > len(data): + raise ValueError(f"GLB length header exceeds file size: {source_path}") + + offset = 12 + doc: dict[str, Any] | None = None + binary_chunk = b"" + while offset + 8 <= declared_length: + chunk_length, chunk_type = struct.unpack_from(" declared_length: + raise ValueError(f"GLB chunk exceeds file size: {source_path}") + chunk = data[offset:chunk_end] + offset = chunk_end + if chunk_type == _GLB_JSON_CHUNK_TYPE: + doc = json.loads(chunk.decode("utf-8")) + elif chunk_type == _GLB_BINARY_CHUNK_TYPE: + binary_chunk = chunk + if doc is None: + raise ValueError(f"GLB file does not contain a JSON chunk: {source_path}") + return doc, binary_chunk + + +def _buffer_view_bytes( + doc: dict[str, Any], + binary_chunk: bytes, + buffer_view_index: int, +) -> bytes: + buffer_views = doc.get("bufferViews", []) + if not isinstance(buffer_views, list) or buffer_view_index >= len(buffer_views): + return b"" + buffer_view = buffer_views[buffer_view_index] + if not isinstance(buffer_view, dict): + return b"" + if int(buffer_view.get("buffer", 0)) != 0: + return b"" + byte_offset = int(buffer_view.get("byteOffset", 0)) + byte_length = int(buffer_view.get("byteLength", 0)) + if byte_length <= 0: + return b"" + return binary_chunk[byte_offset : byte_offset + byte_length] + + +def _rotation_x_matrix4(degrees: float) -> list[list[float]]: + if degrees == 0.0: + return [ + [1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 0.0], + [0.0, 0.0, 0.0, 1.0], + ] + radians = math.radians(degrees) + cos_value = math.cos(radians) + sin_value = math.sin(radians) + return [ + [1.0, 0.0, 0.0, 0.0], + [0.0, cos_value, -sin_value, 0.0], + [0.0, sin_value, cos_value, 0.0], + [0.0, 0.0, 0.0, 1.0], + ] + + +def _file_sha256(path: Path) -> str: + digest = hashlib.sha256() + with path.open("rb") as file: + for chunk in iter(lambda: file.read(1024 * 1024), b""): + digest.update(chunk) + return digest.hexdigest() + + +def _dexsim_engine_version() -> str: + for package_name in ("dexsim-engine", "dexsim_engine"): + try: + return metadata.version(package_name) + except metadata.PackageNotFoundError: + continue + return "unknown" + + +def _require_trimesh() -> Any: + try: + import trimesh + except ImportError as exc: + raise ImportError("trimesh is required to normalize GLB meshes.") from exc + return trimesh diff --git a/embodichain/gen_sim/action_agent_pipeline/generation/prompt_builders.py b/embodichain/gen_sim/action_agent_pipeline/generation/prompt_builders.py new file mode 100644 index 00000000..fa45fd03 --- /dev/null +++ b/embodichain/gen_sim/action_agent_pipeline/generation/prompt_builders.py @@ -0,0 +1,1084 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +"""Prompt and agent-config builders for generated action-agent tasks.""" + +from __future__ import annotations + +import json +from collections.abc import Mapping, Sequence +from typing import Any, Protocol + +__all__ = [ + "make_agent_config", + "make_basket_atom_actions_prompt", + "make_basket_basic_background", + "make_basket_task_prompt", + "make_relative_atom_actions_prompt", + "make_relative_basic_background", + "make_relative_task_prompt", +] + +_BASKET_LEFT_RELEASE_OFFSET_Y = -0.04 +_BASKET_RIGHT_RELEASE_OFFSET_Y = 0.04 +_PLACE_LIFT_HEIGHT = 0.10 +_RELATIVE_COORDINATE_CONVENTION = """Coordinate convention for relative placement: +- `left_of` means negative world y relative to the reference object. +- `right_of` means positive world y relative to the reference object. +- `front_of` means negative world x relative to the reference object. +- `behind` means positive world x relative to the reference object. +- `front_left_of` combines negative world x and negative world y. +- `back_left_of` combines positive world x and negative world y. +- `front_right_of` combines negative world x and positive world y. +- `back_right_of` combines positive world x and positive world y. +- `inside` and `on` use the reference object's xy center.""" + + +class _BasketRolesLike(Protocol): + left_target_runtime_uid: str + right_target_runtime_uid: str + container_runtime_uid: str + left_target_source_uid: str + right_target_source_uid: str + container_source_uid: str + left_target_noun: str + right_target_noun: str + + +class _RelativePlacementLike(Protocol): + active_side: str + moved_runtime_uid: str + moved_source_uid: str + reference_runtime_uid: str + reference_source_uid: str + relation: str + high_offset: tuple[float, float, float] + release_offset: tuple[float, float, float] + reference_is_initial_pose: bool + high_position: Sequence[float] | None + release_position: Sequence[float] | None + + +class _RelativeSpecLike(_RelativePlacementLike, Protocol): + placements: Sequence[_RelativePlacementLike] + task_prompt_summary: str + task_description: str + action_sketch: Sequence[str] + basic_background_notes: str + + +def make_agent_config() -> dict[str, Any]: + return { + "TaskAgent": { + "prompt_name": "generate_task_graph", + }, + "CompileAgent": {}, + "Agent": { + "prompt_kwargs": { + "task_prompt": { + "type": "text", + "name": "task_prompt.txt", + }, + "basic_background": { + "type": "text", + "name": "basic_background.txt", + }, + "atom_actions": { + "type": "text", + "name": "atom_actions.txt", + }, + } + }, + } + + +def make_relative_task_prompt( + task_name: str, + project_name: str, + spec: _RelativeSpecLike, +) -> str: + if len(spec.placements) > 1: + return _make_dual_relative_task_prompt(task_name, project_name, spec) + + active_arm = f"{spec.active_side}_arm" + inactive_slot = ( + "right_arm_action" if spec.active_side == "left" else "left_arm_action" + ) + active_slot = f"{spec.active_side}_arm_action" + action_sketch = _format_action_sketch(spec.action_sketch) + pick_spec = _format_pick_up_spec(active_arm, spec.moved_runtime_uid) + high_spec = _format_relative_pose_spec( + active_arm, + spec, + pose_kind="high", + sample_interval=45, + ) + place_spec = _format_relative_place_spec( + active_arm, + spec, + sample_interval=80, + lift_height=_PLACE_LIFT_HEIGHT, + ) + initial_spec = _format_initial_qpos_spec(active_arm, sample_interval=30) + reference_line = _relative_reference_line(spec) + final_planning_rule = _relative_final_planning_rule(project_name, spec) + high_step_label = _relative_pose_step_label(spec, "high staging") + release_step_label = _relative_pose_step_label(spec, "release") + return f"""Task: +{task_name}: {spec.task_prompt_summary} + +This config was generated from a simple task description by the config-stage +LLM. The execution-stage LLM must now generate the graph JSON from this prompt. + +Original simple task description: +{spec.task_description} + +Config-stage LLM interpretation: +{action_sketch} + +Object and arm mapping: +- Move `{spec.moved_runtime_uid}`. Source object: `{spec.moved_source_uid}`. +- {reference_line} +- Goal relation: `{spec.relation}` ({_relative_relation_phrase(spec.relation)}). +- Active arm: `{active_arm}`. +- Keep every `{inactive_slot}` as null. + +{_RELATIVE_COORDINATE_CONVENTION} + +Generate one deterministic nominal graph with exactly 4 nominal edges. Use only +the atomic action class JSON specs shown below. Do not add recovery, monitor, search, +alignment, or extra lift edges. Use `PlaceAction` for the release-place step so +lowering, gripper opening, and upward retreat remain one atomic action. The +inactive arm must remain null in every edge. + +1. Pick up the moved object: + - {active_slot}: {pick_spec} + - {inactive_slot}: null + +2. Move the held object to the {high_step_label} pose: + - {active_slot}: {high_spec} + - {inactive_slot}: null + +3. Place the held object at the {release_step_label} pose: + - {active_slot}: {place_spec} + - {inactive_slot}: null + +4. Return the active arm to its initial pose: + - {active_slot}: {initial_spec} + - {inactive_slot}: null + +Final state: `{spec.moved_runtime_uid}` must be +{_relative_relation_phrase(spec.relation)} `{spec.reference_runtime_uid}`. +{final_planning_rule} +""" + + +def _make_dual_relative_task_prompt( + task_name: str, + project_name: str, + spec: _RelativeSpecLike, +) -> str: + first, second = spec.placements + first_arm = f"{first.active_side}_arm" + second_arm = f"{second.active_side}_arm" + first_slot = f"{first.active_side}_arm_action" + second_slot = f"{second.active_side}_arm_action" + action_sketch = _format_action_sketch(spec.action_sketch) + first_pick_spec = _format_pick_up_spec(first_arm, first.moved_runtime_uid) + second_pick_spec = _format_pick_up_spec(second_arm, second.moved_runtime_uid) + first_high_spec = _format_relative_pose_spec( + first_arm, + first, + pose_kind="high", + sample_interval=45, + ) + second_high_spec = _format_relative_pose_spec( + second_arm, + second, + pose_kind="high", + sample_interval=45, + ) + first_place_spec = _format_relative_place_spec( + first_arm, + first, + sample_interval=80, + lift_height=_PLACE_LIFT_HEIGHT, + ) + second_place_spec = _format_relative_place_spec( + second_arm, + second, + sample_interval=80, + lift_height=_PLACE_LIFT_HEIGHT, + ) + first_close_spec = _format_gripper_spec( + first_arm, + "close", + sample_interval=10, + ) + second_close_spec = _format_gripper_spec( + second_arm, + "close", + sample_interval=10, + ) + first_initial_spec = _format_initial_qpos_spec( + first_arm, + sample_interval=30, + ) + second_initial_spec = _format_initial_qpos_spec( + second_arm, + sample_interval=30, + ) + first_reference_line = _relative_reference_line(first) + second_reference_line = _relative_reference_line(second) + final_planning_rule = _dual_relative_final_planning_rule(project_name, spec) + return f"""Task: +{task_name}: {spec.task_prompt_summary} + +This config was generated from a simple task description by the config-stage +LLM. The execution-stage LLM must now generate the graph JSON from this prompt. + +Original simple task description: +{spec.task_description} + +Config-stage LLM interpretation: +{action_sketch} + +Object and arm mapping: +- {first_slot} must manipulate `{first.moved_runtime_uid}`. Source object: + `{first.moved_source_uid}`. +- {second_slot} must manipulate `{second.moved_runtime_uid}`. Source object: + `{second.moved_source_uid}`. +- {first_reference_line} Goal relation for `{first.moved_runtime_uid}`: + `{first.relation}` ({_relative_relation_phrase(first.relation)}). +- {second_reference_line} Goal relation for `{second.moved_runtime_uid}`: + `{second.relation}` ({_relative_relation_phrase(second.relation)}). + +{_RELATIVE_COORDINATE_CONVENTION} + +Generate one deterministic nominal graph with exactly 6 nominal edges. Use only +the atomic action class JSON specs shown below. Do not add recovery, monitor, search, +alignment, or extra lift edges. Use `PlaceAction` for each release-place step so +lowering, gripper opening, and upward retreat remain one atomic action. + +1. Pick up both moved objects simultaneously: + - {first_slot}: {first_pick_spec} + - {second_slot}: {second_pick_spec} + +2. Move `{first.moved_runtime_uid}` to the high staging pose while the other arm + keeps holding `{second.moved_runtime_uid}`: + - {first_slot}: {first_high_spec} + - {second_slot}: {second_close_spec} + +3. Place `{first.moved_runtime_uid}` at the release pose: + - {first_slot}: {first_place_spec} + - {second_slot}: {second_close_spec} + +4. Return `{first_arm}` to its initial pose while moving `{second.moved_runtime_uid}` + to the high staging pose: + - {first_slot}: {first_initial_spec} + - {second_slot}: {second_high_spec} + +5. Place `{second.moved_runtime_uid}` at the release pose: + - {first_slot}: null + - {second_slot}: {second_place_spec} + +6. Return `{second_arm}` to its initial pose: + - {first_slot}: null + - {second_slot}: {second_initial_spec} + +Final state: `{first.moved_runtime_uid}` must be +{_relative_relation_phrase(first.relation)} `{first.reference_runtime_uid}`, and +`{second.moved_runtime_uid}` must be {_relative_relation_phrase(second.relation)} +`{second.reference_runtime_uid}`. +{final_planning_rule} +""" + + +def make_relative_basic_background( + project_name: str, + spec: _RelativeSpecLike, +) -> str: + if len(spec.placements) > 1: + return _make_dual_relative_basic_background(project_name, spec) + + active_arm = f"{spec.active_side}_arm" + inactive_arm = "right_arm" if spec.active_side == "left" else "left_arm" + notes = spec.basic_background_notes or ( + "No extra scene notes were provided by the config-stage LLM." + ) + return f"""The scene comes from the exported {project_name} mesh environment. + +This configuration directory is for a Dual-UR5 relative-placement task generated +from a simple natural-language task description. + +The robot is a dual-UR5 composite robot with DH_PGI_140_80 parallel grippers: +- left_arm is the semantic robot-view left slot, mapped to the physical + right_arm control part. +- right_arm is the semantic robot-view right slot, mapped to the physical + left_arm control part. + +The active arm for this task is `{active_arm}`. The inactive arm +`{inactive_arm}` must stay null in the nominal graph. + +Interactive task objects: +- {spec.moved_runtime_uid}: moved object from source `{spec.moved_source_uid}`. +- {_relative_reference_line(spec)} + +Config-stage LLM notes: +{notes} + +The execution-stage LLM should generate graph JSON that grasps the moved object, +moves it to the configured high staging pose, places it at the release pose with +one `PlaceAction`, and returns the active arm to its initial pose. +""" + + +def _make_dual_relative_basic_background( + project_name: str, + spec: _RelativeSpecLike, +) -> str: + notes = spec.basic_background_notes or ( + "No extra scene notes were provided by the config-stage LLM." + ) + placement_lines = "\n".join( + f"- {placement.active_side}_arm moves `{placement.moved_runtime_uid}` " + f"{_relative_relation_phrase(placement.relation)} " + f"`{placement.reference_runtime_uid}`." + for placement in spec.placements + ) + return f"""The scene comes from the exported {project_name} mesh environment. + +This configuration directory is for a Dual-UR5 dual-arm relative-placement task +generated from a simple natural-language task description. + +The robot is a dual-UR5 composite robot with DH_PGI_140_80 parallel grippers: +- left_arm is the semantic robot-view left slot, mapped to the physical + right_arm control part. +- right_arm is the semantic robot-view right slot, mapped to the physical + left_arm control part. + +Both arms participate in the nominal graph: +{placement_lines} + +Config-stage LLM notes: +{notes} + +The execution-stage LLM should generate graph JSON that grasps both moved +objects, stages and places the first moved object with one `PlaceAction`, then +stages and places the second moved object while the first arm returns to its +initial pose. Each arm must release its moved object before returning to its +initial pose. +""" + + +def make_relative_atom_actions_prompt(spec: _RelativeSpecLike) -> str: + if len(spec.placements) > 1: + return _make_dual_relative_atom_actions_prompt(spec) + + active_arm = f"{spec.active_side}_arm" + inactive_arm = "right_arm" if spec.active_side == "left" else "left_arm" + high_spec = _format_relative_pose_spec( + active_arm, + spec, + pose_kind="high", + sample_interval=45, + ) + place_spec = _format_relative_place_spec( + active_arm, + spec, + sample_interval=80, + lift_height=_PLACE_LIFT_HEIGHT, + ) + return f"""### Atomic Action Class JSON Specs for Dual-UR5 Relative Placement + +Use only atomic action class JSON specs backed by `PickUpAction`, `MoveAction`, and +`PlaceAction`. The active arm is `{active_arm}`. Keep `{inactive_arm}` null in +the nominal graph. + +Use exactly these action patterns: +- Pick up `{spec.moved_runtime_uid}`: + {_format_pick_up_spec(active_arm, spec.moved_runtime_uid)} +- {_relative_pose_step_label(spec, "High staging")}: + {high_spec} +- Place at the release pose: + {place_spec} +- Return to initial qpos: + {_format_initial_qpos_spec(active_arm, sample_interval=30)} +""" + + +def _make_dual_relative_atom_actions_prompt(spec: _RelativeSpecLike) -> str: + first, second = spec.placements + first_arm = f"{first.active_side}_arm" + second_arm = f"{second.active_side}_arm" + first_high_spec = _format_relative_pose_spec( + first_arm, + first, + pose_kind="high", + sample_interval=45, + ) + second_high_spec = _format_relative_pose_spec( + second_arm, + second, + pose_kind="high", + sample_interval=45, + ) + first_place_spec = _format_relative_place_spec( + first_arm, + first, + sample_interval=80, + lift_height=_PLACE_LIFT_HEIGHT, + ) + second_place_spec = _format_relative_place_spec( + second_arm, + second, + sample_interval=80, + lift_height=_PLACE_LIFT_HEIGHT, + ) + return f"""### Atomic Action Class JSON Specs for Dual-UR5 Dual-Arm Relative Placement + +Use only atomic action class JSON specs backed by `PickUpAction`, `MoveAction`, and +`PlaceAction`. +- `{first_arm}` manipulates `{first.moved_runtime_uid}`. +- `{second_arm}` manipulates `{second.moved_runtime_uid}`. + +Use these action patterns: +- First arm pick-up: + {_format_pick_up_spec(first_arm, first.moved_runtime_uid)} +- Second arm pick-up: + {_format_pick_up_spec(second_arm, second.moved_runtime_uid)} +- First high staging: + {first_high_spec} +- First place action: + {first_place_spec} +- Second high staging: + {second_high_spec} +- Second place action: + {second_place_spec} +- Keep a holding arm closed: + {_format_gripper_spec("", "close", sample_interval=10)} +- Return to initial qpos: + {_format_initial_qpos_spec("", sample_interval=30)} +""" + + +def make_basket_task_prompt( + task_name: str, + project_name: str, + roles: _BasketRolesLike, +) -> str: + left_target_text = _left_target_text(roles) + right_target_text = _right_target_text(roles) + target_pair_text = _target_pair_text(roles) + target_plural = _target_plural_text(roles) + left_pick_spec = _format_pick_up_spec( + "left_arm", + roles.left_target_runtime_uid, + ) + right_pick_spec = _format_pick_up_spec( + "right_arm", + roles.right_target_runtime_uid, + ) + left_high_spec = _format_pose_object_spec( + "left_arm", + roles.container_runtime_uid, + (0.0, _BASKET_LEFT_RELEASE_OFFSET_Y, 0.22), + sample_interval=45, + ) + right_high_spec = _format_pose_object_spec( + "right_arm", + roles.container_runtime_uid, + (0.0, _BASKET_RIGHT_RELEASE_OFFSET_Y, 0.22), + sample_interval=45, + ) + left_place_spec = _format_place_object_spec( + "left_arm", + roles.container_runtime_uid, + (0.0, _BASKET_LEFT_RELEASE_OFFSET_Y, 0.12), + sample_interval=80, + lift_height=_PLACE_LIFT_HEIGHT, + ) + right_place_spec = _format_place_object_spec( + "right_arm", + roles.container_runtime_uid, + (0.0, _BASKET_RIGHT_RELEASE_OFFSET_Y, 0.12), + sample_interval=80, + lift_height=_PLACE_LIFT_HEIGHT, + ) + right_close_spec = _format_gripper_spec( + "right_arm", + "close", + sample_interval=10, + ) + left_initial_spec = _format_initial_qpos_spec( + "left_arm", + sample_interval=30, + ) + right_initial_spec = _format_initial_qpos_spec( + "right_arm", + sample_interval=30, + ) + return f"""Task: +{task_name}: use the current two-UR5 configuration to place +{target_pair_text} into the {roles.container_runtime_uid}. + +The task starts with both arms acting simultaneously: +the left UR5 grasps the left {left_target_text} while the right UR5 grasps the +right {right_target_text} in the same nominal graph edge. After both +{target_plural} are grasped, the left UR5 places its {left_target_text} into the +{roles.container_runtime_uid} and retreats upward. While the left UR5 returns +to its initial pose, the right UR5 must simultaneously begin placing its +already-grasped {right_target_text} by moving it to the high staging pose above +the {roles.container_runtime_uid}. The right UR5 then completes its placement +and returns to its initial pose. + +Object and arm mapping: +- left_arm must only manipulate `{roles.left_target_runtime_uid}`. +- right_arm must only manipulate `{roles.right_target_runtime_uid}`. +- Both target objects must be released into `{roles.container_runtime_uid}`. + +Generate one deterministic nominal graph with the following semantic sequence. +Do not add extra alignment, search, recovery, or monitor steps. Use `PlaceAction` +for each release-place step so lowering, gripper opening, and upward retreat +remain one atomic action. The left arm must finish its `PlaceAction` retreat +before the right arm enters the shared container workspace, but the left +return-to-initial action and the right high-staging action must execute +simultaneously in one graph edge. Generate exactly 6 +nominal edges, one edge for each numbered step below. Do not split the +simultaneous grasp or the simultaneous left-return/right-staging action into +separate edges. Do not split a `PlaceAction` into separate lower-to-release, +open-gripper, or upward-retreat edges. + +A target object is not considered placed when it is only above the +{roles.container_runtime_uid}. For each arm, the placement order must be: move +to a high staging pose above the container, then execute one `PlaceAction` at +the release pose inside the container, then return the arm to its initial pose. +Never use `target_qpos` source `initial` for an arm that has not already +released its held target object. + +1. Pick up both target objects simultaneously: + - left_arm_action: {left_pick_spec} + - right_arm_action: {right_pick_spec} + +2. Move the held left target object directly above the left half of the + {roles.container_runtime_uid} while the right arm keeps holding its target: + - left_arm_action: {left_high_spec} + - right_arm_action: {right_close_spec} + +3. Place the held left target object at the left release pose inside the + {roles.container_runtime_uid}: + - left_arm_action: {left_place_spec} + - right_arm_action: {right_close_spec} + +4. After the left gripper has retreated upward, return the left UR5 to its + initial pose while simultaneously moving the held right target object + directly above the right half of the {roles.container_runtime_uid}. This + parallel handoff must remain one graph edge: + - left_arm_action: {left_initial_spec} + - right_arm_action: {right_high_spec} + +5. Place the held right target object at the right release pose inside the + {roles.container_runtime_uid}: + - left_arm_action: null + - right_arm_action: {right_place_spec} + +6. Return the right UR5 to its initial pose after releasing the target object: + - left_arm_action: null + - right_arm_action: {right_initial_spec} + +The final state is both `{roles.left_target_runtime_uid}` and +`{roles.right_target_runtime_uid}` resting inside `{roles.container_runtime_uid}`, +with both arms moved away from the container workspace. Always plan to the +current `{roles.container_runtime_uid}` object pose from the exported +{project_name} environment config. +""" + + +def make_basket_basic_background( + project_name: str, + roles: _BasketRolesLike, +) -> str: + left_target_text = _left_target_text(roles) + right_target_text = _right_target_text(roles) + target_plural = _target_plural_text(roles) + return f"""The scene comes from the exported {project_name} mesh environment. + +This configuration directory is for the UR5BreadBasket task template. The +current robot is a dual-UR5 composite robot with DH_PGI_140_80 parallel +grippers. + +The robot is a dual-UR5 composite robot with two parallel grippers: +- left_arm is the semantic robot-view left slot, mapped to the physical + right_arm control part. +- right_arm is the semantic robot-view right slot, mapped to the physical + left_arm control part. + +Both UR5 bases are on the same long side of the table and face inward toward +the central {roles.container_runtime_uid}. The bases are intentionally kept +outside the table edge to avoid initial robot-table contact. + +The interactive objects are: +- {roles.left_target_runtime_uid}: the {left_target_text} mesh initially on the + negative-y side (source object {roles.left_target_source_uid}). +- {roles.right_target_runtime_uid}: the {right_target_text} mesh initially on the + positive-y side (source object {roles.right_target_source_uid}). +- {roles.container_runtime_uid}: the target container near the center of the + table (source object {roles.container_source_uid}). + +The nominal task starts with simultaneous dual-arm grasping. The left UR5 must +grasp {roles.left_target_runtime_uid} while the right UR5 grasps +{roles.right_target_runtime_uid} in the same graph edge. After both +{target_plural} are held, the left UR5 places +{roles.left_target_runtime_uid} into {roles.container_runtime_uid} with one +`PlaceAction`. The next graph edge is a parallel handoff: the left UR5 returns +to its initial pose while the right UR5 simultaneously moves its +already-grasped {roles.right_target_runtime_uid} to the high staging pose above +{roles.container_runtime_uid}. The right UR5 then places +{roles.right_target_runtime_uid} with one `PlaceAction` and returns to its +initial pose. To change the insertion order later, edit the task prompt sequence +and keep the same atomic action API. + +The {roles.container_runtime_uid} area is a shared workspace. A UR5 should +complete its `PlaceAction` retreat before the other UR5 moves to the container, +otherwise the two arms may collide near the container. The right UR5 should keep +holding {roles.right_target_runtime_uid} while the left UR5 performs its +placement. Once that `PlaceAction` is complete, the right UR5 may move toward +the container while the left UR5 simultaneously returns to its initial pose; it +must not wait for the left return-to-initial motion to finish. + +A target object at a high pose above `{roles.container_runtime_uid}` is only +staged, not placed. Each arm must execute a `PlaceAction` at the container +release pose before any return-to-initial motion. + +Always plan to the current `{roles.container_runtime_uid}` object pose from the +environment config. Do not hard-code container coordinates in generated graph +actions. +""" + + +def make_basket_atom_actions_prompt(roles: _BasketRolesLike) -> str: + left_high_spec = _format_pose_object_spec( + "left_arm", + roles.container_runtime_uid, + (0.0, _BASKET_LEFT_RELEASE_OFFSET_Y, 0.22), + sample_interval=45, + ) + right_high_spec = _format_pose_object_spec( + "right_arm", + roles.container_runtime_uid, + (0.0, _BASKET_RIGHT_RELEASE_OFFSET_Y, 0.22), + sample_interval=45, + ) + left_place_spec = _format_place_object_spec( + "left_arm", + roles.container_runtime_uid, + (0.0, _BASKET_LEFT_RELEASE_OFFSET_Y, 0.12), + sample_interval=80, + lift_height=_PLACE_LIFT_HEIGHT, + ) + right_place_spec = _format_place_object_spec( + "right_arm", + roles.container_runtime_uid, + (0.0, _BASKET_RIGHT_RELEASE_OFFSET_Y, 0.12), + sample_interval=80, + lift_height=_PLACE_LIFT_HEIGHT, + ) + return f"""### Atomic Action Class JSON Specs for UR5BreadBasket Dual-UR5 Placement + +Use only atomic action class JSON specs backed by `PickUpAction`, `MoveAction`, and +`PlaceAction`. Use `robot_name="left_arm"` only for +`{roles.left_target_runtime_uid}` and `robot_name="right_arm"` only for +`{roles.right_target_runtime_uid}`. + +The nominal task starts with simultaneous dual-arm pick-up, followed by a +left-first placement with an overlapped handoff to the right arm: +- The first nominal edge must use `atomic_action_class:"PickUpAction"` for both arms. +- While the left arm places its target, keep the right hand closed with a + `target_qpos` whose source is `gripper_state` and state is `close`. +- After the left arm releases `{roles.left_target_runtime_uid}`, first move it + upward to clear the container as part of the same `PlaceAction`. +- The next nominal edge must pair the left arm's initial `target_qpos` move with + the right arm's object-referenced `target_pose` high-staging move. Do not split this + parallel handoff into separate edges. +- After the parallel handoff edge, the remaining right-side placement steps put + the actual action in `right_arm_action` and set `left_arm_action` to null. +- Never use initial `target_qpos` for an arm that is still holding a target object. + +Use these action patterns: +- Left pick-up: + {_format_pick_up_spec("left_arm", roles.left_target_runtime_uid)} +- Right pick-up: + {_format_pick_up_spec("right_arm", roles.right_target_runtime_uid)} +- Left high staging: + {left_high_spec} +- Left place action: + {left_place_spec} +- Right high staging: + {right_high_spec} +- Right place action: + {right_place_spec} +- Keep a holding arm closed: + {_format_gripper_spec("", "close", sample_interval=10)} +- Return to initial qpos: + {_format_initial_qpos_spec("", sample_interval=30)} +""" + + +def _format_pick_up_spec( + robot_name: str, + obj_name: str, + *, + sample_interval: int = 45, +) -> str: + return _compact_json( + { + "atomic_action_class": "PickUpAction", + "robot_name": robot_name, + "control": "arm", + "target_object": { + "obj_name": obj_name, + "affordance": "antipodal", + }, + "cfg": { + "pre_grasp_distance": 0.08, + "sample_interval": sample_interval, + }, + } + ) + + +def _format_pose_object_spec( + robot_name: str, + obj_name: str, + offset: tuple[float, float, float] | list[float], + *, + sample_interval: int, +) -> str: + x, y, z = offset + return _compact_json( + { + "atomic_action_class": "MoveAction", + "robot_name": robot_name, + "control": "arm", + "target_pose": { + "reference": "object", + "obj_name": obj_name, + "offset": [float(x), float(y), float(z)], + }, + "cfg": {"sample_interval": sample_interval}, + } + ) + + +def _format_place_object_spec( + robot_name: str, + obj_name: str, + offset: tuple[float, float, float] | list[float], + *, + sample_interval: int, + lift_height: float, +) -> str: + x, y, z = offset + return _format_place_spec( + robot_name, + { + "reference": "object", + "obj_name": obj_name, + "offset": [float(x), float(y), float(z)], + }, + sample_interval=sample_interval, + lift_height=lift_height, + ) + + +def _format_relative_pose_spec( + robot_name: str, + placement: _RelativePlacementLike, + *, + pose_kind: str, + sample_interval: int, +) -> str: + if getattr(placement, "reference_is_initial_pose", False): + position = ( + placement.high_position + if pose_kind == "high" + else placement.release_position + ) + if position is None: + raise ValueError( + "Self-relative placement requires absolute high/release positions." + ) + return _format_pose_absolute_spec( + robot_name, + position, + sample_interval=sample_interval, + ) + + offset = placement.high_offset if pose_kind == "high" else placement.release_offset + return _format_pose_object_spec( + robot_name, + placement.reference_runtime_uid, + offset, + sample_interval=sample_interval, + ) + + +def _format_relative_place_spec( + robot_name: str, + placement: _RelativePlacementLike, + *, + sample_interval: int, + lift_height: float, +) -> str: + if getattr(placement, "reference_is_initial_pose", False): + if placement.release_position is None: + raise ValueError("Self-relative placement requires release position.") + return _format_place_absolute_spec( + robot_name, + placement.release_position, + sample_interval=sample_interval, + lift_height=lift_height, + ) + + return _format_place_object_spec( + robot_name, + placement.reference_runtime_uid, + placement.release_offset, + sample_interval=sample_interval, + lift_height=lift_height, + ) + + +def _format_pose_absolute_spec( + robot_name: str, + position: Sequence[float], + *, + sample_interval: int, +) -> str: + return _compact_json( + { + "atomic_action_class": "MoveAction", + "robot_name": robot_name, + "control": "arm", + "target_pose": { + "reference": "absolute", + "position": [float(value) for value in position], + }, + "cfg": {"sample_interval": sample_interval}, + } + ) + + +def _format_place_absolute_spec( + robot_name: str, + position: Sequence[float], + *, + sample_interval: int, + lift_height: float, +) -> str: + return _format_place_spec( + robot_name, + { + "reference": "absolute", + "position": [float(value) for value in position], + }, + sample_interval=sample_interval, + lift_height=lift_height, + ) + + +def _format_place_spec( + robot_name: str, + target_pose: Mapping[str, Any], + *, + sample_interval: int, + lift_height: float, +) -> str: + return _compact_json( + { + "atomic_action_class": "PlaceAction", + "robot_name": robot_name, + "control": "arm", + "target_pose": dict(target_pose), + "cfg": { + "sample_interval": sample_interval, + "lift_height": float(lift_height), + }, + } + ) + + +def _format_gripper_spec( + robot_name: str, + state: str, + *, + sample_interval: int, + post_hold_steps: int = 0, +) -> str: + cfg = {"sample_interval": sample_interval} + if post_hold_steps: + cfg["post_hold_steps"] = post_hold_steps + return _compact_json( + { + "atomic_action_class": "MoveAction", + "robot_name": robot_name, + "control": "hand", + "target_qpos": {"source": "gripper_state", "state": state}, + "cfg": cfg, + } + ) + + +def _format_initial_qpos_spec( + robot_name: str, + *, + sample_interval: int, +) -> str: + return _compact_json( + { + "atomic_action_class": "MoveAction", + "robot_name": robot_name, + "control": "arm", + "target_qpos": {"source": "initial"}, + "cfg": {"sample_interval": sample_interval}, + } + ) + + +def _compact_json(value: Mapping[str, Any]) -> str: + return json.dumps(value, ensure_ascii=False, separators=(",", ":")) + + +def _format_action_sketch(action_sketch: list[str]) -> str: + return "\n".join(f"- {item}" for item in action_sketch) + + +def _relative_reference_line(spec: _RelativePlacementLike) -> str: + if getattr(spec, "reference_is_initial_pose", False): + return ( + f"Use the initial position of `{spec.moved_runtime_uid}` as the fixed " + f"spatial anchor. Source object: `{spec.moved_source_uid}`." + ) + return ( + f"Use `{spec.reference_runtime_uid}` as the spatial reference. Source " + f"object: `{spec.reference_source_uid}`." + ) + + +def _relative_pose_step_label( + spec: _RelativePlacementLike, + label: str, +) -> str: + if getattr(spec, "reference_is_initial_pose", False): + return f"{label} at the absolute initial-position offset" + return f"{label} relative to `{spec.reference_runtime_uid}`" + + +def _relative_final_planning_rule( + project_name: str, + spec: _RelativePlacementLike, +) -> str: + if getattr(spec, "reference_is_initial_pose", False): + return ( + "Use the exact absolute target_pose JSON specs shown above. Do not " + "rewrite this self-relative task as an object-referenced pose, because " + "the moved object would become a moving reference after pickup." + ) + return ( + f"Always plan to the current object poses from the exported {project_name} " + "environment config. Do not hard-code absolute object coordinates in the " + "generated graph." + ) + + +def _dual_relative_final_planning_rule( + project_name: str, + spec: _RelativeSpecLike, +) -> str: + if any( + getattr(placement, "reference_is_initial_pose", False) + for placement in spec.placements + ): + return ( + "Use the exact absolute target_pose JSON specs shown above for any " + "initial-position placement. Do not rewrite those self-relative " + "steps as object-referenced poses." + ) + return ( + f"Always plan to the current object poses from the exported {project_name} " + "environment config. Do not hard-code absolute object coordinates in the " + "generated graph." + ) + + +def _relative_relation_phrase(relation: str) -> str: + if relation == "inside": + return "inside" + if relation == "on": + return "on top of" + if relation == "left_of": + return "to the left of" + if relation == "right_of": + return "to the right of" + if relation == "front_of": + return "in front of" + if relation == "behind": + return "behind" + if relation == "front_left_of": + return "to the front-left of" + if relation == "back_left_of": + return "to the back-left of" + if relation == "front_right_of": + return "to the front-right of" + if relation == "back_right_of": + return "to the back-right of" + raise ValueError(f"Unsupported relative placement relation: {relation!r}.") + + +def _left_target_text(roles: _BasketRolesLike) -> str: + return _display_noun(roles.left_target_noun) + + +def _right_target_text(roles: _BasketRolesLike) -> str: + return _display_noun(roles.right_target_noun) + + +def _target_pair_text(roles: _BasketRolesLike) -> str: + left_text = _left_target_text(roles) + right_text = _right_target_text(roles) + if left_text == right_text: + return f"two {left_text} objects" + return f"the left {left_text} and right {right_text}" + + +def _target_plural_text(roles: _BasketRolesLike) -> str: + left_text = _left_target_text(roles) + right_text = _right_target_text(roles) + if left_text == right_text: + return _plural(left_text) + return "target objects" + + +def _display_noun(uid: str) -> str: + return uid.replace("_", " ") + + +def _plural(noun: str) -> str: + if noun.endswith("s"): + return noun + if noun.endswith(("ch", "sh", "x")): + return f"{noun}es" + return f"{noun}s" diff --git a/embodichain/gen_sim/action_agent_pipeline/generation/ur5_basket_config.py b/embodichain/gen_sim/action_agent_pipeline/generation/ur5_basket_config.py new file mode 100644 index 00000000..e0290c34 --- /dev/null +++ b/embodichain/gen_sim/action_agent_pipeline/generation/ur5_basket_config.py @@ -0,0 +1,3665 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from collections.abc import Mapping, Sequence +from dataclasses import dataclass +from pathlib import Path +from typing import Any +import copy +import json +import math +import re +import struct + +from embodichain.gen_sim.action_agent_pipeline.generation.mesh_frame_normalization import ( + MeshFrameNormalizer, +) +from embodichain.gen_sim.action_agent_pipeline.generation.prompt_builders import ( + make_agent_config, + make_basket_atom_actions_prompt, + make_basket_basic_background, + make_basket_task_prompt, + make_relative_atom_actions_prompt, + make_relative_basic_background, + make_relative_task_prompt, +) + +__all__ = [ + "GeneratedUR5BasketConfigPaths", + "TargetReplacementSpec", + "generate_ur5_basket_config_from_project", +] + +_DIGIT_SUFFIX_RE = re.compile(r"_[0-9]+$") +_INVALID_UID_CHARS_RE = re.compile(r"[^0-9a-zA-Z_]+") +_PROJECT_NAME_RE = re.compile(r"^[0-9]+_gym_project$") +_GYM_CONFIG_FILENAMES = frozenset({"gym_config.json", "gym_config_merged.json"}) +_GYM_CONFIG_PREFERENCE = ("gym_config_merged.json", "gym_config.json") +_TARGET_REPLACEMENT_MANIFEST_FILENAME = ".embodichain_replacement_manifest.json" + +_CONTAINER_KEYWORDS = ( + "basket", + "container", + "bowl", + "box", + "bin", + "tray", + "crate", +) + +_RELATIVE_RELATIONS = { + "inside", + "on", + "left_of", + "right_of", + "front_of", + "behind", + "front_left_of", + "back_left_of", + "front_right_of", + "back_right_of", +} + +_SIDE_RELATIONS = _RELATIVE_RELATIONS - {"inside", "on"} + +_SELF_REFERENCE_VALUES = { + "self", + "initial_self", + "initial_position", + "initial_pose", + "origin", + "itself", + "自身", + "自己", + "原位", + "初始位置", +} + +_RELATION_ALIASES = { + "in": "inside", + "into": "inside", + "inside": "inside", + "放入": "inside", + "放进": "inside", + "里面": "inside", + "on": "on", + "onto": "on", + "on_top": "on", + "on_top_of": "on", + "above": "on", + "top": "on", + "上": "on", + "上方": "on", + "上面": "on", + "叠放": "on", + "left": "left_of", + "left_of": "left_of", + "to_the_left_of": "left_of", + "左": "left_of", + "左边": "left_of", + "front_left": "front_left_of", + "front_left_of": "front_left_of", + "left_front": "front_left_of", + "left_front_of": "front_left_of", + "to_the_front_left_of": "front_left_of", + "左前": "front_left_of", + "左前方": "front_left_of", + "左前面": "front_left_of", + "back_left": "back_left_of", + "back_left_of": "back_left_of", + "behind_left": "back_left_of", + "left_back": "back_left_of", + "left_behind": "back_left_of", + "left_back_of": "back_left_of", + "to_the_back_left_of": "back_left_of", + "左后": "back_left_of", + "左后方": "back_left_of", + "左后面": "back_left_of", + "右": "right_of", + "右边": "right_of", + "right": "right_of", + "right_of": "right_of", + "to_the_right_of": "right_of", + "front_right": "front_right_of", + "front_right_of": "front_right_of", + "right_front": "front_right_of", + "right_front_of": "front_right_of", + "to_the_front_right_of": "front_right_of", + "右前": "front_right_of", + "右前方": "front_right_of", + "右前面": "front_right_of", + "back_right": "back_right_of", + "back_right_of": "back_right_of", + "behind_right": "back_right_of", + "right_back": "back_right_of", + "right_behind": "back_right_of", + "right_back_of": "back_right_of", + "to_the_back_right_of": "back_right_of", + "右后": "back_right_of", + "右后方": "back_right_of", + "右后面": "back_right_of", + "front": "front_of", + "front_of": "front_of", + "in_front_of": "front_of", + "前": "front_of", + "前方": "front_of", + "前面": "front_of", + "back": "behind", + "behind": "behind", + "back_of": "behind", + "后": "behind", + "后方": "behind", + "后面": "behind", +} + +_SIDE_RELATION_DISTANCE = 0.16 +_SIDE_RELEASE_Z_OFFSET = 0.12 +_STAGING_Z_DELTA = 0.10 +_ON_RELEASE_Z_OFFSET = 0.2 +_DUAL_UR5_LEGACY_INIT_Z = 0.5 +_DUAL_UR5_HIGH_TABLETOP_THRESHOLD = 1.0 +_DUAL_UR5_HIGH_TABLETOP_INIT_Z = 0.8 +_DUAL_UR5_ARM_COMPONENT_Z = 0.4 +_DUAL_UR5_TABLETOP_CLEARANCE = 0.25 +_DUAL_UR5_SIDE_AXIS_INDEX = 1 +_DUAL_UR5_ROTATED_INIT_X = 2.0 +_DUAL_UR5_ROTATED_INIT_YAW_DEGREES = -90.0 +_ROBOT_VIEW_LEFT_WORLD_Y_SIGN = -1.0 +_ROBOT_VIEW_FRONT_WORLD_X_SIGN = -1.0 +_BACKGROUND_MAX_CONVEX_HULL_NUM = 1 +_TARGET_MAX_CONVEX_HULL_NUM = 16 +_CONTAINER_MAX_CONVEX_HULL_NUM = 8 +_EXTRA_RIGID_MAX_CONVEX_HULL_NUM = 1 +_TABLETOP_OBJECT_CLEARANCE = 0.003 +_GLB_JSON_CHUNK_TYPE = 0x4E4F534A +_GLB_BINARY_CHUNK_TYPE = 0x004E4942 +_GLTF_COMPONENT_FORMATS = { + 5120: ("b", 1), + 5121: ("B", 1), + 5122: ("h", 2), + 5123: ("H", 2), + 5125: ("I", 4), + 5126: ("f", 4), +} +_GLTF_TYPE_COMPONENT_COUNTS = { + "SCALAR": 1, + "VEC2": 2, + "VEC3": 3, + "VEC4": 4, + "MAT4": 16, +} + +_BACKGROUND_ATTRS = { + "mass": 10.0, + "static_friction": 0.95, + "dynamic_friction": 0.9, + "restitution": 0.01, +} + +_RIGID_OBJECT_ATTRS = { + "mass": 0.01, + "contact_offset": 0.003, + "rest_offset": 0.001, + "restitution": 0.01, + "max_depenetration_velocity": 10.0, + "min_position_iters": 32, + "min_velocity_iters": 8, +} + + +@dataclass(frozen=True) +class GeneratedUR5BasketConfigPaths: + """Paths written by the UR5 basket config generator.""" + + output_dir: Path + gym_config: Path + agent_config: Path + task_prompt: Path + basic_background: Path + atom_actions: Path + summary: dict[str, Any] + + +@dataclass(frozen=True) +class TargetReplacementSpec: + """Prompt-to-geometry replacement for one source target object.""" + + source_uid: str + prompt: str + output_dir_name: str + + +@dataclass(frozen=True) +class _SceneObject: + source_uid: str + source_role: str + config: dict[str, Any] + + +@dataclass(frozen=True) +class _BasketTaskRoles: + table_source_uid: str + container_source_uid: str + left_target_source_uid: str + right_target_source_uid: str + container_runtime_uid: str + left_target_runtime_uid: str + right_target_runtime_uid: str + target_noun: str + left_target_noun: str + right_target_noun: str + container_noun: str + + +@dataclass(frozen=True) +class _ResolvedTargetReplacement: + source_uid: str + prompt: str + output_dir_name: str + mesh_path: Path + runtime_noun: str + reused: bool = False + + +@dataclass(frozen=True) +class _RelativePlacementStepSpec: + moved_source_uid: str + reference_source_uid: str + moved_runtime_uid: str + reference_runtime_uid: str + relation: str + active_side: str + release_offset: list[float] + high_offset: list[float] + reference_is_initial_pose: bool = False + release_position: list[float] | None = None + high_position: list[float] | None = None + + +@dataclass(frozen=True) +class _RelativePlacementSpec: + table_source_uid: str + moved_source_uid: str + reference_source_uid: str + moved_runtime_uid: str + reference_runtime_uid: str + relation: str + active_side: str + task_description: str + task_prompt_summary: str + basic_background_notes: str + action_sketch: list[str] + release_offset: list[float] + high_offset: list[float] + placements: tuple[_RelativePlacementStepSpec, ...] + reference_is_initial_pose: bool = False + release_position: list[float] | None = None + high_position: list[float] | None = None + + +def generate_ur5_basket_config_from_project( + gym_project: str | Path, + output_dir: str | Path, + *, + task_name: str = "UR5BreadBasket", + task_description: str | None = None, + use_llm_roles: bool = False, + llm_model: str | None = None, + target_body_scale: float | list[float] | tuple[float, float, float] = 0.7, + target_replacements: Sequence[TargetReplacementSpec] | None = None, + sync_replacement_names: bool = False, + reuse_target_replacements: bool = True, + prewarm_coacd_cache: bool = True, + overwrite: bool = False, + max_episodes: int = 1, + max_episode_steps: int = 1000, +) -> GeneratedUR5BasketConfigPaths: + """Generate Dual-UR5 basket placement configs from an exported gym project. + + This first-stage generator intentionally keeps the UR5BreadBasket task + structure fixed: the left arm grasps the left target object, the right arm + grasps the right target object, and both objects are placed into one + basket-like container. + + Args: + gym_project: Project root, formatted scene folder, ``gym_config.json``, + or ``gym_config_merged.json``. + output_dir: Destination config directory. + task_name: Name passed to ``run_agent``. + task_description: Optional natural-language relative-placement task. + When provided, the generator asks the shared LLM for a constrained + config-level task spec and generates prompts from that spec. + use_llm_roles: If true, use an LLM only to refine object role mapping. + llm_model: Optional model override for role refinement. + target_body_scale: Uniform or xyz scale applied to generated target + objects. Basket-like containers keep their source ``body_scale``. + target_replacements: Optional prompt-generated GLB replacements for + selected default basket target objects. Each replacement writes to + ``/mesh_assets/`` and only affects the + generated config, not the original source mesh file. + sync_replacement_names: If true, update runtime target UIDs and prompts + from the replacement prompts. If false, only mesh paths are replaced. + reuse_target_replacements: If true, reuse an existing replacement GLB + at the expected output path when it matches the requested prompt. + prewarm_coacd_cache: If true, precompute environment-side CoACD cache + files referenced by the generated gym config before writing it. + overwrite: If false, fail when generated files already exist. + max_episodes: Value written to ``fast_gym_config.json``. + max_episode_steps: Value written to ``fast_gym_config.json``. + + Returns: + Paths of generated config files. + """ + + output_dir_path = Path(output_dir).expanduser().resolve() + _raise_if_generated_files_exist(output_dir_path, overwrite) + + input_path = Path(gym_project).expanduser().resolve() + gym_config_path = _resolve_gym_config_path(input_path) + scene_dir = gym_config_path.parent + source_config = _read_json(gym_config_path) + project_name = _infer_project_name(input_path, scene_dir) + replacement_specs = _normalize_target_replacements(target_replacements) + mesh_normalizer = MeshFrameNormalizer( + output_dir=output_dir_path / "mesh_assets" / "normalized" + ) + + scene_objects = _collect_scene_objects(source_config) + if task_description: + if replacement_specs: + raise ValueError( + "target_replacements are only supported by the default basket " + "template. Do not combine them with task_description." + ) + spec = _build_relative_placement_spec_with_llm( + scene_objects=scene_objects, + project_name=project_name, + task_description=task_description, + model=llm_model, + ) + bundle = _build_relative_placement_bundle( + scene_dir=scene_dir, + source_config=source_config, + spec=spec, + project_name=project_name, + task_name=task_name, + target_body_scale=target_body_scale, + max_episodes=max_episodes, + max_episode_steps=max_episode_steps, + mesh_normalizer=mesh_normalizer, + ) + _validate_relative_bundle(bundle, spec) + _attach_mesh_normalization_summary(bundle, mesh_normalizer) + if prewarm_coacd_cache: + _attach_coacd_cache_summary(bundle) + return _write_config_bundle( + output_dir=output_dir_path, + bundle=bundle, + overwrite=overwrite, + ) + + roles = _infer_basket_task_roles(scene_objects) + if use_llm_roles: + roles = _refine_roles_with_llm( + roles=roles, + scene_objects=scene_objects, + project_name=project_name, + model=llm_model, + ) + + _validate_target_replacement_sources(roles, replacement_specs) + resolved_replacements = _run_target_replacements( + scene_dir=scene_dir, + replacement_specs=replacement_specs, + reuse_target_replacements=reuse_target_replacements, + ) + if sync_replacement_names: + roles = _apply_replacement_names( + roles, + resolved_replacements, + ) + + bundle = _build_ur5_basket_bundle( + scene_dir=scene_dir, + source_config=source_config, + roles=roles, + project_name=project_name, + task_name=task_name, + target_body_scale=target_body_scale, + target_replacements=resolved_replacements, + max_episodes=max_episodes, + max_episode_steps=max_episode_steps, + mesh_normalizer=mesh_normalizer, + ) + _validate_bundle(bundle, roles) + _attach_mesh_normalization_summary(bundle, mesh_normalizer) + if prewarm_coacd_cache: + _attach_coacd_cache_summary(bundle) + return _write_config_bundle( + output_dir=output_dir_path, + bundle=bundle, + overwrite=overwrite, + ) + + +def _resolve_gym_config_path(input_path: Path) -> Path: + if input_path.is_file(): + if input_path.name not in _GYM_CONFIG_FILENAMES: + expected = ", ".join(sorted(_GYM_CONFIG_FILENAMES)) + raise ValueError(f"Expected one of {expected}, got: {input_path}") + return input_path + + direct = _preferred_gym_config_in_dir(input_path) + if direct is not None: + return direct + + formatted_scene_dirs = sorted( + { + path.parent + for filename in _GYM_CONFIG_FILENAMES + for path in input_path.glob(f"formatted_tabletop_scene/*/{filename}") + } + ) + formatted_matches = [ + path + for scene_dir in formatted_scene_dirs + if (path := _preferred_gym_config_in_dir(scene_dir)) is not None + ] + if len(formatted_matches) == 1: + return formatted_matches[0] + if len(formatted_matches) > 1: + matches = ", ".join(path.as_posix() for path in formatted_matches) + raise ValueError(f"Multiple formatted gym config files found: {matches}") + + recursive_scene_dirs = sorted( + { + path.parent + for filename in _GYM_CONFIG_FILENAMES + for path in input_path.rglob(filename) + } + ) + recursive_matches = [ + path + for scene_dir in recursive_scene_dirs + if (path := _preferred_gym_config_in_dir(scene_dir)) is not None + ] + if len(recursive_matches) == 1: + return recursive_matches[0] + if not recursive_matches: + expected = " or ".join(_GYM_CONFIG_PREFERENCE) + raise FileNotFoundError(f"{expected} not found under: {input_path}") + matches = ", ".join(path.as_posix() for path in recursive_matches) + raise ValueError(f"Multiple gym config files found: {matches}") + + +def _preferred_gym_config_in_dir(scene_dir: Path) -> Path | None: + for filename in _GYM_CONFIG_PREFERENCE: + path = scene_dir / filename + if path.is_file(): + return path + return None + + +def _infer_project_name(input_path: Path, scene_dir: Path) -> str: + for part in input_path.parts: + if _PROJECT_NAME_RE.match(part): + return part + for part in scene_dir.parts: + if _PROJECT_NAME_RE.match(part): + return part + return scene_dir.name + + +def _collect_scene_objects(scene_config: Mapping[str, Any]) -> list[_SceneObject]: + scene_objects = [] + for source_role in ("background", "rigid_object"): + for obj_config in scene_config.get(source_role, []) or []: + source_uid = str(obj_config.get("uid", "")).strip() + if not source_uid: + raise ValueError(f"Scene object without uid in {source_role}.") + scene_objects.append( + _SceneObject( + source_uid=source_uid, + source_role=source_role, + config=copy.deepcopy(dict(obj_config)), + ) + ) + + if not scene_objects: + raise ValueError("No background or rigid_object entries found in gym config.") + return scene_objects + + +def _infer_basket_task_roles(scene_objects: list[_SceneObject]) -> _BasketTaskRoles: + background_objects = [ + obj for obj in scene_objects if obj.source_role == "background" + ] + rigid_objects = [obj for obj in scene_objects if obj.source_role == "rigid_object"] + if not background_objects: + raise ValueError("UR5 basket generation requires a table/background object.") + if len(rigid_objects) < 3: + raise ValueError( + "UR5 basket generation requires at least two target objects and one " + "basket-like container." + ) + + table = _pick_table(background_objects) + container = _pick_container(rigid_objects) + target_candidates = [ + obj for obj in rigid_objects if obj.source_uid != container.source_uid + ] + if len(target_candidates) < 2: + raise ValueError("Expected at least two non-container target objects.") + + left_target, right_target = _pick_left_right_targets(target_candidates) + target_noun = _target_noun(left_target, right_target) + container_noun = _display_noun(_base_name(container)) + return _BasketTaskRoles( + table_source_uid=table.source_uid, + container_source_uid=container.source_uid, + left_target_source_uid=left_target.source_uid, + right_target_source_uid=right_target.source_uid, + container_runtime_uid=_container_runtime_uid(container), + left_target_runtime_uid=f"left_{target_noun}", + right_target_runtime_uid=f"right_{target_noun}", + target_noun=target_noun, + left_target_noun=target_noun, + right_target_noun=target_noun, + container_noun=container_noun, + ) + + +def _pick_table(background_objects: list[_SceneObject]) -> _SceneObject: + for obj in background_objects: + text = _object_text(obj) + if "table" in text: + return obj + return background_objects[0] + + +def _pick_container(rigid_objects: list[_SceneObject]) -> _SceneObject: + candidates = [ + obj + for obj in rigid_objects + if any(keyword in _object_text(obj) for keyword in _CONTAINER_KEYWORDS) + ] + if not candidates: + names = ", ".join(obj.source_uid for obj in rigid_objects) + raise ValueError(f"No basket-like container object found among: {names}") + + def score(obj: _SceneObject) -> tuple[int, float]: + text = _object_text(obj) + keyword_score = 0 if "basket" in text else 1 + pos = _vector3(obj.config.get("init_pos", [0.0, 0.0, 0.0])) + center_distance = abs(pos[0]) + abs(pos[1]) + return keyword_score, center_distance + + return sorted(candidates, key=score)[0] + + +def _pick_left_right_targets( + target_candidates: list[_SceneObject], +) -> tuple[_SceneObject, _SceneObject]: + if len(target_candidates) == 2: + picked = target_candidates + else: + grouped: dict[str, list[_SceneObject]] = {} + for obj in target_candidates: + grouped.setdefault(_base_name(obj), []).append(obj) + repeated_groups = [group for group in grouped.values() if len(group) >= 2] + if repeated_groups: + picked = sorted( + repeated_groups, + key=_target_group_sort_key, + )[0] + if len(picked) > 2: + picked = sorted( + picked, + key=lambda obj: abs(_side_axis_value(obj)), + reverse=True, + )[:2] + else: + picked = sorted( + target_candidates, + key=lambda obj: abs(_side_axis_value(obj)), + reverse=True, + )[:2] + left, right = sorted(picked, key=_side_axis_value) + return left, right + + +def _target_group_sort_key(group: list[_SceneObject]) -> tuple[float, int]: + side_values = [_side_axis_value(obj) for obj in group] + side_spread = max(side_values) - min(side_values) + return -side_spread, -len(group) + + +def _side_axis_value(obj: _SceneObject) -> float: + return _position_side_axis_value( + _vector3(obj.config.get("init_pos", [0.0, 0.0, 0.0])) + ) + + +def _position_side_axis_value(position: list[float]) -> float: + return float(position[_DUAL_UR5_SIDE_AXIS_INDEX]) + + +def _arm_side_for_position(position: list[float]) -> str: + return "left" if _position_side_axis_value(position) < 0.0 else "right" + + +def _target_noun(left_target: _SceneObject, right_target: _SceneObject) -> str: + left_base = _base_name(left_target) + right_base = _base_name(right_target) + if left_base == right_base: + return _target_runtime_suffix(left_base) + return "target_object" + + +def _object_text(obj: _SceneObject) -> str: + shape = obj.config.get("shape", {}) or {} + return f"{obj.source_uid} {shape.get('fpath', '')}".lower() + + +def _base_name(obj: _SceneObject) -> str: + base = _DIGIT_SUFFIX_RE.sub("", obj.source_uid) + if base == obj.source_uid: + fpath = str(obj.config.get("shape", {}).get("fpath", "")) + path = Path(fpath) + if len(path.parts) >= 2: + base = path.parts[-2] + return _normalize_runtime_uid(base) + + +def _target_runtime_suffix(base: str) -> str: + if base == "bread": + return "bread_roll" + return base + + +def _container_runtime_uid(container: _SceneObject) -> str: + base = _base_name(container) + if "basket" in base: + return "wicker_basket" + return f"target_{base}" + + +def _display_noun(uid: str) -> str: + return uid.replace("_", " ") + + +def _plural(noun: str) -> str: + if noun.endswith("s"): + return noun + if noun.endswith(("ch", "sh", "x")): + return f"{noun}es" + return f"{noun}s" + + +def _left_target_text(roles: _BasketTaskRoles) -> str: + return _display_noun(roles.left_target_noun) + + +def _right_target_text(roles: _BasketTaskRoles) -> str: + return _display_noun(roles.right_target_noun) + + +def _target_pair_text(roles: _BasketTaskRoles) -> str: + left_text = _left_target_text(roles) + right_text = _right_target_text(roles) + if left_text == right_text: + return f"two {left_text} objects" + return f"the left {left_text} and right {right_text}" + + +def _target_plural_text(roles: _BasketTaskRoles) -> str: + left_text = _left_target_text(roles) + right_text = _right_target_text(roles) + if left_text == right_text: + return _plural(left_text) + return "target objects" + + +def _generic_target_text(roles: _BasketTaskRoles) -> str: + left_text = _left_target_text(roles) + right_text = _right_target_text(roles) + if left_text == right_text: + return left_text + return "target object" + + +def _target_task_description_text(roles: _BasketTaskRoles) -> str: + left_text = _left_target_text(roles) + right_text = _right_target_text(roles) + if left_text == right_text: + return _plural(left_text) + return f"{left_text}-and-{right_text}" + + +def _normalize_runtime_uid(value: str) -> str: + uid = _INVALID_UID_CHARS_RE.sub("_", value.strip()).strip("_").lower() + if not uid: + raise ValueError(f"Invalid runtime uid: {value!r}") + return uid + + +def _normalize_target_replacements( + target_replacements: Sequence[TargetReplacementSpec] | None, +) -> tuple[TargetReplacementSpec, ...]: + if not target_replacements: + return () + + normalized = [] + seen_source_uids = set() + seen_output_dirs = set() + for replacement in target_replacements: + if not isinstance(replacement, TargetReplacementSpec): + raise TypeError( + "target_replacements must contain TargetReplacementSpec values." + ) + source_uid = str(replacement.source_uid).strip() + prompt = str(replacement.prompt).strip() + output_dir_name = str(replacement.output_dir_name).strip() + if not source_uid: + raise ValueError("target replacement source_uid must be non-empty.") + if not prompt: + raise ValueError("target replacement prompt must be non-empty.") + if not output_dir_name: + raise ValueError("target replacement output_dir_name must be non-empty.") + output_dir_path = Path(output_dir_name) + if ( + output_dir_path.is_absolute() + or len(output_dir_path.parts) != 1 + or output_dir_name in {".", ".."} + ): + raise ValueError( + "target replacement output_dir_name must be a single relative " + f"directory name, got: {output_dir_name!r}" + ) + if source_uid in seen_source_uids: + raise ValueError(f"Duplicate target replacement source uid: {source_uid}") + if output_dir_name in seen_output_dirs: + raise ValueError( + f"Duplicate target replacement output dir: {output_dir_name}" + ) + seen_source_uids.add(source_uid) + seen_output_dirs.add(output_dir_name) + normalized.append( + TargetReplacementSpec( + source_uid=source_uid, + prompt=prompt, + output_dir_name=output_dir_name, + ) + ) + return tuple(normalized) + + +def _validate_target_replacement_sources( + roles: _BasketTaskRoles, + replacement_specs: Sequence[TargetReplacementSpec], +) -> None: + if not replacement_specs: + return + + target_source_uids = { + roles.left_target_source_uid, + roles.right_target_source_uid, + } + unknown = [ + replacement.source_uid + for replacement in replacement_specs + if replacement.source_uid not in target_source_uids + ] + if unknown: + raise ValueError( + "target_replacements must reference the selected basket target " + f"source uid(s) {sorted(target_source_uids)}, got: {unknown}" + ) + + +def _run_target_replacements( + *, + scene_dir: Path, + replacement_specs: Sequence[TargetReplacementSpec], + reuse_target_replacements: bool, +) -> tuple[_ResolvedTargetReplacement, ...]: + resolved = [] + for replacement in replacement_specs: + runtime_noun = _replacement_runtime_noun(replacement.prompt) + output_root = scene_dir / "mesh_assets" / replacement.output_dir_name + output_name = f"{runtime_noun}.glb" + mesh_path = None + reused = False + if reuse_target_replacements: + mesh_path = _resolve_reusable_target_replacement_mesh_path( + output_root=output_root, + prompt=replacement.prompt, + output_name=output_name, + ) + reused = mesh_path is not None + if mesh_path is None: + result = _run_prompt2geometry_replacement( + prompt=replacement.prompt, + output_root=output_root, + output_name=output_name, + ) + mesh_path = _resolve_prompt2geometry_mesh_path(result, output_root) + _write_target_replacement_manifest( + output_root=output_root, + prompt=replacement.prompt, + output_name=output_name, + mesh_path=mesh_path, + ) + elif reused: + _write_target_replacement_manifest( + output_root=output_root, + prompt=replacement.prompt, + output_name=output_name, + mesh_path=mesh_path, + ) + resolved.append( + _ResolvedTargetReplacement( + source_uid=replacement.source_uid, + prompt=replacement.prompt, + output_dir_name=replacement.output_dir_name, + mesh_path=mesh_path, + runtime_noun=runtime_noun, + reused=reused, + ) + ) + return tuple(resolved) + + +def _resolve_reusable_target_replacement_mesh_path( + *, + output_root: Path, + prompt: str, + output_name: str, +) -> Path | None: + expected_mesh_path = (output_root / output_name).expanduser().resolve() + if not expected_mesh_path.is_file(): + return None + + manifest_path = _target_replacement_manifest_path(output_root) + if not manifest_path.is_file(): + return expected_mesh_path + + try: + manifest = _read_json(manifest_path) + except (OSError, json.JSONDecodeError): + return None + + if manifest.get("prompt") != prompt or manifest.get("output_name") != output_name: + return None + + manifest_mesh_path = Path( + str(manifest.get("mesh_path", expected_mesh_path)) + ).expanduser() + if not manifest_mesh_path.is_absolute(): + manifest_mesh_path = (output_root / manifest_mesh_path).resolve() + else: + manifest_mesh_path = manifest_mesh_path.resolve() + if manifest_mesh_path.is_file(): + return manifest_mesh_path + return expected_mesh_path + + +def _write_target_replacement_manifest( + *, + output_root: Path, + prompt: str, + output_name: str, + mesh_path: Path, +) -> None: + _write_json( + _target_replacement_manifest_path(output_root), + { + "prompt": prompt, + "output_name": output_name, + "mesh_path": mesh_path.expanduser().resolve().as_posix(), + }, + ) + + +def _target_replacement_manifest_path(output_root: Path) -> Path: + return output_root / _TARGET_REPLACEMENT_MANIFEST_FILENAME + + +def _run_prompt2geometry_replacement( + *, + prompt: str, + output_root: Path, + output_name: str, +) -> dict[str, Any]: + from embodichain.gen_sim.action_agent_pipeline.gym_project_api.prompt2geometry import ( + Prompt2GeometryRequest, + load_prompt2geometry_config, + run_prompt2geometry, + ) + + cfg = load_prompt2geometry_config() + return run_prompt2geometry( + Prompt2GeometryRequest( + prompt=prompt, + output_root=output_root, + output_name=output_name, + zimage_base_url=cfg.zimage_base_url, + sam3_base_url=cfg.sam3_base_url, + sam3d_base_url=cfg.sam3d_base_url, + llm_api_key=cfg.llm_api_key, + llm_model=cfg.llm_model, + llm_base_url=cfg.llm_base_url, + llm_timeout_s=cfg.llm_timeout_s, + ) + ) + + +def _resolve_prompt2geometry_mesh_path( + result: Mapping[str, Any], + output_root: Path, +) -> Path: + raw_path = result.get("scaled_mesh_path") or result.get("mesh_path") + if not raw_path: + raise ValueError("prompt2geometry result did not include a GLB mesh path.") + + mesh_path = Path(str(raw_path)).expanduser() + if not mesh_path.is_absolute(): + mesh_path = (output_root / mesh_path).resolve() + else: + mesh_path = mesh_path.resolve() + + if not mesh_path.is_file(): + raise FileNotFoundError(f"Generated replacement GLB not found: {mesh_path}") + return mesh_path + + +def _replacement_runtime_noun(prompt: str) -> str: + tokens = re.findall(r"[a-z0-9]+", prompt.lower()) + while tokens and tokens[0] in {"a", "an", "the"}: + tokens.pop(0) + stem = "_".join(tokens) + if not stem: + stem = "replacement_object" + return _normalize_runtime_uid(stem) + + +def _apply_replacement_names( + roles: _BasketTaskRoles, + resolved_replacements: Sequence[_ResolvedTargetReplacement], +) -> _BasketTaskRoles: + replacement_by_uid = { + replacement.source_uid: replacement for replacement in resolved_replacements + } + left_replacement = replacement_by_uid.get(roles.left_target_source_uid) + right_replacement = replacement_by_uid.get(roles.right_target_source_uid) + left_target_noun = ( + left_replacement.runtime_noun + if left_replacement is not None + else roles.left_target_noun + ) + right_target_noun = ( + right_replacement.runtime_noun + if right_replacement is not None + else roles.right_target_noun + ) + target_noun = ( + left_target_noun if left_target_noun == right_target_noun else "target_object" + ) + return _BasketTaskRoles( + table_source_uid=roles.table_source_uid, + container_source_uid=roles.container_source_uid, + left_target_source_uid=roles.left_target_source_uid, + right_target_source_uid=roles.right_target_source_uid, + container_runtime_uid=roles.container_runtime_uid, + left_target_runtime_uid=f"left_{left_target_noun}", + right_target_runtime_uid=f"right_{right_target_noun}", + target_noun=target_noun, + left_target_noun=left_target_noun, + right_target_noun=right_target_noun, + container_noun=roles.container_noun, + ) + + +def _refine_roles_with_llm( + *, + roles: _BasketTaskRoles, + scene_objects: list[_SceneObject], + project_name: str, + model: str | None, +) -> _BasketTaskRoles: + response = _call_role_llm( + project_name=project_name, + scene_summary=[ + { + "source_uid": obj.source_uid, + "role": obj.source_role, + "mesh": obj.config.get("shape", {}).get("fpath"), + "init_pos": obj.config.get("init_pos"), + } + for obj in scene_objects + ], + default_roles={ + "container_object": roles.container_source_uid, + "left_target_object": roles.left_target_source_uid, + "right_target_object": roles.right_target_source_uid, + "target_noun": roles.target_noun, + "container_runtime_uid": roles.container_runtime_uid, + }, + model=model, + ) + source_uids = {obj.source_uid for obj in scene_objects} + left_target = str(response.get("left_target_object", roles.left_target_source_uid)) + right_target = str( + response.get("right_target_object", roles.right_target_source_uid) + ) + container = str(response.get("container_object", roles.container_source_uid)) + for uid in (left_target, right_target, container): + if uid not in source_uids: + raise ValueError(f"LLM returned unknown source uid: {uid!r}") + if len({left_target, right_target, container}) != 3: + raise ValueError("LLM role mapping must use three distinct source objects.") + + target_noun = _normalize_runtime_uid( + str(response.get("target_noun", roles.target_noun)) + ) + container_runtime_uid = _normalize_runtime_uid( + str(response.get("container_runtime_uid", roles.container_runtime_uid)) + ) + return _BasketTaskRoles( + table_source_uid=roles.table_source_uid, + container_source_uid=container, + left_target_source_uid=left_target, + right_target_source_uid=right_target, + container_runtime_uid=container_runtime_uid, + left_target_runtime_uid=f"left_{target_noun}", + right_target_runtime_uid=f"right_{target_noun}", + target_noun=target_noun, + left_target_noun=target_noun, + right_target_noun=target_noun, + container_noun=_display_noun(container_runtime_uid), + ) + + +def _call_role_llm( + *, + project_name: str, + scene_summary: list[dict[str, Any]], + default_roles: dict[str, Any], + model: str | None, +) -> dict[str, Any]: + from langchain_core.messages import HumanMessage, SystemMessage + + from embodichain.gen_sim.action_agent_pipeline.utils.llm_json import ( + extract_json_object, + ) + from embodichain.gen_sim.action_agent_pipeline.utils.mllm import ( + create_chat_openai, + ) + + prompt = ( + "Identify roles for a fixed Dual-UR5 basket-placement simulation task. " + "Return only one JSON object with keys: container_object, " + "left_target_object, right_target_object, target_noun, " + "container_runtime_uid. Use only source_uid values from the scene. The " + "rotated robot-view left target starts on the negative-y side, and the " + "rotated robot-view right target starts on the positive-y side.\n\n" + f"Project: {project_name}\n" + f"Scene objects:\n{json.dumps(scene_summary, ensure_ascii=False, indent=2)}\n" + f"Default roles:\n{json.dumps(default_roles, ensure_ascii=False, indent=2)}" + ) + llm = create_chat_openai( + temperature=0.0, + model=model, + usage_stage="config_generation.role_refinement", + ) + response = llm.invoke( + [ + SystemMessage( + content=( + "You produce strict JSON role mappings for simulation config " + "generation. Do not include markdown." + ) + ), + HumanMessage(content=prompt), + ] + ) + content = getattr(response, "content", response) + return extract_json_object(content) + + +def _build_relative_placement_spec_with_llm( + *, + scene_objects: list[_SceneObject], + project_name: str, + task_description: str, + model: str | None, +) -> _RelativePlacementSpec: + background_objects = [ + obj for obj in scene_objects if obj.source_role == "background" + ] + rigid_objects = [obj for obj in scene_objects if obj.source_role == "rigid_object"] + if not background_objects: + raise ValueError("Relative placement generation requires a background table.") + if not rigid_objects: + raise ValueError( + "Relative placement generation requires a movable rigid object." + ) + + table = _pick_table(background_objects) + response = _call_relative_task_llm( + project_name=project_name, + task_description=task_description, + scene_summary=[ + { + "source_uid": obj.source_uid, + "role": obj.source_role, + "object_type": _base_name(obj), + "is_container_like": _is_container_like(obj), + "mesh": obj.config.get("shape", {}).get("fpath"), + "init_pos": obj.config.get("init_pos"), + } + for obj in scene_objects + ], + model=model, + ) + return _apply_relative_task_response( + response=response, + table_source_uid=table.source_uid, + scene_objects=scene_objects, + rigid_objects=rigid_objects, + task_description=task_description, + ) + + +def _call_relative_task_llm( + *, + project_name: str, + task_description: str, + scene_summary: list[dict[str, Any]], + model: str | None, +) -> dict[str, Any]: + from langchain_core.messages import HumanMessage, SystemMessage + + from embodichain.gen_sim.action_agent_pipeline.utils.llm_json import ( + extract_json_object, + ) + from embodichain.gen_sim.action_agent_pipeline.utils.mllm import ( + create_chat_openai, + ) + + prompt = ( + "Parse a simple Dual-UR5 tabletop relative-placement task and produce " + "a constrained config-level JSON spec. This JSON is used to generate " + "task_prompt.txt, basic_background.txt, atom_actions.txt, and " + "agent_success; a second LLM will later read those prompts to generate " + "the executable graph JSON.\n\n" + "Return exactly one JSON object with this schema:\n" + "{\n" + ' "placements": [\n' + " {\n" + ' "moved_object": "",\n' + ' "reference_object": "",\n' + ' "goal_relation": ' + '"inside|on|left_of|right_of|front_of|behind|front_left_of|back_left_of|front_right_of|back_right_of",\n' + ' "arm": "left|right|auto"\n' + " }\n" + " ],\n" + ' "task_prompt_summary": "",\n' + ' "basic_background_notes": "",\n' + ' "action_sketch": [\n' + ' "grasp moved_object",\n' + ' "move above the relation target pose",\n' + ' "place at the release pose with PlaceAction"\n' + " ]\n" + "}\n\n" + "Rules:\n" + "- Use only source_uid values from the scene objects listed below.\n" + "- Return one placement for a single-arm task and exactly two placements " + "for a dual-arm task.\n" + "- Treat the task as dual-arm when it explicitly says 双臂, 两臂, both " + "arms, two arms, or when it describes separate work for the left arm and " + "the right arm even if it does not literally say 双臂.\n" + "- Do not invent a second placement when the task only moves one object.\n" + "- moved_object is the object to grasp and move.\n" + "- reference_object is the object used as the spatial reference, " + "container, or support.\n" + "- reference_object may be a rigid_object or a background object such as " + "a pad, tray, basket, or container.\n" + "- For single-object directional tasks such as moving the only object " + "forward, left, front-left, or back-right from its initial position, set " + "reference_object to the same source_uid as moved_object (or 'self'). " + "This means the generator will use the object's initial position as a " + "fixed anchor, not the object's moving runtime pose.\n" + "- Within each placement, moved_object and reference_object must be " + "different unless the task is an initial-position directional move.\n" + "- For dual-arm tasks, the placements must use two different moved_object " + "values and one left arm plus one right arm. Use arm='auto' only when " + "the user did not specify which arm handles that placement.\n" + "- arm selects the single UR5 arm that should manipulate moved_object. " + "Use arm='left' for explicit left-arm instructions such as 左臂, 左机械臂, " + "left arm, or left UR5; use arm='right' for explicit right-arm " + "instructions such as 右臂, 右机械臂, right arm, or right UR5; use " + "arm='auto' when the task does not specify an arm.\n" + "- For Chinese/English left/right/front/back, use the relation enums " + "from the rotated robot-view perspective. front_of means negative " + "world-x; behind means positive world-x; left_of means negative " + "world-y; right_of means positive world-y. Diagonal relations combine " + "both axes: front_left_of, back_left_of, front_right_of, back_right_of.\n" + "- If the task says to release an object above a basket/container so it " + "falls into it, use goal_relation='inside'.\n" + "- If the task says to stack/place one object on another non-container " + "support, use goal_relation='on'.\n" + "- Do not return numeric offsets, object poses, scales, success JSON, " + "robot config, or full prompt files. The generator computes those " + "deterministically.\n\n" + f"Project: {project_name}\n" + f"Task description:\n{task_description}\n" + f"Scene objects:\n{json.dumps(scene_summary, ensure_ascii=False, indent=2)}" + ) + llm = create_chat_openai( + temperature=0.0, + model=model, + usage_stage="config_generation.relative_task", + ) + response = llm.invoke( + [ + SystemMessage( + content=( + "You produce strict JSON specs for simulation config " + "generation. Do not include markdown." + ) + ), + HumanMessage(content=prompt), + ] + ) + content = getattr(response, "content", response) + return extract_json_object(content) + + +def _apply_relative_task_response( + *, + response: Mapping[str, Any], + table_source_uid: str, + scene_objects: list[_SceneObject], + rigid_objects: list[_SceneObject], + task_description: str, +) -> _RelativePlacementSpec: + by_uid = {obj.source_uid: obj for obj in scene_objects} + runtime_uids = _relative_scene_runtime_uid_mapping( + scene_objects, + table_source_uid=table_source_uid, + ) + + placement_entries = _relative_placement_entries(response) + if len(placement_entries) > 2: + raise ValueError("Relative placement supports at most two arm placements.") + + forced_arm_sides = _relative_forced_arm_sides( + placement_entries, + by_uid=by_uid, + rigid_objects=rigid_objects, + ) + placements = tuple( + _build_relative_placement_step( + entry=entry, + by_uid=by_uid, + scene_objects=scene_objects, + rigid_objects=rigid_objects, + runtime_uids=runtime_uids, + forced_side=forced_side, + ) + for entry, forced_side in zip(placement_entries, forced_arm_sides) + ) + _validate_relative_placements(placements) + + summary = str(response.get("task_prompt_summary", "")).strip() + if not summary: + summary = _default_relative_plan_summary(placements) + background_notes = str(response.get("basic_background_notes", "")).strip() + action_sketch = _string_list(response.get("action_sketch")) + if not action_sketch: + action_sketch = _default_relative_action_sketch(placements) + + primary = placements[0] + + return _RelativePlacementSpec( + table_source_uid=table_source_uid, + moved_source_uid=primary.moved_source_uid, + reference_source_uid=primary.reference_source_uid, + moved_runtime_uid=primary.moved_runtime_uid, + reference_runtime_uid=primary.reference_runtime_uid, + relation=primary.relation, + active_side=primary.active_side, + task_description=task_description, + task_prompt_summary=summary, + basic_background_notes=background_notes, + action_sketch=action_sketch, + release_offset=primary.release_offset, + high_offset=primary.high_offset, + placements=placements, + reference_is_initial_pose=primary.reference_is_initial_pose, + release_position=primary.release_position, + high_position=primary.high_position, + ) + + +def _relative_placement_entries(response: Mapping[str, Any]) -> list[Mapping[str, Any]]: + placements = response.get("placements") + if placements is None: + return [response] + if not isinstance(placements, list) or not placements: + raise ValueError("LLM response placements must be a non-empty list.") + entries: list[Mapping[str, Any]] = [] + for index, placement in enumerate(placements): + if not isinstance(placement, Mapping): + raise ValueError(f"Placement {index} must be a JSON object.") + entries.append(placement) + return entries + + +def _relative_forced_arm_sides( + placement_entries: list[Mapping[str, Any]], + *, + by_uid: Mapping[str, _SceneObject], + rigid_objects: list[_SceneObject], +) -> list[str | None]: + if len(placement_entries) != 2: + return [None for _ in placement_entries] + + requested_sides = [ + _normalize_relative_arm(entry.get("arm")) for entry in placement_entries + ] + explicit_sides = [side for side in requested_sides if side != "auto"] + if len(explicit_sides) == 2: + return [None, None] + if len(explicit_sides) == 1: + complement = "right" if explicit_sides[0] == "left" else "left" + return [ + requested_side if requested_side != "auto" else complement + for requested_side in requested_sides + ] + + moved_source_uids = [ + _resolve_rigid_source_uid( + entry.get("moved_object"), + rigid_objects, + field_name="moved_object", + ) + for entry in placement_entries + ] + positions = [ + _vector3(by_uid[source_uid].config.get("init_pos", [0.0, 0.0, 0.0])) + for source_uid in moved_source_uids + ] + inferred_sides = [_arm_side_for_position(position) for position in positions] + if set(inferred_sides) == {"left", "right"}: + return inferred_sides + + side_values = [_position_side_axis_value(position) for position in positions] + if side_values[0] <= side_values[1]: + return ["left", "right"] + return ["right", "left"] + + +def _build_relative_placement_step( + *, + entry: Mapping[str, Any], + by_uid: Mapping[str, _SceneObject], + scene_objects: list[_SceneObject], + rigid_objects: list[_SceneObject], + runtime_uids: Mapping[str, str], + forced_side: str | None, +) -> _RelativePlacementStepSpec: + moved_source_uid = _resolve_rigid_source_uid( + entry.get("moved_object"), + rigid_objects, + field_name="moved_object", + ) + relation = _normalize_relative_relation(entry.get("goal_relation")) + reference_source_uid = _resolve_relative_reference_source_uid( + entry.get("reference_object"), + moved_source_uid=moved_source_uid, + scene_objects=scene_objects, + ) + reference_is_initial_pose = moved_source_uid == reference_source_uid + if reference_is_initial_pose and relation not in _SIDE_RELATIONS: + raise ValueError( + "Initial-position self-relative placement only supports directional " + "relations, not inside/on." + ) + + reference_obj = by_uid[reference_source_uid] + if relation == "on" and _is_container_like(reference_obj): + relation = "inside" + + moved_runtime_uid = runtime_uids[moved_source_uid] + reference_runtime_uid = runtime_uids[reference_source_uid] + if moved_runtime_uid == reference_runtime_uid and not reference_is_initial_pose: + raise ValueError( + f"Relative placement produced duplicate runtime uid {moved_runtime_uid!r}." + ) + + release_offset = _relative_release_offset(relation) + high_offset = list(release_offset) + high_offset[2] += _STAGING_Z_DELTA + moved_position = _vector3( + by_uid[moved_source_uid].config.get("init_pos", [0, 0, 0]) + ) + requested_side = _normalize_relative_arm(entry.get("arm")) + active_side = ( + forced_side + if forced_side is not None + else ( + _arm_side_for_position(moved_position) + if requested_side == "auto" + else requested_side + ) + ) + + return _RelativePlacementStepSpec( + moved_source_uid=moved_source_uid, + reference_source_uid=reference_source_uid, + moved_runtime_uid=moved_runtime_uid, + reference_runtime_uid=reference_runtime_uid, + relation=relation, + active_side=active_side, + release_offset=release_offset, + high_offset=high_offset, + reference_is_initial_pose=reference_is_initial_pose, + ) + + +def _validate_relative_placements( + placements: tuple[_RelativePlacementStepSpec, ...], +) -> None: + if not placements: + raise ValueError("Relative placement requires at least one placement.") + moved_source_uids = [placement.moved_source_uid for placement in placements] + if len(moved_source_uids) != len(set(moved_source_uids)): + raise ValueError("Relative placements must use distinct moved_object values.") + if len(placements) == 2: + active_sides = {placement.active_side for placement in placements} + if active_sides != {"left", "right"}: + raise ValueError( + "Dual-arm relative placement requires one left arm and one right arm." + ) + + +def _resolve_rigid_source_uid( + value: Any, + rigid_objects: list[_SceneObject], + *, + field_name: str, +) -> str: + return _resolve_scene_source_uid( + value, + rigid_objects, + field_name=field_name, + ) + + +def _resolve_relative_reference_source_uid( + value: Any, + *, + moved_source_uid: str, + scene_objects: list[_SceneObject], +) -> str: + if value is not None: + text = str(value).strip() + normalized = text.lower().replace("-", "_").replace(" ", "_") + if normalized in _SELF_REFERENCE_VALUES: + return moved_source_uid + return _resolve_scene_source_uid( + value, + scene_objects, + field_name="reference_object", + ) + + +def _resolve_scene_source_uid( + value: Any, + scene_objects: list[_SceneObject], + *, + field_name: str, +) -> str: + if value is None: + raise ValueError(f"LLM response missing required {field_name}.") + text = str(value).strip() + by_uid = {obj.source_uid: obj for obj in scene_objects} + if text in by_uid: + return text + + normalized = _normalize_runtime_uid(text) + matches = [ + obj.source_uid + for obj in scene_objects + if _normalize_runtime_uid(obj.source_uid) == normalized + or _base_name(obj) == normalized + or _candidate_relative_runtime_uid(obj) == normalized + ] + if len(matches) == 1: + return matches[0] + if not matches: + raise ValueError(f"LLM returned unknown {field_name}: {text!r}.") + raise ValueError( + f"LLM returned ambiguous {field_name}: {text!r}; candidates: {matches}." + ) + + +def _normalize_relative_relation(value: Any) -> str: + relation = str(value or "").strip().lower().replace("-", "_").replace(" ", "_") + relation = _RELATION_ALIASES.get(relation, relation) + if relation not in _RELATIVE_RELATIONS: + raise ValueError( + f"Unsupported relative placement relation {value!r}; expected one " + f"of {sorted(_RELATIVE_RELATIONS)}." + ) + return relation + + +def _normalize_relative_arm(value: Any) -> str: + if value is None: + return "auto" + text = str(value).strip().lower().replace("-", "_").replace(" ", "_") + if text in { + "", + "auto", + "automatic", + "unspecified", + "none", + "null", + "default", + "自动", + "默认", + "未指定", + "不指定", + }: + return "auto" + if text in { + "left", + "left_arm", + "left_ur5", + "左", + "左臂", + "左机械臂", + "左手", + "左手臂", + }: + return "left" + if text in { + "right", + "right_arm", + "right_ur5", + "右", + "右臂", + "右机械臂", + "右手", + "右手臂", + }: + return "right" + raise ValueError( + f"Unsupported relative placement arm {value!r}; expected 'left', " + "'right', or 'auto'." + ) + + +def _relative_release_offset(relation: str) -> list[float]: + relation = _normalize_relative_relation(relation) + if relation == "inside": + return [0.0, 0.0, _SIDE_RELEASE_Z_OFFSET] + if relation == "on": + return [0.0, 0.0, _ON_RELEASE_Z_OFFSET] + if relation in _SIDE_RELATIONS: + x_offset, y_offset = _side_relation_xy_offsets(relation) + return [x_offset, y_offset, _SIDE_RELEASE_Z_OFFSET] + raise ValueError(f"Unsupported relative placement relation: {relation!r}.") + + +def _side_relation_xy_offsets(relation: str) -> tuple[float, float]: + relation = _normalize_relative_relation(relation) + left_y = _ROBOT_VIEW_LEFT_WORLD_Y_SIGN * _SIDE_RELATION_DISTANCE + right_y = -_ROBOT_VIEW_LEFT_WORLD_Y_SIGN * _SIDE_RELATION_DISTANCE + front_x = _ROBOT_VIEW_FRONT_WORLD_X_SIGN * _SIDE_RELATION_DISTANCE + behind_x = -_ROBOT_VIEW_FRONT_WORLD_X_SIGN * _SIDE_RELATION_DISTANCE + if relation == "left_of": + return 0.0, left_y + if relation == "right_of": + return 0.0, right_y + if relation == "front_of": + return front_x, 0.0 + if relation == "behind": + return behind_x, 0.0 + if relation == "front_left_of": + return front_x, left_y + if relation == "back_left_of": + return behind_x, left_y + if relation == "front_right_of": + return front_x, right_y + if relation == "back_right_of": + return behind_x, right_y + raise ValueError(f"Unsupported side relation: {relation!r}.") + + +def _relative_runtime_uid_mapping( + rigid_objects: list[_SceneObject], +) -> dict[str, str]: + candidates: dict[str, str] = {} + for obj in rigid_objects: + if _is_container_like(obj): + candidates[obj.source_uid] = _container_runtime_uid(obj) + continue + + base = _target_runtime_suffix(_base_name(obj)) + base_count = sum( + 1 for other in rigid_objects if _base_name(other) == _base_name(obj) + ) + candidates[obj.source_uid] = ( + base if base_count == 1 else _normalize_runtime_uid(obj.source_uid) + ) + + counts: dict[str, int] = {} + for runtime_uid in candidates.values(): + counts[runtime_uid] = counts.get(runtime_uid, 0) + 1 + return { + source_uid: ( + runtime_uid + if counts[runtime_uid] == 1 + else _normalize_runtime_uid(source_uid) + ) + for source_uid, runtime_uid in candidates.items() + } + + +def _relative_scene_runtime_uid_mapping( + scene_objects: list[_SceneObject], + *, + table_source_uid: str, +) -> dict[str, str]: + candidates: dict[str, str] = {} + rigid_runtime_uids = _relative_runtime_uid_mapping( + [obj for obj in scene_objects if obj.source_role == "rigid_object"] + ) + for obj in scene_objects: + if obj.source_uid == table_source_uid: + candidates[obj.source_uid] = "table" + elif obj.source_role == "rigid_object": + candidates[obj.source_uid] = rigid_runtime_uids[obj.source_uid] + else: + candidates[obj.source_uid] = _candidate_relative_runtime_uid(obj) + + counts: dict[str, int] = {} + for runtime_uid in candidates.values(): + counts[runtime_uid] = counts.get(runtime_uid, 0) + 1 + return { + source_uid: ( + runtime_uid + if source_uid == table_source_uid or counts[runtime_uid] == 1 + else _normalize_runtime_uid(source_uid) + ) + for source_uid, runtime_uid in candidates.items() + } + + +def _candidate_relative_runtime_uid(obj: _SceneObject) -> str: + if _is_container_like(obj): + return _container_runtime_uid(obj) + return _target_runtime_suffix(_base_name(obj)) + + +def _is_container_like(obj: _SceneObject) -> bool: + return any(keyword in _object_text(obj) for keyword in _CONTAINER_KEYWORDS) + + +def _string_list(value: Any) -> list[str]: + if not isinstance(value, list): + return [] + return [str(item).strip() for item in value if str(item).strip()] + + +def _default_relative_task_summary( + moved_uid: str, + reference_uid: str, + relation: str, +) -> str: + return ( + f"Move `{moved_uid}` so its final state is " + f"{_relative_relation_phrase(relation)} `{reference_uid}`." + ) + + +def _default_relative_plan_summary( + placements: Sequence[_RelativePlacementStepSpec], +) -> str: + if len(placements) == 1: + placement = placements[0] + return _default_relative_task_summary( + placement.moved_runtime_uid, + placement.reference_runtime_uid, + placement.relation, + ) + placement_text = "; ".join( + f"use the {placement.active_side} UR5 to move " + f"`{placement.moved_runtime_uid}` " + f"{_relative_relation_phrase(placement.relation)} " + f"`{placement.reference_runtime_uid}`" + for placement in placements + ) + return f"Use both UR5 arms for a dual-arm relative placement: {placement_text}." + + +def _default_relative_action_sketch( + placements: Sequence[_RelativePlacementStepSpec], +) -> list[str]: + if len(placements) == 1: + placement = placements[0] + return [ + f"grasp {placement.moved_runtime_uid}", + ( + f"move above the {placement.relation} release pose relative to " + f"{placement.reference_runtime_uid}" + ), + "place at the release pose with PlaceAction", + ] + sketch = ["grasp both moved objects with their assigned arms"] + for placement in placements: + sketch.extend( + [ + ( + f"use {placement.active_side}_arm to move " + f"{placement.moved_runtime_uid} above the release pose relative " + f"to {placement.reference_runtime_uid}" + ), + f"place {placement.moved_runtime_uid} with PlaceAction", + ] + ) + return sketch + + +def _relative_relation_phrase(relation: str) -> str: + relation = _normalize_relative_relation(relation) + if relation == "inside": + return "inside" + if relation == "on": + return "on top of" + if relation == "left_of": + return "to the left of" + if relation == "right_of": + return "to the right of" + if relation == "front_of": + return "in front of" + if relation == "behind": + return "behind" + if relation == "front_left_of": + return "to the front-left of" + if relation == "back_left_of": + return "to the back-left of" + if relation == "front_right_of": + return "to the front-right of" + if relation == "back_right_of": + return "to the back-right of" + raise ValueError(f"Unsupported relative placement relation: {relation!r}.") + + +def _build_ur5_basket_bundle( + *, + scene_dir: Path, + source_config: Mapping[str, Any], + roles: _BasketTaskRoles, + project_name: str, + task_name: str, + target_body_scale: float | list[float] | tuple[float, float, float], + target_replacements: Sequence[_ResolvedTargetReplacement], + max_episodes: int, + max_episode_steps: int, + mesh_normalizer: MeshFrameNormalizer, +) -> dict[str, Any]: + scene_objects = _collect_scene_objects(source_config) + by_uid = {obj.source_uid: obj for obj in scene_objects} + replacement_by_source_uid = { + replacement.source_uid: replacement for replacement in target_replacements + } + object_scale = _target_body_scale_vector(target_body_scale) + container_scale = _source_body_scale(by_uid[roles.container_source_uid]) + task_source_uids = { + roles.container_source_uid, + roles.left_target_source_uid, + roles.right_target_source_uid, + } + extra_rigid_objects = [ + obj + for obj in scene_objects + if obj.source_role == "rigid_object" and obj.source_uid not in task_source_uids + ] + extra_background_objects = [ + obj + for obj in scene_objects + if obj.source_role == "background" and obj.source_uid != roles.table_source_uid + ] + table_config = _make_background_config( + scene_dir, + by_uid[roles.table_source_uid], + mesh_normalizer, + ) + table_top_z = _mesh_config_world_zmax(table_config) + robot_init_z = _dual_ur5_init_z_from_table_top(table_top_z) + + gym_config = { + "id": "AtomicActionsAgent-v3", + "max_episodes": int(max_episodes), + "max_episode_steps": int(max_episode_steps), + "env": { + "extensions": _make_extensions_config(roles), + "events": _make_events_config(roles), + "observations": _make_observations_config(), + "dataset": _make_dataset_config(project_name, roles), + }, + "robot": _make_dual_ur5_robot_config(robot_init_z=robot_init_z), + "sensor": _make_sensor_config(), + "light": _make_light_config(), + "background": [ + table_config, + _make_container_background_config( + scene_dir, + by_uid[roles.container_source_uid], + roles.container_runtime_uid, + container_scale, + mesh_normalizer, + ), + *[ + _make_extra_background_config(scene_dir, obj, mesh_normalizer) + for obj in extra_background_objects + ], + ], + "rigid_object": [ + _make_target_object_config( + scene_dir, + by_uid[roles.right_target_source_uid], + roles.right_target_runtime_uid, + object_scale, + mesh_normalizer, + replacement_by_source_uid.get(roles.right_target_source_uid), + ), + _make_target_object_config( + scene_dir, + by_uid[roles.left_target_source_uid], + roles.left_target_runtime_uid, + object_scale, + mesh_normalizer, + replacement_by_source_uid.get(roles.left_target_source_uid), + ), + *[ + _make_extra_rigid_object_config( + scene_dir, + obj, + _source_body_scale(obj), + mesh_normalizer, + ) + for obj in extra_rigid_objects + ], + ], + } + _apply_tabletop_z_placement(gym_config, table_top_z) + return { + "gym_config": gym_config, + "agent_config": make_agent_config(), + "task_prompt": make_basket_task_prompt(task_name, project_name, roles), + "basic_background": make_basket_basic_background(project_name, roles), + "atom_actions": make_basket_atom_actions_prompt(roles), + "summary": { + "mode": "basket_template", + "left_target": roles.left_target_runtime_uid, + "right_target": roles.right_target_runtime_uid, + "container": roles.container_runtime_uid, + "target_replacements": [ + { + "source_uid": replacement.source_uid, + "prompt": replacement.prompt, + "output_dir_name": replacement.output_dir_name, + "mesh_path": replacement.mesh_path.as_posix(), + "runtime_noun": replacement.runtime_noun, + "reused": replacement.reused, + } + for replacement in target_replacements + ], + }, + } + + +def _attach_coacd_cache_summary(bundle: dict[str, Any]) -> None: + from embodichain.gen_sim.action_agent_pipeline.generation.coacd_cache import ( + prewarm_coacd_cache_for_gym_config, + ) + + bundle.setdefault("summary", {})["coacd_cache"] = ( + prewarm_coacd_cache_for_gym_config(bundle["gym_config"]) + ) + + +def _attach_mesh_normalization_summary( + bundle: dict[str, Any], + mesh_normalizer: MeshFrameNormalizer, +) -> None: + reports = mesh_normalizer.reports + if reports: + bundle.setdefault("summary", {})["normalized_meshes"] = reports + + +def _build_relative_placement_bundle( + *, + scene_dir: Path, + source_config: Mapping[str, Any], + spec: _RelativePlacementSpec, + project_name: str, + task_name: str, + target_body_scale: float | list[float] | tuple[float, float, float], + max_episodes: int, + max_episode_steps: int, + mesh_normalizer: MeshFrameNormalizer, +) -> dict[str, Any]: + scene_objects = _collect_scene_objects(source_config) + background_objects = [ + obj for obj in scene_objects if obj.source_role == "background" + ] + rigid_objects = [obj for obj in scene_objects if obj.source_role == "rigid_object"] + by_uid = {obj.source_uid: obj for obj in scene_objects} + runtime_uids = _relative_scene_runtime_uid_mapping( + scene_objects, + table_source_uid=spec.table_source_uid, + ) + moved_source_uids = {placement.moved_source_uid for placement in spec.placements} + reference_runtime_uids = { + placement.reference_runtime_uid for placement in spec.placements + } + registered_runtime_uids = sorted( + {runtime_uids[obj.source_uid] for obj in rigid_objects} | reference_runtime_uids + ) + dynamic_rigid_objects = [ + obj for obj in rigid_objects if obj.source_uid in moved_source_uids + ] + static_scene_objects = [ + obj for obj in rigid_objects if obj.source_uid not in moved_source_uids + ] + object_scale = _target_body_scale_vector(target_body_scale) + table_config = _make_background_config( + scene_dir, + by_uid[spec.table_source_uid], + mesh_normalizer, + ) + table_top_z = _mesh_config_world_zmax(table_config) + robot_init_z = _dual_ur5_init_z_from_table_top(table_top_z) + + gym_config = { + "id": "AtomicActionsAgent-v3", + "max_episodes": int(max_episodes), + "max_episode_steps": int(max_episode_steps), + "env": { + "extensions": {}, + "events": _make_relative_events_config(spec, registered_runtime_uids), + "observations": _make_observations_config(), + "dataset": {}, + }, + "robot": _make_dual_ur5_robot_config(robot_init_z=robot_init_z), + "sensor": _make_sensor_config(), + "light": _make_light_config(), + "background": [ + table_config, + *[ + _make_relative_background_object_config( + scene_dir, + obj, + runtime_uids[obj.source_uid], + max_convex_hull_num=_relative_static_background_max_convex_hull_num( + runtime_uids[obj.source_uid], + spec, + ), + mesh_normalizer=mesh_normalizer, + ) + for obj in static_scene_objects + ], + *[ + _make_extra_background_config( + scene_dir, + obj, + mesh_normalizer, + runtime_uid=runtime_uids[obj.source_uid], + ) + for obj in background_objects + if obj.source_uid != spec.table_source_uid + ], + ], + "rigid_object": [ + _make_relative_rigid_object_config( + scene_dir=scene_dir, + obj=obj, + runtime_uid=runtime_uids[obj.source_uid], + body_scale=object_scale, + max_convex_hull_num=_relative_rigid_object_max_convex_hull_num( + runtime_uids[obj.source_uid], + spec, + ), + mesh_normalizer=mesh_normalizer, + ) + for obj in dynamic_rigid_objects + ], + } + _apply_tabletop_z_placement(gym_config, table_top_z) + spec = _with_self_relative_absolute_targets(spec, gym_config) + gym_config["env"]["extensions"] = _make_relative_extensions_config(spec) + gym_config["env"]["dataset"] = _make_relative_dataset_config(project_name, spec) + return { + "gym_config": gym_config, + "agent_config": make_agent_config(), + "task_prompt": make_relative_task_prompt(task_name, project_name, spec), + "basic_background": make_relative_basic_background(project_name, spec), + "atom_actions": make_relative_atom_actions_prompt(spec), + "summary": _make_relative_summary(spec), + } + + +def _with_self_relative_absolute_targets( + spec: _RelativePlacementSpec, + gym_config: Mapping[str, Any], +) -> _RelativePlacementSpec: + if not any(placement.reference_is_initial_pose for placement in spec.placements): + return spec + + generated_positions = { + str(obj.get("uid")): _clean_vector3(obj.get("init_pos", [0.0, 0.0, 0.0])) + for obj in gym_config.get("rigid_object", []) + } + placements = tuple( + _with_self_relative_absolute_target(placement, generated_positions) + for placement in spec.placements + ) + primary = placements[0] + return _RelativePlacementSpec( + table_source_uid=spec.table_source_uid, + moved_source_uid=primary.moved_source_uid, + reference_source_uid=primary.reference_source_uid, + moved_runtime_uid=primary.moved_runtime_uid, + reference_runtime_uid=primary.reference_runtime_uid, + relation=primary.relation, + active_side=primary.active_side, + task_description=spec.task_description, + task_prompt_summary=spec.task_prompt_summary, + basic_background_notes=spec.basic_background_notes, + action_sketch=spec.action_sketch, + release_offset=primary.release_offset, + high_offset=primary.high_offset, + placements=placements, + reference_is_initial_pose=primary.reference_is_initial_pose, + release_position=primary.release_position, + high_position=primary.high_position, + ) + + +def _with_self_relative_absolute_target( + placement: _RelativePlacementStepSpec, + generated_positions: Mapping[str, list[float]], +) -> _RelativePlacementStepSpec: + if not placement.reference_is_initial_pose: + return placement + initial_position = generated_positions.get(placement.moved_runtime_uid) + if initial_position is None: + raise ValueError( + "Generated relative config missing self-relative moved object " + f"{placement.moved_runtime_uid!r}." + ) + release_position = _offset_position(initial_position, placement.release_offset) + high_position = _offset_position(initial_position, placement.high_offset) + return _RelativePlacementStepSpec( + moved_source_uid=placement.moved_source_uid, + reference_source_uid=placement.reference_source_uid, + moved_runtime_uid=placement.moved_runtime_uid, + reference_runtime_uid=placement.reference_runtime_uid, + relation=placement.relation, + active_side=placement.active_side, + release_offset=placement.release_offset, + high_offset=placement.high_offset, + reference_is_initial_pose=True, + release_position=release_position, + high_position=high_position, + ) + + +def _offset_position( + position: Sequence[float], + offset: Sequence[float], +) -> list[float]: + return [ + round(float(position[index]) + float(offset[index]), 6) for index in range(3) + ] + + +def _target_body_scale_vector( + target_body_scale: float | list[float] | tuple[float, float, float], +) -> list[float]: + if isinstance(target_body_scale, (int, float)): + value = float(target_body_scale) + return [value, value, value] + return _clean_vector3(target_body_scale) + + +def _source_body_scale(obj: _SceneObject) -> list[float]: + return _clean_vector3(obj.config.get("body_scale", [1.0, 1.0, 1.0])) + + +def _make_relative_summary(spec: _RelativePlacementSpec) -> dict[str, Any]: + if len(spec.placements) == 1: + return { + "mode": "relative_placement", + "moved_object": spec.moved_runtime_uid, + "reference_object": spec.reference_runtime_uid, + "relation": spec.relation, + "active_arm": f"{spec.active_side}_arm", + "release_offset": spec.release_offset, + } + return { + "mode": "dual_arm_relative_placement", + "placements": [ + { + "moved_object": placement.moved_runtime_uid, + "reference_object": placement.reference_runtime_uid, + "relation": placement.relation, + "active_arm": f"{placement.active_side}_arm", + "release_offset": placement.release_offset, + } + for placement in spec.placements + ], + } + + +def _dual_ur5_init_z_from_table_top(table_top_z: float | None) -> float: + if table_top_z is None: + return _DUAL_UR5_LEGACY_INIT_Z + + init_z = table_top_z + _DUAL_UR5_TABLETOP_CLEARANCE - _DUAL_UR5_ARM_COMPONENT_Z + return round(init_z, 6) + + +def _apply_tabletop_z_placement( + gym_config: dict[str, Any], + table_top_z: float | None, +) -> None: + if table_top_z is None: + return + target_bottom_z = float(table_top_z) + _TABLETOP_OBJECT_CLEARANCE + for obj in _iter_generated_scene_object_configs(gym_config): + if obj.get("uid") == "table": + continue + mesh_min_z = _mesh_config_local_zmin_after_rotation(obj) + if mesh_min_z is None: + continue + init_pos = _clean_vector3(obj.get("init_pos", [0.0, 0.0, 0.0])) + init_pos[2] = round(target_bottom_z - mesh_min_z, 6) + obj["init_pos"] = init_pos + + +def _iter_generated_scene_object_configs( + gym_config: Mapping[str, Any], +) -> list[dict[str, Any]]: + objects: list[dict[str, Any]] = [] + for section in ("background", "rigid_object"): + value = gym_config.get(section, []) + if isinstance(value, Mapping): + value = [value] + if not isinstance(value, list): + continue + objects.extend(obj for obj in value if isinstance(obj, dict)) + return objects + + +def _mesh_config_world_zmax(obj_config: Mapping[str, Any]) -> float | None: + bounds = _mesh_config_world_z_bounds(obj_config) + if bounds is None: + return None + return bounds[1] + + +def _mesh_config_local_zmin_after_rotation( + obj_config: Mapping[str, Any], +) -> float | None: + shape = obj_config.get("shape", {}) + if not isinstance(shape, Mapping): + return None + mesh_path = shape.get("fpath") + if not isinstance(mesh_path, str): + return None + vertices = _load_mesh_vertices(Path(mesh_path).expanduser().resolve()) + if not vertices: + return None + + matrix = _mesh_config_transform_matrix( + obj_config, + translation=[0.0, 0.0, 0.0], + ) + return min(_transform_point(matrix, vertex)[2] for vertex in vertices) + + +def _mesh_config_world_z_bounds( + obj_config: Mapping[str, Any], +) -> tuple[float, float] | None: + shape = obj_config.get("shape", {}) + if not isinstance(shape, Mapping): + return None + mesh_path = shape.get("fpath") + if not isinstance(mesh_path, str): + return None + vertices = _load_mesh_vertices(Path(mesh_path).expanduser().resolve()) + if not vertices: + return None + + matrix = _mesh_config_transform_matrix(obj_config) + z_values = [_transform_point(matrix, vertex)[2] for vertex in vertices] + return (min(z_values), max(z_values)) + + +def _mesh_config_transform_matrix( + obj_config: Mapping[str, Any], + *, + translation: list[float] | None = None, +) -> list[list[float]]: + scale = _vector3(obj_config.get("body_scale", [1.0, 1.0, 1.0])) + init_local_pose = obj_config.get("init_local_pose") + if init_local_pose is not None and translation is None: + root_matrix = _matrix4(init_local_pose) + else: + root_matrix = _euler_xyz_degrees_matrix( + _vector3(obj_config.get("init_rot", [0.0, 0.0, 0.0])), + ( + _vector3(obj_config.get("init_pos", [0.0, 0.0, 0.0])) + if translation is None + else translation + ), + ) + return _matrix_multiply(root_matrix, _scale_matrix4(scale)) + + +def _resolve_table_mesh_world_zmax( + scene_dir: Path, + table_obj: _SceneObject, +) -> float | None: + shape = table_obj.config.get("shape", {}) + if not isinstance(shape, Mapping): + return None + if shape.get("shape_type") != "Mesh" or not shape.get("fpath"): + return None + + mesh_path = _source_asset_path(scene_dir, str(shape["fpath"])) + try: + vertices = _load_mesh_vertices(mesh_path) + except ( + OSError, + ValueError, + json.JSONDecodeError, + UnicodeDecodeError, + struct.error, + ): + return None + if not vertices: + return None + + world_matrix = _table_mesh_world_matrix(table_obj.config) + return max(_transform_point(world_matrix, vertex)[2] for vertex in vertices) + + +def _source_asset_path(scene_dir: Path, fpath: str) -> Path: + raw_path = Path(fpath) + if raw_path.is_absolute(): + return raw_path.resolve() + + scene_candidate = (scene_dir / raw_path).resolve() + if scene_candidate.exists(): + return scene_candidate + + repo_candidate = (_repo_root() / raw_path).resolve() + if repo_candidate.exists(): + return repo_candidate + return scene_candidate + + +def _load_mesh_vertices(mesh_path: Path) -> list[tuple[float, float, float]] | None: + if mesh_path.suffix.lower() == ".glb": + try: + return list(_iter_glb_world_position_vertices(mesh_path)) + except ( + OSError, + ValueError, + json.JSONDecodeError, + UnicodeDecodeError, + struct.error, + ): + return _load_mesh_vertices_with_trimesh(mesh_path) + return _load_mesh_vertices_with_trimesh(mesh_path) + + +def _load_mesh_vertices_with_trimesh( + mesh_path: Path, +) -> list[tuple[float, float, float]] | None: + try: + import trimesh + except ImportError: + return None + + try: + scene_or_mesh = trimesh.load(str(mesh_path), force="scene") + if hasattr(scene_or_mesh, "to_geometry"): + mesh = scene_or_mesh.to_geometry() + elif hasattr(scene_or_mesh, "dump"): + mesh = scene_or_mesh.dump(concatenate=True) + else: + mesh = scene_or_mesh + except Exception: + return None + vertices = getattr(mesh, "vertices", None) + if vertices is None or len(vertices) == 0: + return None + return [ + (float(vertex[0]), float(vertex[1]), float(vertex[2])) for vertex in vertices + ] + + +def _iter_glb_world_position_vertices( + mesh_path: Path, +): + doc, binary_chunk = _read_glb(mesh_path) + nodes = doc.get("nodes", []) + if not isinstance(nodes, list): + raise ValueError("GLB nodes must be a list.") + + scenes = doc.get("scenes", []) + if scenes: + scene_index = int(doc.get("scene", 0)) + root_node_ids = scenes[scene_index].get("nodes", []) + else: + root_node_ids = list(range(len(nodes))) + + stack = [(int(node_id), _identity_matrix4()) for node_id in root_node_ids] + while stack: + node_id, parent_matrix = stack.pop() + node = nodes[node_id] + node_matrix = _matrix_multiply(parent_matrix, _gltf_node_matrix(node)) + mesh_index = node.get("mesh") + if mesh_index is not None: + for vertex in _iter_gltf_mesh_position_vertices( + doc, + binary_chunk, + int(mesh_index), + ): + yield _transform_point(node_matrix, vertex) + for child_id in node.get("children", []) or []: + stack.append((int(child_id), node_matrix)) + + +def _read_glb(mesh_path: Path) -> tuple[dict[str, Any], bytes]: + data = mesh_path.read_bytes() + if len(data) < 20: + raise ValueError("GLB file is too small.") + + magic, version, total_length = struct.unpack_from("<4sII", data, 0) + if magic != b"glTF" or version != 2: + raise ValueError("Only GLB version 2 files are supported.") + if total_length > len(data): + raise ValueError("GLB length header exceeds file size.") + + doc: dict[str, Any] | None = None + binary_chunk = b"" + offset = 12 + while offset + 8 <= total_length: + chunk_length, chunk_type = struct.unpack_from(" total_length: + raise ValueError("GLB chunk exceeds file size.") + chunk = data[offset:chunk_end] + offset = chunk_end + if chunk_type == _GLB_JSON_CHUNK_TYPE: + doc = json.loads(chunk.decode("utf-8").rstrip("\x00 ")) + elif chunk_type == _GLB_BINARY_CHUNK_TYPE: + binary_chunk = chunk + + if doc is None: + raise ValueError("GLB file does not contain a JSON chunk.") + return doc, binary_chunk + + +def _iter_gltf_mesh_position_vertices( + doc: Mapping[str, Any], + binary_chunk: bytes, + mesh_index: int, +): + meshes = doc.get("meshes", []) + accessors = doc.get("accessors", []) + mesh = meshes[mesh_index] + for primitive in mesh.get("primitives", []) or []: + attributes = primitive.get("attributes", {}) + position_accessor = attributes.get("POSITION") + if position_accessor is None: + continue + if int(position_accessor) >= len(accessors): + raise ValueError("POSITION accessor index is out of range.") + yield from _iter_gltf_accessor_vec3(doc, binary_chunk, int(position_accessor)) + + +def _iter_gltf_accessor_vec3( + doc: Mapping[str, Any], + binary_chunk: bytes, + accessor_index: int, +): + accessor = doc["accessors"][accessor_index] + if accessor.get("sparse"): + raise ValueError("Sparse GLB accessors are not supported.") + if accessor.get("type") != "VEC3": + raise ValueError("POSITION accessor must be VEC3.") + if "bufferView" not in accessor: + raise ValueError("POSITION accessor must reference a bufferView.") + + component_type = int(accessor["componentType"]) + if component_type not in _GLTF_COMPONENT_FORMATS: + raise ValueError(f"Unsupported GLB component type: {component_type}.") + component_format, component_size = _GLTF_COMPONENT_FORMATS[component_type] + component_count = _GLTF_TYPE_COMPONENT_COUNTS[accessor["type"]] + buffer_view = doc["bufferViews"][int(accessor["bufferView"])] + if int(buffer_view.get("buffer", 0)) != 0: + raise ValueError("Only GLB embedded binary buffers are supported.") + + stride = int(buffer_view.get("byteStride", component_size * component_count)) + offset = int(buffer_view.get("byteOffset", 0)) + int(accessor.get("byteOffset", 0)) + element_format = "<" + component_format * component_count + for index in range(int(accessor["count"])): + values = struct.unpack_from( + element_format, + binary_chunk, + offset + index * stride, + ) + yield (float(values[0]), float(values[1]), float(values[2])) + + +def _table_mesh_world_matrix(table_config: Mapping[str, Any]) -> list[list[float]]: + scale = _vector3(table_config.get("body_scale", [1.0, 1.0, 1.0])) + init_local_pose = table_config.get("init_local_pose") + if init_local_pose is not None: + root_matrix = _matrix4(init_local_pose) + else: + root_matrix = _euler_xyz_degrees_matrix( + _vector3(table_config.get("init_rot", [0.0, 0.0, 0.0])), + _vector3(table_config.get("init_pos", [0.0, 0.0, 0.0])), + ) + return _matrix_multiply(root_matrix, _scale_matrix4(scale)) + + +def _gltf_node_matrix(node: Mapping[str, Any]) -> list[list[float]]: + if "matrix" in node: + values = [float(value) for value in node["matrix"]] + if len(values) != 16: + raise ValueError("GLB node matrix must contain 16 values.") + return [[values[column * 4 + row] for column in range(4)] for row in range(4)] + + translation = [float(value) for value in node.get("translation", [0.0, 0.0, 0.0])] + scale = [float(value) for value in node.get("scale", [1.0, 1.0, 1.0])] + rotation = [float(value) for value in node.get("rotation", [0.0, 0.0, 0.0, 1.0])] + if len(translation) != 3 or len(scale) != 3 or len(rotation) != 4: + raise ValueError("Invalid GLB node TRS transform.") + + x, y, z, w = rotation + xx, yy, zz = x * x, y * y, z * z + xy, xz, yz = x * y, x * z, y * z + wx, wy, wz = w * x, w * y, w * z + matrix = [ + [ + (1.0 - 2.0 * (yy + zz)) * scale[0], + (2.0 * (xy - wz)) * scale[1], + (2.0 * (xz + wy)) * scale[2], + translation[0], + ], + [ + (2.0 * (xy + wz)) * scale[0], + (1.0 - 2.0 * (xx + zz)) * scale[1], + (2.0 * (yz - wx)) * scale[2], + translation[1], + ], + [ + (2.0 * (xz - wy)) * scale[0], + (2.0 * (yz + wx)) * scale[1], + (1.0 - 2.0 * (xx + yy)) * scale[2], + translation[2], + ], + [0.0, 0.0, 0.0, 1.0], + ] + return matrix + + +def _euler_xyz_degrees_matrix( + rotation_deg: Sequence[float], + translation: Sequence[float], +) -> list[list[float]]: + rx, ry, rz = (math.radians(float(value)) for value in rotation_deg) + cx, sx = math.cos(rx), math.sin(rx) + cy, sy = math.cos(ry), math.sin(ry) + cz, sz = math.cos(rz), math.sin(rz) + rot_x = [ + [1.0, 0.0, 0.0, 0.0], + [0.0, cx, -sx, 0.0], + [0.0, sx, cx, 0.0], + [0.0, 0.0, 0.0, 1.0], + ] + rot_y = [ + [cy, 0.0, sy, 0.0], + [0.0, 1.0, 0.0, 0.0], + [-sy, 0.0, cy, 0.0], + [0.0, 0.0, 0.0, 1.0], + ] + rot_z = [ + [cz, -sz, 0.0, 0.0], + [sz, cz, 0.0, 0.0], + [0.0, 0.0, 1.0, 0.0], + [0.0, 0.0, 0.0, 1.0], + ] + matrix = _matrix_multiply(_matrix_multiply(rot_z, rot_y), rot_x) + matrix[0][3] = float(translation[0]) + matrix[1][3] = float(translation[1]) + matrix[2][3] = float(translation[2]) + return matrix + + +def _identity_matrix4() -> list[list[float]]: + return [ + [1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 0.0], + [0.0, 0.0, 0.0, 1.0], + ] + + +def _scale_matrix4(scale: Sequence[float]) -> list[list[float]]: + return [ + [float(scale[0]), 0.0, 0.0, 0.0], + [0.0, float(scale[1]), 0.0, 0.0], + [0.0, 0.0, float(scale[2]), 0.0], + [0.0, 0.0, 0.0, 1.0], + ] + + +def _matrix4(value: Any) -> list[list[float]]: + if not isinstance(value, (list, tuple)) or len(value) != 4: + raise ValueError(f"Expected a 4x4 matrix, got {value!r}.") + matrix = [] + for row in value: + if not isinstance(row, (list, tuple)) or len(row) != 4: + raise ValueError(f"Expected a 4x4 matrix, got {value!r}.") + matrix.append([float(item) for item in row]) + return matrix + + +def _matrix_multiply( + left: Sequence[Sequence[float]], + right: Sequence[Sequence[float]], +) -> list[list[float]]: + return [ + [ + sum( + float(left[row][inner]) * float(right[inner][column]) + for inner in range(4) + ) + for column in range(4) + ] + for row in range(4) + ] + + +def _transform_point( + matrix: Sequence[Sequence[float]], + point: Sequence[float], +) -> tuple[float, float, float]: + x, y, z = (float(point[0]), float(point[1]), float(point[2])) + return ( + float(matrix[0][0]) * x + + float(matrix[0][1]) * y + + float(matrix[0][2]) * z + + float(matrix[0][3]), + float(matrix[1][0]) * x + + float(matrix[1][1]) * y + + float(matrix[1][2]) * z + + float(matrix[1][3]), + float(matrix[2][0]) * x + + float(matrix[2][1]) * y + + float(matrix[2][2]) * z + + float(matrix[2][3]), + ) + + +def _make_extensions_config(roles: _BasketTaskRoles) -> dict[str, Any]: + return { + "agent_arm_slots": { + "left": { + "arm": "right_arm", + "eef": "right_eef", + }, + "right": { + "arm": "left_arm", + "eef": "left_eef", + }, + }, + "arm_aim_yaw_offset": { + "left": 3.141592653589793, + "right": 0.0, + }, + "gripper_open_state": [0.0], + "gripper_close_state": [0.04], + "ignore_terminations_during_agent": True, + "viewer_camera_uid": "cam_high", + "agent_success": { + "op": "all", + "terms": [ + _object_in_container_success( + roles.left_target_runtime_uid, + roles.container_runtime_uid, + ), + _object_in_container_success( + roles.right_target_runtime_uid, + roles.container_runtime_uid, + ), + ], + }, + } + + +def _object_in_container_success(object_uid: str, container_uid: str) -> dict[str, Any]: + return { + "type": "object_in_container", + "object": object_uid, + "container": container_uid, + "radius": 0.2, + "min_z_offset": -0.05, + "max_z_offset": 0.35, + } + + +def _make_relative_extensions_config(spec: _RelativePlacementSpec) -> dict[str, Any]: + return { + "agent_arm_slots": { + "left": { + "arm": "right_arm", + "eef": "right_eef", + }, + "right": { + "arm": "left_arm", + "eef": "left_eef", + }, + }, + "arm_aim_yaw_offset": { + "left": 3.141592653589793, + "right": 0.0, + }, + "gripper_open_state": [0.0], + "gripper_close_state": [0.04], + "ignore_terminations_during_agent": True, + "viewer_camera_uid": "cam_high", + "agent_success": _make_relative_success_spec(spec), + } + + +def _make_relative_success_spec(spec: _RelativePlacementSpec) -> dict[str, Any]: + if len(spec.placements) == 1: + return _make_relative_placement_success_spec(spec.placements[0]) + return { + "op": "all", + "terms": [ + _make_relative_placement_success_spec(placement) + for placement in spec.placements + ], + } + + +def _make_relative_placement_success_spec( + placement: _RelativePlacementStepSpec, +) -> dict[str, Any]: + if placement.relation == "inside": + return _object_in_container_success( + placement.moved_runtime_uid, + placement.reference_runtime_uid, + ) + if placement.relation == "on": + return { + "type": "object_on_object", + "object": placement.moved_runtime_uid, + "support": placement.reference_runtime_uid, + "xy_radius": 0.08, + "min_z_offset": 0.02, + "max_z_offset": 0.35, + } + + if placement.reference_is_initial_pose: + if placement.release_position is None: + raise ValueError( + "Self-relative success requires an absolute release position." + ) + return { + "op": "all", + "terms": [ + *_absolute_xy_success_terms( + placement.moved_runtime_uid, + placement.release_position, + ), + { + "type": "object_not_fallen", + "object": placement.moved_runtime_uid, + "max_tilt": 0.9, + }, + ], + } + + return { + "op": "all", + "terms": [ + *_relative_xy_success_terms(placement), + { + "type": "object_not_fallen", + "object": placement.moved_runtime_uid, + "max_tilt": 0.9, + }, + ], + } + + +def _absolute_xy_success_terms( + object_uid: str, + position: Sequence[float], +) -> list[dict[str, Any]]: + return [ + { + "type": "object_axis_near", + "object": object_uid, + "axis": axis, + "target": float(position[index]), + "tolerance": 0.05, + } + for index, axis in enumerate(("x", "y")) + ] + + +def _relative_xy_success_terms( + placement: _RelativePlacementStepSpec, +) -> list[dict[str, Any]]: + x_offset, y_offset = _side_relation_xy_offsets(placement.relation) + return [ + { + "type": "object_axis_offset_near", + "object": placement.moved_runtime_uid, + "reference": placement.reference_runtime_uid, + "axis": axis, + "offset": offset, + "tolerance": 0.05 if offset else 0.06, + } + for axis, offset in (("x", x_offset), ("y", y_offset)) + ] + + +def _make_relative_events_config( + spec: _RelativePlacementSpec, + registered_runtime_uids: list[str], +) -> dict[str, Any]: + return { + "record_camera": _record_camera_event_config(), + "validation_cameras": _validation_cameras_event_config(), + "prepare_extra_attr": { + "func": "prepare_extra_attr", + "mode": "reset", + "params": { + "attrs": [ + { + "name": "object_lengths", + "mode": "callable", + "entity_uids": "all_objects", + "func_name": "compute_object_length", + "func_kwargs": { + "is_svd_frame": True, + "sample_points": 5000, + }, + }, + ] + }, + }, + "register_info_to_env": { + "func": "register_info_to_env", + "mode": "reset", + "params": { + "registry": [ + _object_registry_entry(uid) + for uid in sorted(registered_runtime_uids) + ], + "registration": "affordance_datas", + "sim_update": True, + }, + }, + } + + +def _make_events_config(roles: _BasketTaskRoles) -> dict[str, Any]: + return { + "record_camera": _record_camera_event_config(), + "validation_cameras": _validation_cameras_event_config(), + "prepare_extra_attr": { + "func": "prepare_extra_attr", + "mode": "reset", + "params": { + "attrs": [ + { + "name": "object_lengths", + "mode": "callable", + "entity_uids": "all_objects", + "func_name": "compute_object_length", + "func_kwargs": { + "is_svd_frame": True, + "sample_points": 5000, + }, + }, + ] + }, + }, + "register_info_to_env": { + "func": "register_info_to_env", + "mode": "reset", + "params": { + "registry": [ + _object_registry_entry(roles.left_target_runtime_uid), + _object_registry_entry(roles.right_target_runtime_uid), + _object_registry_entry(roles.container_runtime_uid), + ], + "registration": "affordance_datas", + "sim_update": True, + }, + }, + } + + +def _record_camera_event_config() -> dict[str, Any]: + camera = _make_sensor_config()[0] + extrinsics = camera["extrinsics"] + return { + "func": "record_camera_data", + "mode": "interval", + "interval_step": 1, + "params": { + "name": "record_cam_high", + "resolution": [camera["width"], camera["height"]], + "intrinsics": camera["intrinsics"], + "eye": extrinsics["eye"], + "target": extrinsics["target"], + "up": extrinsics["up"], + }, + } + + +def _validation_cameras_event_config() -> dict[str, Any]: + return { + "func": "validation_cameras", + "mode": "trigger", + "params": {}, + } + + +def _object_registry_entry(uid: str) -> dict[str, Any]: + return { + "entity_cfg": { + "uid": uid, + }, + "pose_register_params": { + "compute_relative": False, + "compute_pose_object_to_arena": True, + "to_matrix": True, + }, + } + + +def _make_observations_config() -> dict[str, Any]: + return { + "norm_robot_eef_joint": { + "func": "normalize_robot_joint_data", + "mode": "modify", + "name": "robot/qpos", + "params": { + "joint_ids": [12, 13, 14, 15], + }, + } + } + + +def _make_dataset_config( + project_name: str, + roles: _BasketTaskRoles, +) -> dict[str, Any]: + left_target_text = _left_target_text(roles) + right_target_text = _right_target_text(roles) + target_description = _target_task_description_text(roles) + return { + "lerobot": { + "func": "LeRobotRecorder", + "mode": "save", + "params": { + "robot_meta": { + "robot_type": "DualUR5", + "control_freq": 25, + }, + "instruction": { + "lang": ( + f"Use the left UR5 to place the left {left_target_text} into " + f"the {roles.container_runtime_uid}, then use the right " + f"UR5 to place the right {right_target_text} into the " + f"{roles.container_runtime_uid}." + ), + }, + "extra": { + "scene_type": project_name, + "task_description": ( + f"Dual UR5 {target_description}-to-container placement" + ), + "data_type": "sim", + }, + "use_videos": True, + }, + } + } + + +def _make_relative_dataset_config( + project_name: str, + spec: _RelativePlacementSpec, +) -> dict[str, Any]: + return { + "lerobot": { + "func": "LeRobotRecorder", + "mode": "save", + "params": { + "robot_meta": { + "robot_type": "DualUR5", + "control_freq": 25, + }, + "instruction": { + "lang": _relative_dataset_instruction(spec), + }, + "extra": { + "scene_type": project_name, + "task_description": spec.task_description, + "data_type": "sim", + }, + "use_videos": True, + }, + } + } + + +def _relative_dataset_instruction(spec: _RelativePlacementSpec) -> str: + if len(spec.placements) == 1: + placement = spec.placements[0] + return ( + f"Use the {placement.active_side} UR5 to move " + f"{placement.moved_runtime_uid} " + f"{_relative_relation_phrase(placement.relation)} " + f"{placement.reference_runtime_uid}." + ) + return " ".join( + f"Use the {placement.active_side} UR5 to move " + f"{placement.moved_runtime_uid} " + f"{_relative_relation_phrase(placement.relation)} " + f"{placement.reference_runtime_uid}." + for placement in spec.placements + ) + + +def _make_dual_ur5_robot_config(*, robot_init_z: float) -> dict[str, Any]: + return { + "uid": "DualUR5", + "urdf_cfg": { + "fname": "dual_ur5_dh_pgi_basket", + "components": [ + { + "component_type": "left_arm", + "urdf_path": "UniversalRobots/UR5/UR5.urdf", + "transform": [ + [0.0, -1.0, 0.0, -0.3], + [1.0, 0.0, 0.0, -1.45], + [0.0, 0.0, 1.0, 0.4], + [0.0, 0.0, 0.0, 1.0], + ], + }, + { + "component_type": "left_hand", + "urdf_path": "DH_PGI_140_80/DH_PGI_140_80.urdf", + }, + { + "component_type": "right_arm", + "urdf_path": "UniversalRobots/UR5/UR5.urdf", + "transform": [ + [0.0, -1.0, 0.0, 0.3], + [1.0, 0.0, 0.0, -1.45], + [0.0, 0.0, 1.0, 0.4], + [0.0, 0.0, 0.0, 1.0], + ], + }, + { + "component_type": "right_hand", + "urdf_path": "DH_PGI_140_80/DH_PGI_140_80.urdf", + }, + ], + }, + "init_pos": [_DUAL_UR5_ROTATED_INIT_X, 0.0, float(robot_init_z)], + "init_rot": [0.0, 0.0, _DUAL_UR5_ROTATED_INIT_YAW_DEGREES], + "init_qpos": [ + 0, + 0, + -1.57, + -1.57, + 1.57, + 1.57, + -1.57, + -1.57, + -1.57, + -1.57, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + ], + "drive_pros": { + "stiffness": { + "LEFT_JOINT[1-6]": 10000.0, + "RIGHT_JOINT[1-6]": 10000.0, + "LEFT_GRIPPER_FINGER[1-2]_JOINT_1": 100.0, + "RIGHT_GRIPPER_FINGER[1-2]_JOINT_1": 100.0, + }, + "damping": { + "LEFT_JOINT[1-6]": 1000.0, + "RIGHT_JOINT[1-6]": 1000.0, + "LEFT_GRIPPER_FINGER[1-2]_JOINT_1": 10.0, + "RIGHT_GRIPPER_FINGER[1-2]_JOINT_1": 10.0, + }, + "max_effort": { + "LEFT_JOINT[1-6]": 100000.0, + "RIGHT_JOINT[1-6]": 100000.0, + "LEFT_GRIPPER_FINGER[1-2]_JOINT_1": 1000.0, + "RIGHT_GRIPPER_FINGER[1-2]_JOINT_1": 1000.0, + }, + }, + "control_parts": { + "left_arm": ["LEFT_JOINT[1-6]"], + "left_eef": ["LEFT_GRIPPER_FINGER[1-2]_JOINT_1"], + "right_arm": ["RIGHT_JOINT[1-6]"], + "right_eef": ["RIGHT_GRIPPER_FINGER[1-2]_JOINT_1"], + }, + "solver_cfg": { + "left_arm": _ur5_solver_config("left"), + "right_arm": _ur5_solver_config("right"), + }, + } + + +def _ur5_solver_config(side: str) -> dict[str, Any]: + return { + "class_type": "PytorchSolver", + "end_link_name": f"{side}_ee_link", + "root_link_name": f"{side}_base_link", + "tcp": [ + [1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 0.16], + [0.0, 0.0, 0.0, 1.0], + ], + } + + +def _make_sensor_config() -> list[dict[str, Any]]: + return [ + { + "sensor_type": "Camera", + "uid": "cam_high", + "width": 960, + "height": 540, + "intrinsics": [420, 420, 480, 270], + "extrinsics": { + "pos": [0.4, 0.0, 2.2], + "eye": [0.6, 0.0, 3.3], + "target": [0.0, 0.0, 0.75], + "up": [1.0, 0.0, 0.0], + }, + }, + { + "sensor_type": "Camera", + "uid": "cam_wrist_left", + "width": 640, + "height": 480, + "intrinsics": [600, 600, 320, 240], + "extrinsics": { + "parent": "left_ee_link", + "pos": [0.0, 0.12, 0.08], + "quat": [ + -0.0012598701, + -0.029051816664441618998, + 0.9094039177564813, + 0.41489627504330695, + ], + }, + }, + { + "sensor_type": "Camera", + "uid": "cam_wrist_right", + "width": 640, + "height": 480, + "intrinsics": [600, 600, 320, 240], + "extrinsics": { + "parent": "right_ee_link", + "pos": [0.0, 0.12, 0.08], + "quat": [ + -0.0012598701, + -0.029051816664441618998, + 0.9094039177564813, + 0.41489627504330695, + ], + }, + }, + ] + + +def _make_light_config() -> dict[str, Any]: + return { + "direct": [ + { + "uid": "main_light", + "light_type": "point", + "color": [1.0, 1.0, 1.0], + "intensity": 40.0, + "init_pos": [0.0, -0.4, 2.2], + "radius": 10.0, + } + ] + } + + +def _make_background_config( + scene_dir: Path, + obj: _SceneObject, + mesh_normalizer: MeshFrameNormalizer, +) -> dict[str, Any]: + shape = _make_shape_config(scene_dir, obj.config, mesh_normalizer=mesh_normalizer) + return { + "uid": "table", + "shape": shape, + "attrs": dict(_BACKGROUND_ATTRS), + "body_scale": _clean_vector3(obj.config.get("body_scale", [1.0, 1.0, 1.0])), + "body_type": "kinematic", + "init_pos": _clean_vector3(obj.config.get("init_pos", [0.0, 0.0, 0.0])), + "init_rot": _clean_vector3(obj.config.get("init_rot", [0.0, 0.0, 0.0])), + "max_convex_hull_num": _role_limited_max_convex_hull_num( + obj, + _BACKGROUND_MAX_CONVEX_HULL_NUM, + ), + } + + +def _make_extra_background_config( + scene_dir: Path, + obj: _SceneObject, + mesh_normalizer: MeshFrameNormalizer, + body_scale: Any | None = None, + runtime_uid: str | None = None, +) -> dict[str, Any]: + shape = _make_shape_config(scene_dir, obj.config, mesh_normalizer=mesh_normalizer) + config = { + "uid": runtime_uid or _normalize_runtime_uid(obj.source_uid), + "shape": shape, + "attrs": copy.deepcopy(dict(obj.config.get("attrs", _BACKGROUND_ATTRS))), + "body_scale": _clean_vector3( + obj.config.get("body_scale", [1.0, 1.0, 1.0]) + if body_scale is None + else body_scale + ), + "body_type": str(obj.config.get("body_type", "static")), + "init_pos": _clean_vector3(obj.config.get("init_pos", [0.0, 0.0, 0.0])), + "init_rot": _clean_vector3(obj.config.get("init_rot", [0.0, 0.0, 0.0])), + "max_convex_hull_num": _role_limited_max_convex_hull_num( + obj, + _BACKGROUND_MAX_CONVEX_HULL_NUM, + ), + } + return config + + +def _make_target_object_config( + scene_dir: Path, + obj: _SceneObject, + runtime_uid: str, + target_scale: list[float], + mesh_normalizer: MeshFrameNormalizer, + replacement: _ResolvedTargetReplacement | None = None, +) -> dict[str, Any]: + config = _make_rigid_object_config( + scene_dir, + obj, + runtime_uid, + target_scale, + max_convex_hull_num=_TARGET_MAX_CONVEX_HULL_NUM, + mesh_fpath=replacement.mesh_path if replacement else None, + mesh_normalizer=mesh_normalizer, + ) + config["body_type"] = "dynamic" + return config + + +def _make_container_object_config( + scene_dir: Path, + obj: _SceneObject, + runtime_uid: str, + body_scale: Any, + mesh_normalizer: MeshFrameNormalizer, +) -> dict[str, Any]: + return _make_rigid_object_config( + scene_dir, + obj, + runtime_uid, + body_scale, + max_convex_hull_num=_role_limited_max_convex_hull_num( + obj, + _CONTAINER_MAX_CONVEX_HULL_NUM, + ), + mesh_normalizer=mesh_normalizer, + ) + + +def _make_container_background_config( + scene_dir: Path, + obj: _SceneObject, + runtime_uid: str, + body_scale: Any, + mesh_normalizer: MeshFrameNormalizer, +) -> dict[str, Any]: + config = _make_container_object_config( + scene_dir, + obj, + runtime_uid, + body_scale, + mesh_normalizer, + ) + config["body_type"] = "kinematic" + return config + + +def _make_relative_background_object_config( + scene_dir: Path, + obj: _SceneObject, + runtime_uid: str, + *, + max_convex_hull_num: int, + mesh_normalizer: MeshFrameNormalizer, +) -> dict[str, Any]: + config = _make_rigid_object_config( + scene_dir, + obj, + runtime_uid, + _source_body_scale(obj), + max_convex_hull_num=max_convex_hull_num, + mesh_normalizer=mesh_normalizer, + ) + config["body_type"] = "kinematic" + return config + + +def _make_extra_rigid_object_config( + scene_dir: Path, + obj: _SceneObject, + body_scale: Any, + mesh_normalizer: MeshFrameNormalizer, +) -> dict[str, Any]: + return _make_rigid_object_config( + scene_dir, + obj, + _normalize_runtime_uid(obj.source_uid), + body_scale, + max_convex_hull_num=_role_limited_max_convex_hull_num( + obj, + _EXTRA_RIGID_MAX_CONVEX_HULL_NUM, + ), + mesh_normalizer=mesh_normalizer, + ) + + +def _make_relative_rigid_object_config( + *, + scene_dir: Path, + obj: _SceneObject, + runtime_uid: str, + body_scale: Any, + max_convex_hull_num: int, + mesh_normalizer: MeshFrameNormalizer, +) -> dict[str, Any]: + if max_convex_hull_num == _TARGET_MAX_CONVEX_HULL_NUM: + resolved_max_convex_hull_num = max_convex_hull_num + else: + resolved_max_convex_hull_num = _role_limited_max_convex_hull_num( + obj, + max_convex_hull_num, + ) + config = _make_rigid_object_config( + scene_dir, + obj, + runtime_uid, + body_scale, + max_convex_hull_num=resolved_max_convex_hull_num, + mesh_normalizer=mesh_normalizer, + ) + config["body_type"] = "dynamic" + return config + + +def _make_rigid_object_config( + scene_dir: Path, + obj: _SceneObject, + runtime_uid: str, + body_scale: Any, + max_convex_hull_num: int, + mesh_fpath: str | Path | None = None, + mesh_normalizer: MeshFrameNormalizer | None = None, +) -> dict[str, Any]: + shape = _make_shape_config( + scene_dir, + obj.config, + mesh_fpath=mesh_fpath, + mesh_normalizer=mesh_normalizer, + ) + config = { + "uid": runtime_uid, + "shape": shape, + "attrs": dict(_RIGID_OBJECT_ATTRS), + "init_pos": _clean_vector3(obj.config.get("init_pos", [0.0, 0.0, 0.0])), + "init_rot": _clean_vector3(obj.config.get("init_rot", [0.0, 0.0, 0.0])), + "body_scale": _clean_vector3(body_scale), + "max_convex_hull_num": int(max_convex_hull_num), + } + if "body_type" in obj.config: + config["body_type"] = str(obj.config["body_type"]) + return config + + +def _role_limited_max_convex_hull_num( + obj: _SceneObject, + role_max_convex_hull_num: int, +) -> int: + source_max_convex_hull_num = obj.config.get("max_convex_hull_num") + if source_max_convex_hull_num is None: + return role_max_convex_hull_num + return max(1, min(int(source_max_convex_hull_num), role_max_convex_hull_num)) + + +def _relative_rigid_object_max_convex_hull_num( + runtime_uid: str, + spec: _RelativePlacementSpec, +) -> int: + for placement in spec.placements: + if ( + placement.relation == "inside" + and runtime_uid == placement.reference_runtime_uid + ): + return _CONTAINER_MAX_CONVEX_HULL_NUM + task_uids = { + uid + for placement in spec.placements + for uid in (placement.moved_runtime_uid, placement.reference_runtime_uid) + } + if runtime_uid in task_uids: + return _TARGET_MAX_CONVEX_HULL_NUM + return _EXTRA_RIGID_MAX_CONVEX_HULL_NUM + + +def _relative_static_background_max_convex_hull_num( + runtime_uid: str, + spec: _RelativePlacementSpec, +) -> int: + for placement in spec.placements: + if ( + placement.relation == "inside" + and runtime_uid == placement.reference_runtime_uid + ): + return _CONTAINER_MAX_CONVEX_HULL_NUM + return _BACKGROUND_MAX_CONVEX_HULL_NUM + + +def _make_shape_config( + scene_dir: Path, + source_config: Mapping[str, Any], + *, + mesh_fpath: str | Path | None = None, + mesh_normalizer: MeshFrameNormalizer | None = None, +) -> dict[str, Any]: + shape = copy.deepcopy(dict(source_config.get("shape", {}))) + if mesh_fpath is not None: + shape["shape_type"] = "Mesh" + shape["fpath"] = str(mesh_fpath) + if shape.get("shape_type") == "Mesh" and "fpath" in shape: + mesh_path = Path(_asset_path_for_config(scene_dir, str(shape["fpath"]))) + if mesh_normalizer is not None: + mesh_path = mesh_normalizer.normalize_path(mesh_path) + shape["fpath"] = mesh_path.as_posix() + shape.setdefault("compute_uv", False) + return shape + + +def _asset_path_for_config(scene_dir: Path, fpath: str) -> str: + raw_path = Path(fpath) + if raw_path.is_absolute(): + return raw_path.resolve().as_posix() + return (scene_dir / raw_path).resolve().as_posix() + + +def _repo_root() -> Path: + current = Path(__file__).resolve() + for parent in current.parents: + if (parent / "setup.py").exists() and (parent / "embodichain").exists(): + return parent + return Path.cwd().resolve() + + +def _validate_bundle(bundle: Mapping[str, Any], roles: _BasketTaskRoles) -> None: + gym_config = bundle["gym_config"] + if gym_config.get("id") != "AtomicActionsAgent-v3": + raise ValueError("Generated gym config must use AtomicActionsAgent-v3.") + if gym_config.get("robot", {}).get("uid") != "DualUR5": + raise ValueError("Generated UR5 basket config must use DualUR5.") + + rigid_uids = {obj["uid"] for obj in gym_config.get("rigid_object", [])} + background_uids = {obj["uid"] for obj in gym_config.get("background", [])} + scene_uids = rigid_uids | background_uids + required_rigid = { + roles.left_target_runtime_uid, + roles.right_target_runtime_uid, + } + if not required_rigid.issubset(rigid_uids): + raise ValueError( + f"Generated rigid objects missing: {sorted(required_rigid - rigid_uids)}" + ) + if roles.container_runtime_uid not in scene_uids: + raise ValueError( + f"Generated scene objects missing container: {roles.container_runtime_uid}" + ) + + success = gym_config["env"]["extensions"]["agent_success"] + for term in success.get("terms", []): + if ( + term.get("object") not in rigid_uids + or term.get("container") not in scene_uids + ): + raise ValueError(f"Invalid success term uid reference: {term}") + + +def _validate_relative_bundle( + bundle: Mapping[str, Any], + spec: _RelativePlacementSpec, +) -> None: + gym_config = bundle["gym_config"] + if gym_config.get("id") != "AtomicActionsAgent-v3": + raise ValueError("Generated gym config must use AtomicActionsAgent-v3.") + if gym_config.get("robot", {}).get("uid") != "DualUR5": + raise ValueError("Generated relative placement config must use DualUR5.") + + rigid_uid_list = [obj["uid"] for obj in gym_config.get("rigid_object", [])] + if len(rigid_uid_list) != len(set(rigid_uid_list)): + raise ValueError(f"Duplicate rigid object runtime uid(s): {rigid_uid_list}") + rigid_uids = set(rigid_uid_list) + background_uids = {obj["uid"] for obj in gym_config.get("background", [])} + scene_uids = rigid_uids | background_uids + moved_required = {placement.moved_runtime_uid for placement in spec.placements} + missing_moved = moved_required - rigid_uids + if missing_moved: + raise ValueError( + f"Generated relative config missing moved rigid object(s): {missing_moved}" + ) + reference_required = { + placement.reference_runtime_uid for placement in spec.placements + } + missing_reference = reference_required - scene_uids + if missing_reference: + raise ValueError( + f"Generated relative config missing reference object(s): {missing_reference}" + ) + + _validate_success_uids( + gym_config["env"]["extensions"]["agent_success"], + rigid_uids=rigid_uids, + scene_uids=scene_uids, + ) + registry = gym_config["env"]["events"]["register_info_to_env"]["params"]["registry"] + registered = {entry["entity_cfg"]["uid"] for entry in registry} + required = moved_required | reference_required + if not required.issubset(registered): + raise ValueError( + f"Relative config registry missing: {sorted(required - registered)}" + ) + + +def _validate_success_uids( + success: Mapping[str, Any], + *, + rigid_uids: set[str], + scene_uids: set[str], +) -> None: + if success.get("op") in {"all", "and", "any", "or"}: + for term in success.get("terms", []): + _validate_success_uids(term, rigid_uids=rigid_uids, scene_uids=scene_uids) + return + + success_type = str(success.get("type", success.get("func", ""))).lower() + if success_type == "object_in_container": + required_keys = ("object", "container") + elif success_type in {"object_on_object", "object_on", "on_object"}: + required_keys = ("object", "support") + elif success_type in { + "object_axis_offset_near", + "object_relative_axis_near", + }: + required_keys = ("object", "reference") + elif success_type in {"object_axis_near", "object_coordinate_near"}: + required_keys = ("object",) + elif success_type in {"object_not_fallen", "not_fallen"}: + required_keys = ("object",) + else: + raise ValueError(f"Unsupported generated success term: {success_type!r}.") + + for key in required_keys: + uid = success.get(key) + valid_uids = rigid_uids if key == "object" else scene_uids + if uid not in valid_uids: + raise ValueError(f"Invalid success uid reference {key}={uid!r}.") + + +def _write_config_bundle( + *, + output_dir: Path, + bundle: Mapping[str, Any], + overwrite: bool, +) -> GeneratedUR5BasketConfigPaths: + paths = GeneratedUR5BasketConfigPaths( + output_dir=output_dir, + gym_config=output_dir / "fast_gym_config.json", + agent_config=output_dir / "agent_config.json", + task_prompt=output_dir / "task_prompt.txt", + basic_background=output_dir / "basic_background.txt", + atom_actions=output_dir / "atom_actions.txt", + summary=dict(bundle.get("summary", {})), + ) + _raise_if_generated_files_exist(output_dir, overwrite) + + output_dir.mkdir(parents=True, exist_ok=True) + _write_json(paths.gym_config, bundle["gym_config"]) + _write_json(paths.agent_config, bundle["agent_config"]) + _write_text(paths.task_prompt, bundle["task_prompt"]) + _write_text(paths.basic_background, bundle["basic_background"]) + _write_text(paths.atom_actions, bundle["atom_actions"]) + return paths + + +def _raise_if_generated_files_exist(output_dir: Path, overwrite: bool) -> None: + if overwrite: + return + output_files = [ + output_dir / "fast_gym_config.json", + output_dir / "agent_config.json", + output_dir / "task_prompt.txt", + output_dir / "basic_background.txt", + output_dir / "atom_actions.txt", + ] + existing = [path for path in output_files if path.exists()] + if existing: + existing_text = ", ".join(path.as_posix() for path in existing) + raise FileExistsError( + f"Generated file(s) already exist: {existing_text}. " + "Pass overwrite=True or --overwrite to replace them." + ) + + +def _write_json(path: Path, data: Mapping[str, Any]) -> None: + path.write_text( + json.dumps(data, ensure_ascii=False, indent=4) + "\n", + encoding="utf-8", + ) + + +def _write_text(path: Path, content: str) -> None: + path.write_text(content.rstrip() + "\n", encoding="utf-8") + + +def _read_json(path: Path) -> dict[str, Any]: + with path.open("r", encoding="utf-8") as file: + return json.load(file) + + +def _vector3(value: Any) -> list[float]: + if not isinstance(value, (list, tuple)) or len(value) != 3: + raise ValueError(f"Expected a 3-vector, got {value!r}.") + return [float(item) for item in value] + + +def _clean_vector3(value: Any) -> list[float]: + cleaned = [] + for item in _vector3(value): + if abs(item - 1.0) < 1e-9: + cleaned.append(1.0) + elif abs(item) < 1e-12: + cleaned.append(0.0) + else: + cleaned.append(item) + return cleaned diff --git a/embodichain/gen_sim/action_agent_pipeline/gym_project_api/image2tabletop_client.py b/embodichain/gen_sim/action_agent_pipeline/gym_project_api/image2tabletop_client.py new file mode 100644 index 00000000..0da04292 --- /dev/null +++ b/embodichain/gen_sim/action_agent_pipeline/gym_project_api/image2tabletop_client.py @@ -0,0 +1,298 @@ +#!/usr/bin/env python3 +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +"""Client for the Image2Tabletop API.""" + +from __future__ import annotations + +import argparse +import json +import re +import shutil +import sys +import time +import zipfile +from pathlib import Path +from tempfile import TemporaryDirectory + +import requests +from requests import exceptions as request_exceptions + +_IMAGE_SUFFIXES = frozenset({".bmp", ".jpeg", ".jpg", ".png", ".webp"}) +_PROJECT_NAME_RE = re.compile(r"^[0-9]+_gym_project$") +_PROJECT_ID_RE = re.compile(r"Image2Tabletop-([0-9]+)-v[0-9]+") +_DEFAULT_SERVER = "http://192.168.3.23:4523" + + +def _repo_root() -> Path: + current = Path(__file__).resolve() + for parent in current.parents: + if (parent / "setup.py").is_file() and (parent / "embodichain").is_dir(): + return parent + return Path.cwd().resolve() + + +_REPO_ROOT = _repo_root() +_DEFAULT_OUTPUT_ROOT = _REPO_ROOT / "gym_project" +_DEFAULT_IMAGE_INPUT = _DEFAULT_OUTPUT_ROOT / "action_agent_pipeline/images" + + +def _server_url(base_url: str, path: str) -> str: + return f"{base_url.rstrip('/')}{path}" + + +def check_health(server: str) -> None: + try: + response = requests.get(_server_url(server, "/health"), timeout=10) + except request_exceptions.ConnectionError as exc: + raise RuntimeError( + f"cannot connect to Image2Tabletop demo API: {server}. " + "Start the server with: " + "python demo_api/server/image2tabletop_api.py --host 0.0.0.0 --port 4523" + ) from exc + response.raise_for_status() + + +def submit_job(server: str, image_path: Path) -> str: + try: + with image_path.open("rb") as image_file: + response = requests.post( + _server_url(server, "/api/image2tabletop/start"), + files={"image": (image_path.name, image_file)}, + timeout=60, + ) + except request_exceptions.ConnectionError as exc: + raise RuntimeError( + f"cannot connect to API server: {server}. " + "Make sure the server is running and listening on this host/port." + ) from exc + response.raise_for_status() + data = response.json() + job_id = data.get("job_id") + if not job_id: + raise RuntimeError(f"API response does not contain job_id: {data}") + return str(job_id) + + +def wait_for_job(server: str, job_id: str, poll_interval: float) -> dict: + status_url = _server_url(server, f"/api/image2tabletop/status/{job_id}") + while True: + response = requests.get(status_url, timeout=30) + response.raise_for_status() + data = response.json() + status = data.get("status") + print(f"[{time.strftime('%H:%M:%S')}] job={job_id} status={status}", flush=True) + if status == "completed": + return data + if status == "failed": + raise RuntimeError(f"job failed: {data}") + time.sleep(poll_interval) + + +def download_zip(server: str, job_id: str, output_dir: Path) -> Path: + output_dir.mkdir(parents=True, exist_ok=True) + zip_path = output_dir / f"{job_id}_formatted_tabletop_scene.zip" + response = requests.get( + _server_url(server, f"/api/image2tabletop/download/{job_id}"), + stream=True, + timeout=300, + ) + response.raise_for_status() + with zip_path.open("wb") as file: + for chunk in response.iter_content(chunk_size=1024 * 1024): + if chunk: + file.write(chunk) + return zip_path + + +def collect_image_paths(image_input: Path) -> list[Path]: + image_input = image_input.expanduser().resolve() + if image_input.is_file(): + if image_input.suffix.lower() not in _IMAGE_SUFFIXES: + raise ValueError(f"unsupported image suffix: {image_input}") + return [image_input] + if image_input.is_dir(): + image_paths = sorted( + path + for path in image_input.iterdir() + if path.is_file() and path.suffix.lower() in _IMAGE_SUFFIXES + ) + if image_paths: + return image_paths + raise FileNotFoundError(f"no supported image files found under: {image_input}") + raise FileNotFoundError(f"image input not found: {image_input}") + + +def extract_gym_project( + zip_path: Path, output_root: Path, job_id: str, overwrite: bool +) -> Path: + output_root = output_root.expanduser().resolve() + output_root.mkdir(parents=True, exist_ok=True) + + with TemporaryDirectory(prefix=f"{job_id}_image2tabletop_") as temp_dir_name: + extract_dir = Path(temp_dir_name).resolve() + _safe_extract_zip(zip_path, extract_dir) + gym_config_paths = sorted(extract_dir.rglob("gym_config.json")) + if not gym_config_paths: + raise FileNotFoundError( + f"gym_config.json not found in downloaded archive: {zip_path}" + ) + if len(gym_config_paths) > 1: + matches = ", ".join(path.as_posix() for path in gym_config_paths) + raise ValueError( + f"multiple gym_config.json files found in archive: {matches}" + ) + + gym_config_path = gym_config_paths[0] + project_name = _infer_project_name(gym_config_path, extract_dir, job_id) + source_root = _infer_source_project_root( + gym_config_path, extract_dir, project_name + ) + destination = output_root / project_name + if destination.exists(): + if not overwrite: + raise FileExistsError( + f"output project already exists: {destination}. " + "Pass --overwrite to replace it." + ) + shutil.rmtree(destination) + shutil.copytree(source_root, destination) + return destination + + +def _safe_extract_zip(zip_path: Path, extract_dir: Path) -> None: + with zipfile.ZipFile(zip_path) as archive: + for member in archive.infolist(): + target_path = (extract_dir / member.filename).resolve() + if not target_path.is_relative_to(extract_dir): + raise RuntimeError(f"unsafe archive member path: {member.filename}") + archive.extractall(extract_dir) + + +def _infer_project_name(gym_config_path: Path, extract_dir: Path, job_id: str) -> str: + for part in gym_config_path.relative_to(extract_dir).parts: + if _PROJECT_NAME_RE.match(part): + return part + + try: + config = json.loads(gym_config_path.read_text(encoding="utf-8")) + except json.JSONDecodeError: + config = {} + project_id = str(config.get("id", "")) + match = _PROJECT_ID_RE.match(project_id) + if match: + return f"{match.group(1)}_gym_project" + return f"{job_id}_gym_project" + + +def _infer_source_project_root( + gym_config_path: Path, extract_dir: Path, project_name: str +) -> Path: + current = extract_dir + for part in gym_config_path.relative_to(extract_dir).parts: + current = current / part + if part == project_name: + return current + return gym_config_path.parent + + +def process_image( + server: str, + image_path: Path, + output_root: Path, + poll_interval: float, + overwrite: bool, +) -> Path: + job_id = submit_job(server, image_path) + print(f"submitted job: {job_id} image={image_path}", flush=True) + wait_for_job(server, job_id, poll_interval) + with TemporaryDirectory( + prefix=f"{job_id}_image2tabletop_download_" + ) as temp_dir_name: + zip_path = download_zip(server, job_id, Path(temp_dir_name)) + project_path = extract_gym_project(zip_path, output_root, job_id, overwrite) + print(f"generated gym project: {project_path}", flush=True) + return project_path + + +def main() -> int: + parser = argparse.ArgumentParser( + description="Submit image files to Image2Tabletop API." + ) + parser.add_argument( + "--server", + default=_DEFAULT_SERVER, + help=f"Image2Tabletop demo API server. Defaults to {_DEFAULT_SERVER}", + ) + parser.add_argument( + "--image", + default=str(_DEFAULT_IMAGE_INPUT), + help=( + "Input image file or directory. Defaults to " + f"{_DEFAULT_IMAGE_INPUT.as_posix()}" + ), + ) + parser.add_argument( + "--output-root", + default=None, + help=f"Directory where generated gym projects are written. Defaults to {_DEFAULT_OUTPUT_ROOT.as_posix()}", + ) + parser.add_argument( + "--download-dir", + dest="output_root", + default=None, + help=argparse.SUPPRESS, + ) + parser.add_argument("--poll-interval", type=float, default=10.0) + parser.add_argument( + "--skip-health-check", + action="store_true", + default=False, + help="Skip GET /health before submitting images.", + ) + parser.add_argument( + "--overwrite", + action="store_true", + default=False, + help="Replace an existing generated gym project with the same name.", + ) + args = parser.parse_args() + + image_paths = collect_image_paths(Path(args.image)) + if not args.skip_health_check: + check_health(args.server) + + project_paths = [] + for image_path in image_paths: + project_paths.append( + process_image( + server=args.server, + image_path=image_path, + output_root=Path(args.output_root or _DEFAULT_OUTPUT_ROOT), + poll_interval=args.poll_interval, + overwrite=args.overwrite, + ) + ) + + print("gym_project paths:", flush=True) + for project_path in project_paths: + print(project_path, flush=True) + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/embodichain/gen_sim/action_agent_pipeline/gym_project_api/prompt2geometry/.gitignore b/embodichain/gen_sim/action_agent_pipeline/gym_project_api/prompt2geometry/.gitignore new file mode 100644 index 00000000..ede6bbf2 --- /dev/null +++ b/embodichain/gen_sim/action_agent_pipeline/gym_project_api/prompt2geometry/.gitignore @@ -0,0 +1,4 @@ +# Python cache +__pycache__/ +*.py[cod] + diff --git a/embodichain/gen_sim/action_agent_pipeline/gym_project_api/prompt2geometry/__init__.py b/embodichain/gen_sim/action_agent_pipeline/gym_project_api/prompt2geometry/__init__.py new file mode 100644 index 00000000..bdac8600 --- /dev/null +++ b/embodichain/gen_sim/action_agent_pipeline/gym_project_api/prompt2geometry/__init__.py @@ -0,0 +1,57 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from .pipeline import ( + Prompt2GeometryRequest, + run_prompt2geometry, +) +from .config import ( + Prompt2GeometryConfig, + load_prompt2geometry_config, +) +from .llm_client import ( + OpenAICompatibleClient, + OpenAICompatibleClientError, +) +from .sam3_client import ( + SAM3Client, + SAM3ClientError, +) +from .sam3d_client import ( + SAM3DClient, + SAM3DClientError, +) +from .zimage_client import ( + ZImageClient, + ZImageClientError, +) + +__all__ = [ + "Prompt2GeometryRequest", + "Prompt2GeometryConfig", + "OpenAICompatibleClient", + "OpenAICompatibleClientError", + "SAM3Client", + "SAM3ClientError", + "SAM3DClient", + "SAM3DClientError", + "ZImageClient", + "ZImageClientError", + "run_prompt2geometry", + "load_prompt2geometry_config", +] diff --git a/embodichain/gen_sim/action_agent_pipeline/gym_project_api/prompt2geometry/config.json b/embodichain/gen_sim/action_agent_pipeline/gym_project_api/prompt2geometry/config.json new file mode 100644 index 00000000..740a5710 --- /dev/null +++ b/embodichain/gen_sim/action_agent_pipeline/gym_project_api/prompt2geometry/config.json @@ -0,0 +1,21 @@ +{ + "services": { + "zimage": { + "base_url": "http://192.168.3.23:5013" + }, + "sam3": { + "base_url": "http://192.168.3.23:5015" + }, + "sam3d": { + "base_url": "http://192.168.3.23:5016" + } + }, + "llm": { + "openai_compatible": { + "api_key": "sk-7hjyRgBLrhUYUSCpLgPSARk8sz1Sc2vZ2bnt3fy1bkHsI7ak", + "model": "gpt-5.5", + "base_url": "https://airouter.cloud/v1", + "timeout_s": 120 + } + } +} diff --git a/embodichain/gen_sim/action_agent_pipeline/gym_project_api/prompt2geometry/config.py b/embodichain/gen_sim/action_agent_pipeline/gym_project_api/prompt2geometry/config.py new file mode 100644 index 00000000..cf7dda1d --- /dev/null +++ b/embodichain/gen_sim/action_agent_pipeline/gym_project_api/prompt2geometry/config.py @@ -0,0 +1,109 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +import json +import os +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +from embodichain.gen_sim.action_agent_pipeline.utils.llm_config import ( + get_openai_compatible_llm_config, +) + +__all__ = ["Prompt2GeometryConfig", "load_prompt2geometry_config"] + +DEFAULT_CONFIG_PATH = Path(__file__).resolve().parent / "config.json" + + +@dataclass(frozen=True) +class Prompt2GeometryConfig: + """Prompt2Geometry runtime configuration.""" + + zimage_base_url: str + sam3_base_url: str + sam3d_base_url: str + llm_api_key: str + llm_model: str + llm_base_url: str + llm_timeout_s: float + + +def load_prompt2geometry_config( + config_path: Path | None = None, +) -> Prompt2GeometryConfig: + """Load prompt2geometry config from a local JSON file and environment.""" + path = (config_path or DEFAULT_CONFIG_PATH).expanduser().resolve() + if not path.is_file(): + raise FileNotFoundError(f"Prompt2Geometry config not found: {path}") + raw = json.loads(path.read_text(encoding="utf-8")) + services = _mapping(raw.get("services"), "services") + llm = _mapping( + _mapping(raw.get("llm"), "llm").get("openai_compatible"), + "llm.openai_compatible", + ) + shared_llm = get_openai_compatible_llm_config( + required=False, + require_base_url=False, + ) + + return Prompt2GeometryConfig( + zimage_base_url=_env_or_config( + "PROMPT2GEOMETRY_ZIMAGE_BASE_URL", + _service_base_url(services, "zimage"), + ), + sam3_base_url=_env_or_config( + "PROMPT2GEOMETRY_SAM3_BASE_URL", + _service_base_url(services, "sam3"), + ), + sam3d_base_url=_env_or_config( + "PROMPT2GEOMETRY_SAM3D_BASE_URL", + _service_base_url(services, "sam3d"), + ), + llm_api_key=_env_or_config( + "PROMPT2GEOMETRY_LLM_API_KEY", + str(shared_llm.get("api_key") or llm.get("api_key", "")), + ), + llm_model=_env_or_config( + "PROMPT2GEOMETRY_LLM_MODEL", + str(shared_llm.get("model") or llm.get("model", "")), + ), + llm_base_url=_env_or_config( + "PROMPT2GEOMETRY_LLM_BASE_URL", + str(shared_llm.get("base_url") or llm.get("base_url", "")), + ).rstrip("/"), + llm_timeout_s=float( + os.getenv("PROMPT2GEOMETRY_LLM_TIMEOUT_S") + or llm.get("timeout_s", 120.0) + ), + ) + + +def _service_base_url(services: dict[str, Any], name: str) -> str: + section = _mapping(services.get(name), f"services.{name}") + return str(section.get("base_url", "")).rstrip("/") + + +def _env_or_config(env_name: str, config_value: str) -> str: + return str(os.getenv(env_name) or config_value).strip() + + +def _mapping(value: Any, name: str) -> dict[str, Any]: + if not isinstance(value, dict): + raise ValueError(f"Prompt2Geometry config key {name} must be an object.") + return value diff --git a/embodichain/gen_sim/action_agent_pipeline/gym_project_api/prompt2geometry/dimensions.py b/embodichain/gen_sim/action_agent_pipeline/gym_project_api/prompt2geometry/dimensions.py new file mode 100644 index 00000000..3c0dec17 --- /dev/null +++ b/embodichain/gen_sim/action_agent_pipeline/gym_project_api/prompt2geometry/dimensions.py @@ -0,0 +1,128 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +import time +from typing import Any + +try: + from .llm_client import OpenAICompatibleClient +except ImportError: + from llm_client import OpenAICompatibleClient + +__all__ = ["DIMENSION_ESTIMATION_SYSTEM_PROMPT", "estimate_real_dimensions"] + + +DIMENSION_ESTIMATION_SYSTEM_PROMPT = """ + +You are a careful real-world object size estimation assistant. + + + +Estimate the plausible real-world bounding-box dimensions of one physical object +from the user's object description. + + + +- Units are meters. +- length_m is the object's longest horizontal dimension. +- width_m is the object's shorter horizontal dimension. +- height_m is the vertical dimension when the object is in its common upright pose. +- Use common real-world size priors for everyday objects. +- If the object category is ambiguous, choose a conservative typical tabletop size. +- Do not include decorative background, shadows, or image canvas in the dimensions. + + + +{ + "length_m": 0.08, + "width_m": 0.08, + "height_m": 0.08, + "confidence": 0.7, + "reason": "A typical apple is roughly 8 cm across." +} + + + +- Output JSON only. Do not include markdown or text outside JSON. +- length_m, width_m, height_m, and confidence must be numbers. +- length_m, width_m, and height_m must be positive. +- confidence must be between 0 and 1. +- Keep reason short and specific. + +""".strip() + + +def estimate_real_dimensions( + *, + object_prompt: str, + client: OpenAICompatibleClient, + max_attempts: int | None = None, +) -> dict[str, Any]: + """Estimate real-world object dimensions with schema validation and retry.""" + messages = [ + {"role": "system", "content": DIMENSION_ESTIMATION_SYSTEM_PROMPT}, + { + "role": "user", + "content": ( + "Object description:\n" + f"{object_prompt.strip()}\n\n" + "Return the dimensions JSON only." + ), + }, + ] + attempt = 1 + while max_attempts is None or attempt <= max_attempts: + try: + raw = client.chat_json(messages=messages) + return _validate_dimension_output(raw) + except Exception: + attempt += 1 + time.sleep(1.0) + continue + raise ValueError( + "Failed to estimate object dimensions after " + f"{max_attempts} attempts." + ) + + +def _validate_dimension_output(raw: dict[str, Any]) -> dict[str, Any]: + allowed = {"length_m", "width_m", "height_m", "confidence", "reason"} + extra = set(raw) - allowed + if extra: + raise ValueError(f"Unexpected dimension keys: {sorted(extra)}") + result: dict[str, Any] = {} + for key in ("length_m", "width_m", "height_m"): + value = raw.get(key) + if not isinstance(value, int | float): + raise ValueError(f"{key} must be a number.") + value = float(value) + if value <= 0: + raise ValueError(f"{key} must be positive.") + result[key] = value + confidence = raw.get("confidence") + if not isinstance(confidence, int | float): + raise ValueError("confidence must be a number.") + confidence = float(confidence) + if confidence < 0 or confidence > 1: + raise ValueError("confidence must be between 0 and 1.") + reason = raw.get("reason") + if not isinstance(reason, str) or not reason.strip(): + raise ValueError("reason must be a non-empty string.") + result["confidence"] = confidence + result["reason"] = reason.strip() + return result diff --git a/embodichain/gen_sim/action_agent_pipeline/gym_project_api/prompt2geometry/llm_client.py b/embodichain/gen_sim/action_agent_pipeline/gym_project_api/prompt2geometry/llm_client.py new file mode 100644 index 00000000..d3a0f826 --- /dev/null +++ b/embodichain/gen_sim/action_agent_pipeline/gym_project_api/prompt2geometry/llm_client.py @@ -0,0 +1,134 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +import json +from typing import Any +from urllib.error import HTTPError, URLError +from urllib.request import Request, urlopen + +from embodichain.gen_sim.action_agent_pipeline.utils.llm_usage import record_llm_usage + +__all__ = ["OpenAICompatibleClient", "OpenAICompatibleClientError"] + + +class OpenAICompatibleClientError(RuntimeError): + """Raised when an OpenAI-compatible chat request fails.""" + + +class OpenAICompatibleClient: + """Small OpenAI-compatible chat completions client.""" + + def __init__( + self, + *, + api_key: str, + model: str, + base_url: str, + timeout_s: float = 120.0, + usage_stage: str | None = None, + ): + if not api_key.strip(): + raise ValueError("LLM api_key must be non-empty.") + if not model.strip(): + raise ValueError("LLM model must be non-empty.") + if not base_url.strip(): + raise ValueError("LLM base_url must be non-empty.") + self.api_key = api_key + self.model = model + self.base_url = base_url.rstrip("/") + self.timeout_s = timeout_s + self.usage_stage = usage_stage or "prompt2geometry.chat_json" + + def chat_json(self, *, messages: list[dict[str, str]]) -> dict[str, Any]: + """Call chat completions and return the decoded JSON response content.""" + payload = { + "model": self.model, + "messages": messages, + "temperature": 0, + "response_format": {"type": "json_object"}, + } + request = Request( + f"{self.base_url}/chat/completions", + data=json.dumps(payload).encode("utf-8"), + headers={ + "Authorization": f"Bearer {self.api_key}", + "Content-Type": "application/json", + "Accept": "application/json", + }, + method="POST", + ) + try: + with urlopen(request, timeout=self.timeout_s) as response: + body = response.read().decode("utf-8") + except HTTPError as exc: + detail = exc.read().decode("utf-8", errors="replace") + raise OpenAICompatibleClientError( + f"LLM request failed with HTTP {exc.code}: {detail}" + ) from exc + except URLError as exc: + raise OpenAICompatibleClientError( + f"LLM server is unreachable at {request.full_url}: {exc.reason}" + ) from exc + except TimeoutError as exc: + raise OpenAICompatibleClientError( + f"LLM request timed out after {self.timeout_s}s." + ) from exc + + try: + decoded = json.loads(body) + choice = decoded["choices"][0] + content = choice["message"]["content"] + except (KeyError, IndexError, TypeError, json.JSONDecodeError) as exc: + raise OpenAICompatibleClientError( + f"LLM response has unsupported format: {body}" + ) from exc + record_llm_usage( + stage=self.usage_stage, + provider="openai_compatible_http", + model=str(decoded.get("model") or self.model), + usage=decoded.get("usage") if isinstance(decoded, dict) else None, + request_id=str(decoded.get("id")) if decoded.get("id") else None, + finish_reason=( + str(choice.get("finish_reason")) + if isinstance(choice, dict) and choice.get("finish_reason") + else None + ), + raw_usage=( + decoded.get("usage") + if isinstance(decoded, dict) and isinstance(decoded.get("usage"), dict) + else None + ), + ) + if not isinstance(content, str): + raise OpenAICompatibleClientError("LLM message content must be a string.") + return _parse_json_text(content) + + +def _parse_json_text(content: str) -> dict[str, Any]: + stripped = content.strip() + if stripped.startswith("```"): + lines = stripped.splitlines() + if lines and lines[0].startswith("```"): + lines = lines[1:] + if lines and lines[-1].startswith("```"): + lines = lines[:-1] + stripped = "\n".join(lines).strip() + parsed = json.loads(stripped) + if not isinstance(parsed, dict): + raise ValueError("LLM output must be a JSON object.") + return parsed diff --git a/embodichain/gen_sim/action_agent_pipeline/gym_project_api/prompt2geometry/mesh_scaling.py b/embodichain/gen_sim/action_agent_pipeline/gym_project_api/prompt2geometry/mesh_scaling.py new file mode 100644 index 00000000..e0418419 --- /dev/null +++ b/embodichain/gen_sim/action_agent_pipeline/gym_project_api/prompt2geometry/mesh_scaling.py @@ -0,0 +1,225 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +import json +from pathlib import Path +from typing import Any + +__all__ = ["scale_mesh_to_real_dimensions"] + + +def scale_mesh_to_real_dimensions( + *, + mesh_path: Path, + output_path: Path, + dimensions_m: dict[str, Any], + report_path: Path, +) -> dict[str, Any]: + """Scale a canonical GLB mesh for Blender with object up along -Y. + + glTF/GLB stores assets in y-up coordinates. Blender converts glTF y-up + assets to its z-up scene coordinates during import. The exported vertices + are arranged so that after Blender import the object's original y-up axis + becomes Blender -Y, and the bbox bottom-center is at the world origin. + """ + trimesh = _require_trimesh() + np = _require_numpy() + mesh_path = mesh_path.expanduser().resolve() + output_path = output_path.expanduser().resolve() + report_path = report_path.expanduser().resolve() + scene = trimesh.load(str(mesh_path), force="scene") + mesh = _scene_to_world_mesh(scene) + bounds = _mesh_bounds(mesh) + extents = bounds[1] - bounds[0] + axis_map = _axis_mapping(extents) + target_extents = np.asarray( + [ + dimensions_m[axis_map["x"]], + dimensions_m[axis_map["y"]], + dimensions_m[axis_map["z"]], + ], + dtype=np.float64, + ) + source_max_extent = float(max(extents) or 1.0) + target_max_extent = float(max(target_extents)) + uniform_scale = target_max_extent / source_max_extent + scale = np.asarray([uniform_scale, uniform_scale, uniform_scale], dtype=np.float64) + bottom_center_y_up = _bottom_center_y_up(bounds) + gltf_to_blender = _gltf_y_up_to_blender_z_up_matrix(np) + original_to_blender = _original_y_up_to_blender_negative_y_up_matrix(np) + original_to_export = np.linalg.inv(gltf_to_blender) @ original_to_blender + transform = np.eye(4, dtype=np.float64) + transform[:3, :3] = original_to_export @ np.diag(scale) + transform[:3, 3] = -(original_to_export @ np.diag(scale) @ bottom_center_y_up) + mesh.apply_transform(transform) + exported_bounds = _mesh_bounds(mesh) + blender_bounds = _bounds(_transform_vertices(mesh.vertices, gltf_to_blender)) + output_path.parent.mkdir(parents=True, exist_ok=True) + mesh.export(str(output_path)) + + report = { + "input_mesh_path": str(mesh_path), + "scaled_mesh_path": str(output_path), + "axis_convention": ( + "Input GLB is treated as y-up. After Blender's glTF import, the " + "object's original y-up axis is aligned to Blender -Y. length_m " + "maps to the larger generated horizontal axis among input x/z; " + "width_m maps to the other." + ), + "scaling_policy": ( + "The mesh is scaled uniformly to preserve generated geometry " + "proportions. The source mesh is first considered normalized by " + "its maximum bbox extent; the uniform scale is computed as " + "estimated_max_real_extent / mesh_max_extent." + ), + "origin_policy": ( + "The input y-up bbox bottom-center is subtracted before GLB export. " + "After Blender import, the -Y-up bbox bottom-center is at " + "(0, 0, 0), so its XZ-plane location is (0, 0)." + ), + "axis_map": axis_map, + "estimated_dimensions_m": dimensions_m, + "estimated_target_extents_by_mesh_axes": target_extents.tolist(), + "source_max_extent": source_max_extent, + "estimated_max_real_extent": target_max_extent, + "original_bounds": bounds.tolist(), + "original_extents": extents.tolist(), + "bottom_center_y_up_subtracted": bottom_center_y_up.tolist(), + "gltf_to_blender_matrix": gltf_to_blender.tolist(), + "original_to_blender_matrix": original_to_blender.tolist(), + "original_to_export_matrix": original_to_export.tolist(), + "uniform_scale": uniform_scale, + "applied_transform": transform.tolist(), + "exported_gltf_bounds": exported_bounds.tolist(), + "exported_gltf_extents": (exported_bounds[1] - exported_bounds[0]).tolist(), + "blender_import_bounds": blender_bounds.tolist(), + "blender_import_extents": (blender_bounds[1] - blender_bounds[0]).tolist(), + "blender_import_bottom_center_negative_y_up": _bottom_center_negative_y_up( + blender_bounds + ).tolist(), + } + report_path.parent.mkdir(parents=True, exist_ok=True) + report_path.write_text( + json.dumps(report, indent=2, ensure_ascii=False) + "\n", + encoding="utf-8", + ) + return report + + +def _axis_mapping(extents: Any) -> dict[str, str]: + if float(extents[0]) >= float(extents[2]): + return {"x": "length_m", "y": "height_m", "z": "width_m"} + return {"x": "width_m", "y": "height_m", "z": "length_m"} + + +def _bottom_center_negative_y_up(bounds: Any) -> Any: + np = _require_numpy() + return np.asarray( + [ + 0.5 * (bounds[0][0] + bounds[1][0]), + bounds[1][1], + 0.5 * (bounds[0][2] + bounds[1][2]), + ], + dtype=np.float64, + ) + + +def _bottom_center_y_up(bounds: Any) -> Any: + np = _require_numpy() + return np.asarray( + [ + 0.5 * (bounds[0][0] + bounds[1][0]), + bounds[0][1], + 0.5 * (bounds[0][2] + bounds[1][2]), + ], + dtype=np.float64, + ) + + +def _bounds(vertices: Any) -> Any: + np = _require_numpy() + return np.vstack([vertices.min(axis=0), vertices.max(axis=0)]) + + +def _transform_vertices(vertices: Any, matrix: Any) -> Any: + np = _require_numpy() + vertices_array = np.asarray(vertices, dtype=np.float64) + matrix_array = np.asarray(matrix, dtype=np.float64) + return vertices_array @ matrix_array.T + + +def _gltf_y_up_to_blender_z_up_matrix(np: Any) -> Any: + return np.asarray( + [ + [1.0, 0.0, 0.0], + [0.0, 0.0, -1.0], + [0.0, 1.0, 0.0], + ], + dtype=np.float64, + ) + + +def _original_y_up_to_blender_negative_y_up_matrix(np: Any) -> Any: + return np.asarray( + [ + [1.0, 0.0, 0.0], + [0.0, -1.0, 0.0], + [0.0, 0.0, 1.0], + ], + dtype=np.float64, + ) + + +def _mesh_bounds(mesh: Any) -> Any: + np = _require_numpy() + vertices = np.asarray(mesh.vertices, dtype=np.float64) + if vertices.size == 0: + raise ValueError("Mesh contains no vertices.") + return _bounds(vertices) + + +def _scene_to_world_mesh(scene: Any) -> Any: + """Convert a loaded GLB scene to one world-space mesh. + + This intentionally bakes scene graph transforms into vertex coordinates so + later z-up conversion and origin anchoring are visible to downstream tools + that only inspect mesh vertices. + """ + try: + mesh = scene.dump(concatenate=True) + except Exception as exc: + raise ValueError("Failed to concatenate GLB scene into a mesh.") from exc + if not hasattr(mesh, "vertices") or len(mesh.vertices) == 0: + raise ValueError("GLB scene contains no mesh vertices.") + return mesh + + +def _require_trimesh() -> Any: + try: + import trimesh + except ImportError as exc: + raise ImportError("trimesh is required to scale GLB meshes.") from exc + return trimesh + + +def _require_numpy() -> Any: + try: + import numpy as np + except ImportError as exc: + raise ImportError("numpy is required to scale GLB meshes.") from exc + return np diff --git a/embodichain/gen_sim/action_agent_pipeline/gym_project_api/prompt2geometry/pipeline.py b/embodichain/gen_sim/action_agent_pipeline/gym_project_api/prompt2geometry/pipeline.py new file mode 100644 index 00000000..2154da48 --- /dev/null +++ b/embodichain/gen_sim/action_agent_pipeline/gym_project_api/prompt2geometry/pipeline.py @@ -0,0 +1,589 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +import json +import os +import re +import shutil +import time +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +try: + from .dimensions import estimate_real_dimensions + from .llm_client import OpenAICompatibleClient + from .mesh_scaling import scale_mesh_to_real_dimensions + from .sam3_client import SAM3Client + from .sam3d_client import SAM3DClient + from .schemas import SelectedBox + from .segmentation_outputs import save_segmentation_outputs + from .zimage_client import ZImageClient +except ImportError: + from dimensions import estimate_real_dimensions + from llm_client import OpenAICompatibleClient + from mesh_scaling import scale_mesh_to_real_dimensions + from sam3_client import SAM3Client + from sam3d_client import SAM3DClient + from schemas import SelectedBox + from segmentation_outputs import save_segmentation_outputs + from zimage_client import ZImageClient + +__all__ = ["Prompt2GeometryRequest", "run_prompt2geometry"] + + +@dataclass(frozen=True) +class Prompt2GeometryRequest: + """Request for prompt-to-single-asset geometry generation.""" + + prompt: str + output_root: Path + target_id: str = "asset_0" + request_id: str = "prompt2geometry_asset_0" + output_name: str | None = None + zimage_base_url: str = "http://192.168.3.23:5013" + zimage_width: int = 1024 + zimage_height: int = 1024 + zimage_seed: int = 42 + zimage_num_inference_steps: int = 8 + zimage_prompt_suffix: str = "a complete single object, with pure-black background" + sam3_base_url: str = "http://192.168.3.23:5015" + sam3d_base_url: str = "http://192.168.3.23:5016" + sam3d_seed: int = 42 + llm_api_key: str | None = None + llm_model: str | None = None + llm_base_url: str | None = None + llm_timeout_s: float = 120.0 + verbose: bool = True + + +def run_prompt2geometry(request: Prompt2GeometryRequest) -> dict[str, Any]: + """Run z-image, SAM3 segmentation, and SAM3D generation.""" + output_root = request.output_root.expanduser().resolve() + output_root.mkdir(parents=True, exist_ok=True) + final_glb_path: Path | None = None + success = False + try: + _log_status(request, "start", f"output_root={output_root}") + _write_json( + output_root / "prompt2geometry_request.json", + _request_manifest(request), + ) + + _log_status(request, "z-image", "generating source image") + image_path, zimage_manifest = _generate_image(request, output_root) + _log_status(request, "segmentation", "segmenting generated image") + raw_mask_path, segmentation_manifest = _segment_image( + request, + image_path, + output_root, + ) + _log_status(request, "mask", "checking mask orientation with center prior") + corrected_mask_path = _correct_mask_with_center_prior( + image_path=image_path, + raw_mask_path=raw_mask_path, + output_dir=output_root / "mask_correction", + ) + _log_status(request, "3D-generation", "generating raw mesh") + generation_manifest = _generate_geometry( + request=request, + image_path=image_path, + mask_path=corrected_mask_path, + output_root=output_root, + ) + _log_status(request, "dimensions", "estimating real-world dimensions") + dimension_manifest = _estimate_dimensions(request, output_root) + _log_status(request, "naming", "resolving final GLB file name") + final_glb_path = _final_scaled_glb_path(request, output_root) + _log_status(request, "scale", f"writing final mesh to {final_glb_path.name}") + scaling_manifest = _scale_generated_mesh( + mesh_path=Path(str(generation_manifest["local_glb_path"])), + dimensions_m=dimension_manifest, + output_path=final_glb_path, + output_root=output_root, + ) + manifest = { + "prompt": request.prompt, + "zimage_prompt": _zimage_prompt(request), + "output_root": str(output_root), + "image_path": str(image_path), + "raw_mask_path": str(raw_mask_path), + "corrected_mask_path": str(corrected_mask_path), + "zimage": zimage_manifest, + "sam3_segmentation": segmentation_manifest, + "sam3d_generation": generation_manifest, + "dimension_estimation": dimension_manifest, + "mesh_scaling": scaling_manifest, + "mesh_path": generation_manifest.get("local_glb_path"), + "scaled_mesh_path": scaling_manifest.get("scaled_mesh_path"), + "transform_metadata_path": generation_manifest.get( + "local_transform_metadata_path" + ), + } + _write_json(output_root / "prompt2geometry_result.json", manifest) + success = True + _log_status(request, "done", f"final_glb={final_glb_path}") + return manifest + finally: + _cleanup_output_root(output_root, keep_path=final_glb_path if success else None) + + +def _generate_image( + request: Prompt2GeometryRequest, + output_root: Path, +) -> tuple[Path, dict[str, Any]]: + image_path = output_root / "zimage" / "zimage.png" + client = ZImageClient(base_url=request.zimage_base_url) + manifest = client.generate_png( + prompt=_zimage_prompt(request), + output_path=image_path, + width=request.zimage_width, + height=request.zimage_height, + seed=request.zimage_seed, + num_inference_steps=request.zimage_num_inference_steps, + ) + _write_json(output_root / "zimage" / "zimage_result.json", manifest) + return image_path, manifest + + +def _segment_image( + request: Prompt2GeometryRequest, + image_path: Path, + output_root: Path, +) -> tuple[Path, dict[str, Any]]: + width, height = _image_size(image_path) + full_image_box = SelectedBox( + target_id=request.target_id, + target_kind="asset", + phrase=request.target_id, + bbox_xyxy=[0.0, 0.0, float(width), float(height)], + source_candidate_ids=["full_image_bbox"], + selection_reason="Use the full generated image as a bbox prompt.", + ) + sam3_client = SAM3Client( + base_url=os.getenv("PROMPT2GEOMETRY_SAM3_BASE_URL") or request.sam3_base_url, + ) + health = sam3_client.health() + _write_json(output_root / "sam3_health.json", health) + + sam3_request = { + "image": str(image_path), + "request_id": f"{request.request_id}_sam3_box", + "selected_boxes": [full_image_box.to_manifest()], + "save_visualizations": False, + } + _write_json(output_root / "sam3_box_segmentation_request.json", sam3_request) + result = sam3_client.segment_boxes_image( + image_path, + selected_boxes=[full_image_box], + request_id=f"{request.request_id}_sam3_box", + save_visualizations=False, + progress_path=output_root / "sam3_progress.jsonl", + verbose=request.verbose, + ) + _write_json(output_root / "sam3_segmentation_result.json", result) + + local_outputs = save_segmentation_outputs( + image_path=image_path, + segmentation_result=result, + output_dir=output_root / "segment_box", + ) + _write_json(output_root / "sam3_local_outputs.json", local_outputs) + segmentations = local_outputs.get("segmentations", []) + if not isinstance(segmentations, list) or not segmentations: + raise RuntimeError("SAM3 box segmentation produced no local masks.") + first = segmentations[0] + mask_path = first.get("local_mask_path") + if not isinstance(mask_path, str) or not mask_path: + raise RuntimeError("SAM3 local segmentation output missing local_mask_path.") + return Path(mask_path).expanduser().resolve(), local_outputs + + +def _generate_geometry( + *, + request: Prompt2GeometryRequest, + image_path: Path, + mask_path: Path, + output_root: Path, +) -> dict[str, Any]: + output_name = request.output_name or f"{request.request_id}.glb" + local_glb_path = output_root / "sam3d" / output_name + local_metadata_path = ( + output_root / "sam3d" / f"{Path(output_name).stem}_transform.json" + ) + + client = SAM3DClient( + base_url=os.getenv("PROMPT2GEOMETRY_SAM3D_BASE_URL") or request.sam3d_base_url, + ) + health = client.health() + _write_json(output_root / "sam3d_health.json", health) + generation_request = { + "image": str(image_path), + "mask": str(mask_path), + "request_id": request.request_id, + "output_name": output_name, + "prompt": request.prompt, + "seed": request.sam3d_seed, + "local_glb_path": str(local_glb_path), + "local_transform_metadata_path": str(local_metadata_path), + } + _write_json(output_root / "sam3d_generation_request.json", generation_request) + result = client.generate_asset( + image_path=image_path, + mask_path=mask_path, + request_id=request.request_id, + output_name=output_name, + prompt=request.prompt, + seed=request.sam3d_seed, + output_path=local_glb_path, + metadata_path=local_metadata_path, + progress_path=output_root / "sam3d_progress.jsonl", + verbose=request.verbose, + ) + _write_json(output_root / "sam3d_generation_result.json", result) + return result + + +def _estimate_dimensions( + request: Prompt2GeometryRequest, + output_root: Path, +) -> dict[str, Any]: + client = _llm_client_from_request(request, purpose="dimension estimation") + dimensions = estimate_real_dimensions( + object_prompt=request.prompt, + client=client, + ) + _write_json(output_root / "dimension_estimation.json", dimensions) + return dimensions + + +def _scale_generated_mesh( + *, + mesh_path: Path, + dimensions_m: dict[str, Any], + output_path: Path, + output_root: Path, +) -> dict[str, Any]: + report_path = output_root / "mesh_scaling_report.json" + return scale_mesh_to_real_dimensions( + mesh_path=mesh_path, + output_path=output_path, + dimensions_m=dimensions_m, + report_path=report_path, + ) + + +def _final_scaled_glb_path( + request: Prompt2GeometryRequest, + output_root: Path, +) -> Path: + if request.output_name: + stem = _safe_glb_stem(Path(request.output_name).stem) + else: + client = _llm_client_from_request(request, purpose="GLB file naming") + stem = _extract_glb_stem_from_prompt(request.prompt, client) + return output_root / f"{stem}.glb" + + +def _extract_glb_stem_from_prompt( + prompt: str, + client: OpenAICompatibleClient, +) -> str: + system_prompt = """ + +You extract a concise object file name from a prompt. + + + +Return a JSON object with one field, object_name, containing a short ASCII +snake_case name for the single main object described by the user. + + + +- Output JSON only. +- Required schema: {"object_name": "red_ceramic_mug"} +- object_name must be non-empty. +- Do not include a file extension. +- Use only lowercase English letters, numbers, and underscores. +- Prefer the concrete object noun with one or two useful modifiers. + +""".strip() + messages = [ + {"role": "system", "content": system_prompt}, + { + "role": "user", + "content": ( + "Prompt:\n" f"{prompt.strip()}\n\n" "Return the object_name JSON only." + ), + }, + ] + while True: + try: + raw = client.chat_json(messages=messages) + return _validate_glb_stem_output(raw) + except Exception: + time.sleep(1.0) + continue + + +def _validate_glb_stem_output(raw: dict[str, Any]) -> str: + value = raw.get("object_name") + if not isinstance(value, str) or not value.strip(): + raise ValueError("object_name must be a non-empty string.") + return _safe_glb_stem(value) + + +def _safe_glb_stem(value: str) -> str: + stem = value.strip().lower() + if stem.endswith(".glb"): + stem = stem[:-4] + stem = re.sub(r"[^a-z0-9]+", "_", stem) + stem = re.sub(r"_+", "_", stem).strip("_") + if not stem: + raise ValueError("GLB file name stem is empty after sanitization.") + return stem + + +def _llm_client_from_request( + request: Prompt2GeometryRequest, + *, + purpose: str, +) -> OpenAICompatibleClient: + api_key = os.getenv("PROMPT2GEOMETRY_LLM_API_KEY") or request.llm_api_key + model = os.getenv("PROMPT2GEOMETRY_LLM_MODEL") or request.llm_model + base_url = os.getenv("PROMPT2GEOMETRY_LLM_BASE_URL") or request.llm_base_url + missing = [ + name + for name, value in { + "PROMPT2GEOMETRY_LLM_API_KEY or --llm-api-key": api_key, + "PROMPT2GEOMETRY_LLM_MODEL or --llm-model": model, + "PROMPT2GEOMETRY_LLM_BASE_URL or --llm-base-url": base_url, + }.items() + if not value + ] + if missing: + raise ValueError(f"Missing required LLM config for {purpose}: {missing}") + return OpenAICompatibleClient( + api_key=str(api_key), + model=str(model), + base_url=str(base_url), + timeout_s=request.llm_timeout_s, + usage_stage=f"prompt2geometry.{purpose}", + ) + + +def _cleanup_output_root(output_root: Path, *, keep_path: Path | None) -> None: + output_root = output_root.expanduser().resolve() + keep_path = keep_path.expanduser().resolve() if keep_path is not None else None + if keep_path is not None and not keep_path.is_file(): + keep_path = None + for child in output_root.iterdir(): + if keep_path is not None and child.resolve() == keep_path: + continue + if child.is_dir() and not child.is_symlink(): + shutil.rmtree(child) + else: + child.unlink() + + +def _correct_mask_with_center_prior( + *, + image_path: Path, + raw_mask_path: Path, + output_dir: Path, +) -> Path: + cv2 = _require_cv2() + np = _require_numpy() + + image = cv2.imread(str(image_path), cv2.IMREAD_COLOR) + if image is None: + raise ValueError(f"Failed to read image for mask correction: {image_path}") + raw_mask = cv2.imread(str(raw_mask_path), cv2.IMREAD_GRAYSCALE) + if raw_mask is None: + raise ValueError(f"Failed to read raw mask for correction: {raw_mask_path}") + height, width = image.shape[:2] + if raw_mask.shape[:2] != (height, width): + raise ValueError( + "Raw mask shape does not match image shape: " + f"{raw_mask.shape[:2]} vs {(height, width)}" + ) + + output_dir.mkdir(parents=True, exist_ok=True) + raw_bool = raw_mask > 0 + inverted_bool = ~raw_bool + center_bool, edge_bool = _center_prior_regions(height, width) + normal_score = _center_prior_score(raw_bool, center_bool, edge_bool) + inverted_score = _center_prior_score(inverted_bool, center_bool, edge_bool) + used_inverted = inverted_score["score"] > normal_score["score"] + corrected_bool = inverted_bool if used_inverted else raw_bool + + raw_output = output_dir / "sam3_raw_mask.png" + center_prior_output = output_dir / "center_prior_reference_mask.png" + edge_prior_output = output_dir / "edge_prior_reference_mask.png" + corrected_output = output_dir / "sam3_corrected_mask.png" + cv2.imwrite(str(raw_output), raw_bool.astype("uint8") * 255) + cv2.imwrite(str(center_prior_output), center_bool.astype("uint8") * 255) + cv2.imwrite(str(edge_prior_output), edge_bool.astype("uint8") * 255) + cv2.imwrite(str(corrected_output), corrected_bool.astype("uint8") * 255) + _write_json( + output_dir / "mask_correction_report.json", + { + "image_path": str(image_path), + "raw_mask_path": str(raw_mask_path), + "raw_mask_copy_path": str(raw_output), + "center_prior_reference_mask_path": str(center_prior_output), + "edge_prior_reference_mask_path": str(edge_prior_output), + "corrected_mask_path": str(corrected_output), + "normal_center_prior_score": normal_score, + "inverted_center_prior_score": inverted_score, + "used_inverted_mask": used_inverted, + "raw_mask_area_ratio": float(raw_bool.mean()), + "corrected_mask_area_ratio": float(corrected_bool.mean()), + "foreground_rule": ( + "prefer masks with high center foreground density and low edge " + "foreground density" + ), + }, + ) + return corrected_output + + +def _request_manifest(request: Prompt2GeometryRequest) -> dict[str, Any]: + return { + "prompt": request.prompt, + "output_root": str(request.output_root.expanduser().resolve()), + "target_id": request.target_id, + "request_id": request.request_id, + "output_name": request.output_name, + "zimage_base_url": request.zimage_base_url, + "zimage_width": request.zimage_width, + "zimage_height": request.zimage_height, + "zimage_seed": request.zimage_seed, + "zimage_num_inference_steps": request.zimage_num_inference_steps, + "zimage_prompt_suffix": request.zimage_prompt_suffix, + "sam3_base_url": request.sam3_base_url, + "sam3d_base_url": request.sam3d_base_url, + "sam3d_seed": request.sam3d_seed, + "llm_model": request.llm_model, + "llm_base_url": request.llm_base_url, + "has_llm_api_key": bool(request.llm_api_key), + "llm_timeout_s": request.llm_timeout_s, + "verbose": request.verbose, + } + + +def _zimage_prompt(request: Prompt2GeometryRequest) -> str: + prompt = request.prompt.strip() + suffix = request.zimage_prompt_suffix.strip() + if not suffix: + return prompt + lowered = prompt.lower() + additions = [] + if "single object" not in lowered and "one object" not in lowered: + additions.append("a complete single object") + if "background" not in lowered: + additions.append(_normalize_background_suffix(suffix)) + if not additions: + return prompt + return f"{prompt}, {', '.join(additions)}" + + +def _normalize_background_suffix(suffix: str) -> str: + lowered = suffix.lower() + if "black background" in lowered or "pure-black background" in lowered: + return "with pure-black background" + if "white background" in lowered or "pure-white background" in lowered: + return "with pure-white background" + return suffix + + +def _image_size(image_path: Path) -> tuple[int, int]: + try: + from PIL import Image + except ImportError as exc: + raise ImportError("Pillow is required to read generated image size.") from exc + with Image.open(image_path) as image: + return image.size + + +def _center_prior_regions(height: int, width: int) -> tuple[Any, Any]: + np = _require_numpy() + center_x1 = int(width * 0.2) + center_x2 = int(width * 0.8) + center_y1 = int(height * 0.2) + center_y2 = int(height * 0.8) + center_bool = np.zeros((height, width), dtype=bool) + center_bool[center_y1:center_y2, center_x1:center_x2] = True + + edge_x = max(1, int(width * 0.08)) + edge_y = max(1, int(height * 0.08)) + edge_bool = np.zeros((height, width), dtype=bool) + edge_bool[:edge_y, :] = True + edge_bool[-edge_y:, :] = True + edge_bool[:, :edge_x] = True + edge_bool[:, -edge_x:] = True + return center_bool, edge_bool + + +def _center_prior_score( + mask_bool: Any, + center_bool: Any, + edge_bool: Any, +) -> dict[str, float]: + center_density = _masked_mean(mask_bool, center_bool) + edge_density = _masked_mean(mask_bool, edge_bool) + return { + "score": center_density - edge_density, + "center_foreground_density": center_density, + "edge_foreground_density": edge_density, + } + + +def _masked_mean(mask_bool: Any, region_bool: Any) -> float: + if not region_bool.any(): + return 0.0 + return float(mask_bool[region_bool].mean()) + + +def _write_json(path: Path, payload: dict[str, Any]) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text( + json.dumps(payload, indent=2, ensure_ascii=False) + "\n", + encoding="utf-8", + ) + + +def _log_status(request: Prompt2GeometryRequest, stage: str, message: str) -> None: + if request.verbose: + print(f"[prompt2geometry:{stage}] {message}", flush=True) + + +def _require_cv2() -> Any: + try: + import cv2 + except ImportError as exc: + raise ImportError("opencv-python is required for mask correction.") from exc + return cv2 + + +def _require_numpy() -> Any: + try: + import numpy as np + except ImportError as exc: + raise ImportError("numpy is required for mask correction.") from exc + return np diff --git a/embodichain/gen_sim/action_agent_pipeline/gym_project_api/prompt2geometry/run.py b/embodichain/gen_sim/action_agent_pipeline/gym_project_api/prompt2geometry/run.py new file mode 100644 index 00000000..0890da91 --- /dev/null +++ b/embodichain/gen_sim/action_agent_pipeline/gym_project_api/prompt2geometry/run.py @@ -0,0 +1,135 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +import argparse +import json +from pathlib import Path + +if __package__: + from .config import load_prompt2geometry_config + from .pipeline import Prompt2GeometryRequest, run_prompt2geometry +else: + from config import load_prompt2geometry_config + from pipeline import Prompt2GeometryRequest, run_prompt2geometry + +__all__ = ["main"] + + +def main() -> None: + """Run prompt-to-geometry from the command line.""" + parser = argparse.ArgumentParser( + description=( + "Generate one object mesh from a prompt via z-image, segmentation, " + "and 3D-generation." + ) + ) + parser.add_argument( + "--prompt", + required=True, + help=( + "Object description. Complete single-object and pure-black background " + "constraints are appended automatically." + ), + ) + parser.add_argument( + "--output-root", + type=Path, + default=Path("prompt2geometry_output"), + help="Local output directory.", + ) + parser.add_argument("--target-id", default="asset_0") + parser.add_argument("--request-id", default="prompt2geometry_asset_0") + parser.add_argument( + "--output-name", + default=None, + help=( + "Final scaled GLB file name. If omitted, the LLM extracts one " + "from the prompt." + ), + ) + parser.add_argument( + "--config", + type=Path, + default=None, + help="Prompt2Geometry local config JSON path.", + ) + parser.add_argument("--zimage-base-url", default=None) + parser.add_argument("--sam3-base-url", default=None) + parser.add_argument("--sam3d-base-url", default=None) + parser.add_argument( + "--llm-api-key", + default=None, + help="OpenAI-compatible API key for real-world dimension estimation.", + ) + parser.add_argument( + "--llm-model", + default=None, + help="OpenAI-compatible model for real-world dimension estimation.", + ) + parser.add_argument( + "--llm-base-url", + default=None, + help="OpenAI-compatible base URL for real-world dimension estimation.", + ) + parser.add_argument("--llm-timeout-s", type=float, default=None) + parser.add_argument("--width", type=int, default=1024) + parser.add_argument("--height", type=int, default=1024) + parser.add_argument("--zimage-seed", type=int, default=42) + parser.add_argument("--num-inference-steps", type=int, default=8) + parser.add_argument( + "--zimage-prompt-suffix", + default="a complete single object, with pure-black background", + help="Suffix appended to the object description before z-image generation.", + ) + parser.add_argument("--sam3d-seed", type=int, default=42) + parser.add_argument( + "--quiet", + action="store_true", + help="Disable live progress logs.", + ) + args = parser.parse_args() + cfg = load_prompt2geometry_config(args.config) + + result = run_prompt2geometry( + Prompt2GeometryRequest( + prompt=args.prompt, + output_root=args.output_root, + target_id=args.target_id, + request_id=args.request_id, + output_name=args.output_name, + zimage_base_url=args.zimage_base_url or cfg.zimage_base_url, + zimage_width=args.width, + zimage_height=args.height, + zimage_seed=args.zimage_seed, + zimage_num_inference_steps=args.num_inference_steps, + zimage_prompt_suffix=args.zimage_prompt_suffix, + sam3_base_url=args.sam3_base_url or cfg.sam3_base_url, + sam3d_base_url=args.sam3d_base_url or cfg.sam3d_base_url, + sam3d_seed=args.sam3d_seed, + llm_api_key=args.llm_api_key or cfg.llm_api_key, + llm_model=args.llm_model or cfg.llm_model, + llm_base_url=args.llm_base_url or cfg.llm_base_url, + llm_timeout_s=args.llm_timeout_s or cfg.llm_timeout_s, + verbose=not args.quiet, + ) + ) + print(json.dumps(result, indent=2, ensure_ascii=False)) + + +if __name__ == "__main__": + main() diff --git a/embodichain/gen_sim/action_agent_pipeline/gym_project_api/prompt2geometry/sam3_client.py b/embodichain/gen_sim/action_agent_pipeline/gym_project_api/prompt2geometry/sam3_client.py new file mode 100644 index 00000000..7bc60abe --- /dev/null +++ b/embodichain/gen_sim/action_agent_pipeline/gym_project_api/prompt2geometry/sam3_client.py @@ -0,0 +1,266 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +import json +import mimetypes +import time +import uuid +from pathlib import Path +from typing import Any +from urllib.error import HTTPError, URLError +from urllib.request import ProxyHandler, Request, build_opener + +try: + from .schemas import SelectedBox +except ImportError: + from schemas import SelectedBox + +__all__ = ["SAM3Client", "SAM3ClientError"] + + +class SAM3ClientError(RuntimeError): + """Raised when the SAM3 segmentation service fails.""" + + +class SAM3Client: + """Self-contained HTTP client for SAM3 box segmentation.""" + + def __init__( + self, + *, + base_url: str, + boxes_path: str = "/segment_boxes", + health_path: str = "/health", + timeout_s: float = 120.0, + poll_interval_s: float = 2.0, + ): + self.base_url = base_url.rstrip("/") + self.boxes_path = boxes_path + self.health_path = health_path + self.timeout_s = timeout_s + self.poll_interval_s = poll_interval_s + self._opener = build_opener(ProxyHandler({})) + + def health(self) -> dict[str, Any]: + """Check SAM3 service health.""" + request = Request( + self._url(self.health_path), + headers={"Accept": "application/json"}, + method="GET", + ) + return self._open_json_request(request) + + def segment_boxes_image( + self, + image_path: Path, + *, + selected_boxes: list[SelectedBox], + request_id: str | None = None, + save_visualizations: bool = False, + progress_path: Path | None = None, + verbose: bool = False, + ) -> dict[str, Any]: + """Segment an image using box prompts.""" + payload: dict[str, object] = { + "mode": "box", + "async": True, + "selected_boxes": [box.to_manifest() for box in selected_boxes], + "save_visualizations": save_visualizations, + } + if request_id is not None: + payload["request_id"] = request_id + result = self._post_multipart_json( + self.boxes_path, + payload=payload, + image_path=image_path, + ) + result = self._resolve_async_result( + result, + progress_path=progress_path, + verbose=verbose, + ) + _validate_segmentation_result(result) + return result + + def _post_multipart_json( + self, + path: str, + *, + payload: dict[str, object], + image_path: Path, + ) -> dict[str, Any]: + body, content_type = _build_multipart_body( + payload=payload, + image_path=image_path, + ) + request = Request( + self._url(path), + data=body, + headers={ + "Accept": "application/json", + "Content-Type": content_type, + }, + method="POST", + ) + return self._open_json_request(request) + + def _open_json_request(self, request: Request) -> dict[str, Any]: + try: + with self._opener.open(request, timeout=self.timeout_s) as response: + response_body = response.read().decode("utf-8") + except HTTPError as exc: + detail = exc.read().decode("utf-8", errors="replace") + raise SAM3ClientError( + f"SAM3 request to {request.full_url} failed with " + f"HTTP {exc.code}: {detail}" + ) from exc + except URLError as exc: + raise SAM3ClientError( + f"SAM3 server is unreachable at {request.full_url}: {exc.reason}" + ) from exc + except TimeoutError as exc: + raise SAM3ClientError( + f"SAM3 request to {request.full_url} timed out after " + f"{self.timeout_s}s." + ) from exc + + try: + decoded = json.loads(response_body) + except json.JSONDecodeError as exc: + raise SAM3ClientError( + f"SAM3 server returned non-JSON: {response_body}" + ) from exc + if not isinstance(decoded, dict): + raise SAM3ClientError("SAM3 response must be a JSON object.") + return decoded + + def _resolve_async_result( + self, + result: dict[str, Any], + *, + progress_path: Path | None, + verbose: bool, + ) -> dict[str, Any]: + status = str(result.get("status") or "").lower() + status_url = result.get("status_url") + if status not in {"queued", "running"} or not isinstance(status_url, str): + _append_progress(progress_path, result) + _print_progress("segmentation", result, verbose=verbose) + return result + + _append_progress(progress_path, result) + _print_progress("segmentation", result, verbose=verbose) + while True: + time.sleep(self.poll_interval_s) + job = self._get_json(status_url) + _append_progress(progress_path, job) + _print_progress("segmentation", job, verbose=verbose) + job_status = str(job.get("status") or "").lower() + if job_status in {"queued", "running"}: + continue + if job_status == "succeeded": + final_result = job.get("result") + if not isinstance(final_result, dict): + raise SAM3ClientError("SAM3 async job succeeded without result.") + return final_result + if job_status == "failed": + raise SAM3ClientError(f"SAM3 async job failed: {job}") + raise SAM3ClientError(f"SAM3 async job returned unknown status: {job}") + + def _get_json(self, path: str) -> dict[str, Any]: + request = Request( + self._url(path), + headers={"Accept": "application/json"}, + method="GET", + ) + return self._open_json_request(request) + + def _url(self, path: str) -> str: + if path.startswith("http://") or path.startswith("https://"): + return path + normalized_path = path if path.startswith("/") else f"/{path}" + return f"{self.base_url}{normalized_path}" + + +def _build_multipart_body( + *, + payload: dict[str, object], + image_path: Path, +) -> tuple[bytes, str]: + image_path = image_path.expanduser().resolve() + if not image_path.is_file(): + raise FileNotFoundError(f"Image upload path is not a file: {image_path}") + + boundary = f"----prompt2geometry-sam3-{uuid.uuid4().hex}" + content_type = mimetypes.guess_type(image_path.name)[0] or "image/png" + chunks = [ + f"--{boundary}\r\n".encode("utf-8"), + b'Content-Disposition: form-data; name="payload"\r\n', + b"Content-Type: application/json\r\n\r\n", + json.dumps(payload).encode("utf-8"), + b"\r\n", + f"--{boundary}\r\n".encode("utf-8"), + ( + 'Content-Disposition: form-data; name="image"; ' + f'filename="{image_path.name}"\r\n' + ).encode("utf-8"), + f"Content-Type: {content_type}\r\n\r\n".encode("utf-8"), + image_path.read_bytes(), + b"\r\n", + f"--{boundary}--\r\n".encode("utf-8"), + ] + return b"".join(chunks), f"multipart/form-data; boundary={boundary}" + + +def _append_progress(progress_path: Path | None, payload: dict[str, Any]) -> None: + if progress_path is None: + return + progress_path = progress_path.expanduser().resolve() + progress_path.parent.mkdir(parents=True, exist_ok=True) + with progress_path.open("a", encoding="utf-8") as f: + f.write(json.dumps(payload, ensure_ascii=False) + "\n") + + +def _print_progress(stage: str, payload: dict[str, Any], *, verbose: bool) -> None: + if not verbose: + return + status = payload.get("status") or payload.get("ok") or "unknown" + job_id = payload.get("job_id") or payload.get("request_id") or payload.get("id") + progress = payload.get("progress") + parts = [f"[{stage}] status={status}"] + if job_id is not None: + parts.append(f"job={job_id}") + if progress is not None: + parts.append(f"progress={progress}") + print(" ".join(parts), flush=True) + + +def _validate_segmentation_result(result: dict[str, Any]) -> None: + if result.get("ok") is not True: + raise SAM3ClientError(f"SAM3 segmentation failed: {result}") + segmentations = result.get("segmentations") + if not isinstance(segmentations, list): + raise SAM3ClientError("SAM3 response missing segmentations list.") + for index, segmentation in enumerate(segmentations): + if not isinstance(segmentation, dict): + raise SAM3ClientError(f"SAM3 segmentation {index} must be an object.") + target_id = segmentation.get("target_id") + if not isinstance(target_id, str) or not target_id.strip(): + raise SAM3ClientError( + f"SAM3 segmentation {index} must contain target_id." + ) diff --git a/embodichain/gen_sim/action_agent_pipeline/gym_project_api/prompt2geometry/sam3d_client.py b/embodichain/gen_sim/action_agent_pipeline/gym_project_api/prompt2geometry/sam3d_client.py new file mode 100644 index 00000000..d8e4d8f8 --- /dev/null +++ b/embodichain/gen_sim/action_agent_pipeline/gym_project_api/prompt2geometry/sam3d_client.py @@ -0,0 +1,324 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +import json +import mimetypes +import time +import uuid +from pathlib import Path +from typing import Any +from urllib.error import HTTPError, URLError +from urllib.request import ProxyHandler, Request, build_opener + +__all__ = ["SAM3DClient", "SAM3DClientError"] + + +class SAM3DClientError(RuntimeError): + """Raised when the SAM3D service fails.""" + + +class SAM3DClient: + """Self-contained HTTP client for SAM3D image/mask-to-GLB generation.""" + + def __init__( + self, + *, + base_url: str, + generation_path: str = "/generate", + health_path: str = "/health", + timeout_s: float = 1800.0, + poll_interval_s: float = 5.0, + ): + self.base_url = base_url.rstrip("/") + self.generation_path = generation_path + self.health_path = health_path + self.timeout_s = timeout_s + self.poll_interval_s = poll_interval_s + self._opener = build_opener(ProxyHandler({})) + + def health(self) -> dict[str, Any]: + """Check SAM3D service health.""" + request = Request( + self._url(self.health_path), + headers={"Accept": "application/json"}, + method="GET", + ) + return self._open_json_request(request) + + def generate_asset( + self, + *, + image_path: Path, + mask_path: Path, + request_id: str, + output_name: str, + prompt: str, + seed: int, + output_path: Path, + metadata_path: Path, + progress_path: Path | None = None, + verbose: bool = False, + ) -> dict[str, Any]: + """Generate one 3D asset and download the returned GLB and metadata.""" + payload: dict[str, object] = { + "response_format": "json", + "async": True, + "request_id": request_id, + "output_name": output_name, + "prompt": prompt, + "seed": seed, + } + result = self._post_multipart_json( + self.generation_path, + payload=payload, + image_path=image_path, + mask_path=mask_path, + ) + result = self._resolve_async_result( + result, + progress_path=progress_path, + verbose=verbose, + ) + _validate_generation_result(result) + self._download_required(result, "glb_url", output_path, "model/gltf-binary") + result["local_glb_path"] = str(output_path.expanduser().resolve()) + self._download_required( + result, + "transform_metadata_url", + metadata_path, + "application/json", + ) + result["local_transform_metadata_path"] = str( + metadata_path.expanduser().resolve() + ) + return result + + def _post_multipart_json( + self, + path: str, + *, + payload: dict[str, object], + image_path: Path, + mask_path: Path, + ) -> dict[str, Any]: + body, content_type = _build_multipart_body( + payload=payload, + image_path=image_path, + mask_path=mask_path, + ) + request = Request( + self._url(path), + data=body, + headers={ + "Accept": "application/json", + "Content-Type": content_type, + }, + method="POST", + ) + return self._open_json_request(request) + + def _download_required( + self, + manifest: dict[str, Any], + key: str, + output_path: Path, + accept: str, + ) -> None: + url_path = manifest.get(key) + if not isinstance(url_path, str) or not url_path.strip(): + raise SAM3DClientError(f"SAM3D manifest missing {key}.") + url = url_path if url_path.startswith("http") else self._url(url_path) + output_path = output_path.expanduser().resolve() + output_path.parent.mkdir(parents=True, exist_ok=True) + request = Request(url, headers={"Accept": accept}, method="GET") + try: + with self._opener.open(request, timeout=self.timeout_s) as response: + output_path.write_bytes(response.read()) + except HTTPError as exc: + detail = exc.read().decode("utf-8", errors="replace") + raise SAM3DClientError( + f"SAM3D download from {url} failed with HTTP {exc.code}: {detail}" + ) from exc + except URLError as exc: + raise SAM3DClientError( + f"SAM3D server is unreachable at {url}: {exc.reason}" + ) from exc + + def _open_json_request(self, request: Request) -> dict[str, Any]: + try: + with self._opener.open(request, timeout=self.timeout_s) as response: + response_body = response.read().decode("utf-8") + except HTTPError as exc: + detail = exc.read().decode("utf-8", errors="replace") + raise SAM3DClientError( + f"SAM3D request to {request.full_url} failed with " + f"HTTP {exc.code}: {detail}" + ) from exc + except URLError as exc: + raise SAM3DClientError( + f"SAM3D server is unreachable at {request.full_url}: {exc.reason}" + ) from exc + except TimeoutError as exc: + raise SAM3DClientError( + f"SAM3D request to {request.full_url} timed out after " + f"{self.timeout_s}s." + ) from exc + + try: + decoded = json.loads(response_body) + except json.JSONDecodeError as exc: + raise SAM3DClientError( + f"SAM3D server returned non-JSON: {response_body}" + ) from exc + if not isinstance(decoded, dict): + raise SAM3DClientError("SAM3D response must be a JSON object.") + return decoded + + def _resolve_async_result( + self, + result: dict[str, Any], + *, + progress_path: Path | None, + verbose: bool, + ) -> dict[str, Any]: + status = str(result.get("status") or "").lower() + status_url = result.get("status_url") + if status not in {"queued", "running"} or not isinstance(status_url, str): + _append_progress(progress_path, result) + _print_progress("3D-generation", result, verbose=verbose) + return result + + _append_progress(progress_path, result) + _print_progress("3D-generation", result, verbose=verbose) + while True: + time.sleep(self.poll_interval_s) + job = self._get_json(status_url) + _append_progress(progress_path, job) + _print_progress("3D-generation", job, verbose=verbose) + job_status = str(job.get("status") or "").lower() + if job_status in {"queued", "running"}: + continue + if job_status == "succeeded": + final_result = job.get("result") + if not isinstance(final_result, dict): + raise SAM3DClientError("SAM3D async job succeeded without result.") + return final_result + if job_status == "failed": + raise SAM3DClientError(f"SAM3D async job failed: {job}") + raise SAM3DClientError(f"SAM3D async job returned unknown status: {job}") + + def _get_json(self, path: str) -> dict[str, Any]: + request = Request( + self._url(path), + headers={"Accept": "application/json"}, + method="GET", + ) + return self._open_json_request(request) + + def _url(self, path: str) -> str: + if path.startswith("http://") or path.startswith("https://"): + return path + normalized_path = path if path.startswith("/") else f"/{path}" + return f"{self.base_url}{normalized_path}" + + +def _build_multipart_body( + *, + payload: dict[str, object], + image_path: Path, + mask_path: Path, +) -> tuple[bytes, str]: + image_path = image_path.expanduser().resolve() + mask_path = mask_path.expanduser().resolve() + if not image_path.is_file(): + raise FileNotFoundError(f"Image upload path is not a file: {image_path}") + if not mask_path.is_file(): + raise FileNotFoundError(f"Mask upload path is not a file: {mask_path}") + + boundary = f"----prompt2geometry-sam3d-{uuid.uuid4().hex}" + chunks = [ + _multipart_text(boundary, "payload", json.dumps(payload), "application/json"), + _multipart_file(boundary, "image", image_path), + _multipart_file(boundary, "mask", mask_path), + f"--{boundary}--\r\n".encode("utf-8"), + ] + return b"".join(chunks), f"multipart/form-data; boundary={boundary}" + + +def _append_progress(progress_path: Path | None, payload: dict[str, Any]) -> None: + if progress_path is None: + return + progress_path = progress_path.expanduser().resolve() + progress_path.parent.mkdir(parents=True, exist_ok=True) + with progress_path.open("a", encoding="utf-8") as f: + f.write(json.dumps(payload, ensure_ascii=False) + "\n") + + +def _print_progress(stage: str, payload: dict[str, Any], *, verbose: bool) -> None: + if not verbose: + return + status = payload.get("status") or payload.get("ok") or "unknown" + job_id = payload.get("job_id") or payload.get("request_id") or payload.get("id") + progress = payload.get("progress") + parts = [f"[{stage}] status={status}"] + if job_id is not None: + parts.append(f"job={job_id}") + if progress is not None: + parts.append(f"progress={progress}") + print(" ".join(parts), flush=True) + + +def _multipart_text( + boundary: str, + field_name: str, + value: str, + content_type: str, +) -> bytes: + return b"".join( + [ + f"--{boundary}\r\n".encode("utf-8"), + f'Content-Disposition: form-data; name="{field_name}"\r\n'.encode("utf-8"), + f"Content-Type: {content_type}\r\n\r\n".encode("utf-8"), + value.encode("utf-8"), + b"\r\n", + ] + ) + + +def _multipart_file(boundary: str, field_name: str, path: Path) -> bytes: + content_type = mimetypes.guess_type(path.name)[0] or "application/octet-stream" + return b"".join( + [ + f"--{boundary}\r\n".encode("utf-8"), + ( + f'Content-Disposition: form-data; name="{field_name}"; ' + f'filename="{path.name}"\r\n' + ).encode("utf-8"), + f"Content-Type: {content_type}\r\n\r\n".encode("utf-8"), + path.read_bytes(), + b"\r\n", + ] + ) + + +def _validate_generation_result(result: dict[str, Any]) -> None: + if result.get("ok") is not True: + raise SAM3DClientError(f"SAM3D generation failed: {result}") + glb_url = result.get("glb_url") + if not isinstance(glb_url, str) or not glb_url.strip(): + raise SAM3DClientError("SAM3D response missing glb_url.") diff --git a/embodichain/gen_sim/action_agent_pipeline/gym_project_api/prompt2geometry/schemas.py b/embodichain/gen_sim/action_agent_pipeline/gym_project_api/prompt2geometry/schemas.py new file mode 100644 index 00000000..b86ee284 --- /dev/null +++ b/embodichain/gen_sim/action_agent_pipeline/gym_project_api/prompt2geometry/schemas.py @@ -0,0 +1,46 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from dataclasses import dataclass, field + +__all__ = ["SelectedBox"] + + +@dataclass(frozen=True) +class SelectedBox: + """One box prompt passed to the SAM3 segmentation service.""" + + target_id: str + target_kind: str + phrase: str + bbox_xyxy: list[float] + source_candidate_ids: list[str] = field(default_factory=list) + selection_reason: str | None = None + + def to_manifest(self) -> dict[str, object]: + """Convert the selected box to JSON-safe data.""" + manifest: dict[str, object] = { + "target_id": self.target_id, + "target_kind": self.target_kind, + "phrase": self.phrase, + "bbox_xyxy": self.bbox_xyxy, + "source_candidate_ids": self.source_candidate_ids, + } + if self.selection_reason is not None: + manifest["selection_reason"] = self.selection_reason + return manifest diff --git a/embodichain/gen_sim/action_agent_pipeline/gym_project_api/prompt2geometry/segmentation_outputs.py b/embodichain/gen_sim/action_agent_pipeline/gym_project_api/prompt2geometry/segmentation_outputs.py new file mode 100644 index 00000000..bad645ec --- /dev/null +++ b/embodichain/gen_sim/action_agent_pipeline/gym_project_api/prompt2geometry/segmentation_outputs.py @@ -0,0 +1,245 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +import json +from pathlib import Path +from typing import Any + +__all__ = ["save_segmentation_outputs"] + + +def save_segmentation_outputs( + *, + image_path: Path, + segmentation_result: dict[str, Any], + output_dir: Path, +) -> dict[str, Any]: + """Save local mask and transparent crop images from mask RLE output.""" + cv2 = _require_cv2() + image_path = image_path.expanduser().resolve() + output_dir = output_dir.expanduser().resolve() + if not image_path.is_file(): + raise FileNotFoundError(f"Segmentation source image not found: {image_path}") + + image = cv2.imread(str(image_path)) + if image is None: + raise ValueError(f"Failed to read segmentation source image: {image_path}") + + segmentations = segmentation_result.get("segmentations", []) + if not isinstance(segmentations, list): + raise ValueError("Segmentation result key segmentations must be a list.") + + output_dir.mkdir(parents=True, exist_ok=True) + local_segmentations = [] + height, width = image.shape[:2] + used_stems: set[str] = set() + + for index, segmentation in enumerate(segmentations): + if not isinstance(segmentation, dict): + continue + mask_rle = segmentation.get("mask_rle") or segmentation.get("segmentation") + if not isinstance(mask_rle, dict): + continue + mask_bool = _decode_mask_rle(mask_rle).astype(bool) + if mask_bool.shape[:2] != (height, width): + raise ValueError( + "Decoded mask shape does not match source image: " + f"{mask_bool.shape[:2]} vs {(height, width)}" + ) + + bbox = _bbox_from_segmentation(segmentation, mask_bool) + target_id = str(segmentation.get("target_id") or f"segment_{index}") + phrase = str(segmentation.get("phrase") or target_id) + file_stem = _unique_file_stem(_safe_name(target_id), used_stems) + mask_path = output_dir / f"{file_stem}_mask.png" + crop_path = output_dir / f"{file_stem}_crop.png" + + _save_mask(mask_bool, mask_path) + _save_transparent_crop(image, mask_bool, bbox, crop_path) + local_segmentations.append( + { + "target_id": target_id, + "target_kind": segmentation.get("target_kind"), + "phrase": phrase, + "bbox_xyxy": [float(value) for value in bbox], + "local_mask_path": str(mask_path), + "local_crop_path": str(crop_path), + } + ) + + manifest = { + "output_dir": str(output_dir), + "source_image_path": str(image_path), + "segmentations": local_segmentations, + "num_segmentations": len(local_segmentations), + } + manifest_path = output_dir / "segmentation_outputs.json" + manifest["manifest_path"] = str(manifest_path) + manifest_path.write_text( + json.dumps(manifest, indent=2, ensure_ascii=False) + "\n", + encoding="utf-8", + ) + return manifest + + +def _decode_mask_rle(rle: dict[str, Any]) -> Any: + try: + from pycocotools import mask as mask_util + except ImportError: + mask_util = None + + if mask_util is not None: + try: + return mask_util.decode(rle) + except Exception: + pass + + np = _require_numpy() + size = rle.get("size") + if not isinstance(size, list) or len(size) != 2: + raise ValueError("Mask RLE must contain size [height, width].") + height, width = int(size[0]), int(size[1]) + counts = rle.get("counts") + if isinstance(counts, str): + runs = _decode_compressed_coco_rle_counts(counts) + elif isinstance(counts, list): + runs = [int(value) for value in counts] + else: + raise ValueError("Mask RLE counts must be a string or list.") + + flat = np.zeros(height * width, dtype=np.uint8) + offset = 0 + value = 0 + for run_length in runs: + next_offset = min(offset + int(run_length), flat.size) + if value == 1: + flat[offset:next_offset] = 1 + offset = next_offset + value = 1 - value + return flat.reshape((height, width), order="F") + + +def _decode_compressed_coco_rle_counts(counts: str) -> list[int]: + runs = [] + index = 0 + while index < len(counts): + value = 0 + shift = 0 + more = True + while more: + char_value = ord(counts[index]) - 48 + index += 1 + value |= (char_value & 0x1F) << shift + more = bool(char_value & 0x20) + shift += 5 + if not more and (char_value & 0x10): + value |= -1 << shift + if len(runs) > 2: + value += runs[-2] + runs.append(value) + return runs + + +def _bbox_from_segmentation( + segmentation: dict[str, Any], + mask_bool: Any, +) -> tuple[int, int, int, int]: + bbox = segmentation.get("bbox_xyxy") + if isinstance(bbox, list) and len(bbox) == 4: + return tuple(int(round(float(value))) for value in bbox) + + np = _require_numpy() + ys, xs = np.where(mask_bool) + if len(xs) == 0 or len(ys) == 0: + return 0, 0, 0, 0 + return int(xs.min()), int(ys.min()), int(xs.max() + 1), int(ys.max() + 1) + + +def _save_mask(mask_bool: Any, output_path: Path) -> None: + cv2 = _require_cv2() + cv2.imwrite(str(output_path), mask_bool.astype("uint8") * 255) + + +def _save_transparent_crop( + image: Any, + mask_bool: Any, + bbox: tuple[int, int, int, int], + output_path: Path, +) -> None: + cv2 = _require_cv2() + np = _require_numpy() + x1, y1, x2, y2 = _clip_bbox(bbox, image=image) + if x2 <= x1 or y2 <= y1: + return + crop_bgr = image[y1:y2, x1:x2] + crop_mask = mask_bool[y1:y2, x1:x2].astype("uint8") * 255 + crop_bgra = np.dstack([crop_bgr, crop_mask]) + cv2.imwrite(str(output_path), crop_bgra) + + +def _clip_bbox( + bbox: tuple[int, int, int, int], + *, + image: Any, +) -> tuple[int, int, int, int]: + height, width = image.shape[:2] + x1, y1, x2, y2 = bbox + return ( + max(0, min(width, x1)), + max(0, min(height, y1)), + max(0, min(width, x2)), + max(0, min(height, y2)), + ) + + +def _safe_name(value: str) -> str: + safe = "".join( + char if char.isalnum() or char in {"-", "_"} else "_" + for char in value.strip().lower() + ) + return safe or "object" + + +def _unique_file_stem(stem: str, used_stems: set[str]) -> str: + if stem not in used_stems: + used_stems.add(stem) + return stem + suffix = 1 + while f"{stem}_{suffix}" in used_stems: + suffix += 1 + unique_stem = f"{stem}_{suffix}" + used_stems.add(unique_stem) + return unique_stem + + +def _require_cv2() -> Any: + try: + import cv2 + except ImportError as exc: + raise ImportError( + "opencv-python is required to save segmentation outputs." + ) from exc + return cv2 + + +def _require_numpy() -> Any: + try: + import numpy as np + except ImportError as exc: + raise ImportError("numpy is required to save segmentation outputs.") from exc + return np diff --git a/embodichain/gen_sim/action_agent_pipeline/gym_project_api/prompt2geometry/zimage_client.py b/embodichain/gen_sim/action_agent_pipeline/gym_project_api/prompt2geometry/zimage_client.py new file mode 100644 index 00000000..e9d7b287 --- /dev/null +++ b/embodichain/gen_sim/action_agent_pipeline/gym_project_api/prompt2geometry/zimage_client.py @@ -0,0 +1,115 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +import json +from pathlib import Path +from typing import Any +from urllib.error import HTTPError, URLError +from urllib.request import ProxyHandler, Request, build_opener + +__all__ = ["ZImageClient", "ZImageClientError"] + + +class ZImageClientError(RuntimeError): + """Raised when the z-image service request fails.""" + + +class ZImageClient: + """HTTP client for the deployed z-image PNG generation service.""" + + def __init__( + self, + *, + base_url: str = "http://192.168.3.23:5013", + generation_path: str = "/generate.png", + timeout_s: float = 300.0, + ): + """Initialize the z-image client.""" + self.base_url = base_url.rstrip("/") + self.generation_path = generation_path + self.timeout_s = timeout_s + self._opener = build_opener(ProxyHandler({})) + + def generate_png( + self, + *, + prompt: str, + output_path: Path, + width: int = 1024, + height: int = 1024, + seed: int = 42, + num_inference_steps: int = 8, + ) -> dict[str, Any]: + """Generate a PNG image and write it to ``output_path``.""" + payload = { + "prompt": prompt, + "width": width, + "height": height, + "seed": seed, + "num_inference_steps": num_inference_steps, + } + body = json.dumps(payload).encode("utf-8") + request = Request( + self._url(self.generation_path), + data=body, + headers={ + "Accept": "image/png", + "Content-Type": "application/json", + }, + method="POST", + ) + output_path = output_path.expanduser().resolve() + output_path.parent.mkdir(parents=True, exist_ok=True) + try: + with self._opener.open(request, timeout=self.timeout_s) as response: + content = response.read() + except HTTPError as exc: + detail = exc.read().decode("utf-8", errors="replace") + raise ZImageClientError( + f"z-image request to {request.full_url} failed with " + f"HTTP {exc.code}: {detail}" + ) from exc + except URLError as exc: + raise ZImageClientError( + f"z-image server is unreachable at {request.full_url}: {exc.reason}" + ) from exc + except TimeoutError as exc: + raise ZImageClientError( + f"z-image request to {request.full_url} timed out after " + f"{self.timeout_s}s." + ) from exc + + if not content: + raise ZImageClientError("z-image server returned an empty image response.") + output_path.write_bytes(content) + return { + "provider": "z-image", + "base_url": self.base_url, + "generation_path": self.generation_path, + "prompt": prompt, + "width": width, + "height": height, + "seed": seed, + "num_inference_steps": num_inference_steps, + "output_path": str(output_path), + "num_bytes": len(content), + } + + def _url(self, path: str) -> str: + normalized_path = path if path.startswith("/") else f"/{path}" + return f"{self.base_url}{normalized_path}" diff --git a/embodichain/gen_sim/action_agent_pipeline/prompts/__init__.py b/embodichain/gen_sim/action_agent_pipeline/prompts/__init__.py new file mode 100644 index 00000000..88168e41 --- /dev/null +++ b/embodichain/gen_sim/action_agent_pipeline/prompts/__init__.py @@ -0,0 +1,21 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from .task_prompt import TaskPrompt + +__all__ = ["TaskPrompt"] diff --git a/embodichain/gen_sim/action_agent_pipeline/prompts/atom_actions.txt b/embodichain/gen_sim/action_agent_pipeline/prompts/atom_actions.txt new file mode 100644 index 00000000..596fc455 --- /dev/null +++ b/embodichain/gen_sim/action_agent_pipeline/prompts/atom_actions.txt @@ -0,0 +1,57 @@ +### Atomic Action Class JSON Specs for Robot Arm Control + +Each non-null graph edge action must be a JSON object with these common fields +and exactly one target field: + +{ + "atomic_action_class": "PickUpAction|MoveAction|PlaceAction", + "robot_name": "left_arm|right_arm", + "control": "arm|hand", + "cfg": {} +} + +Use only these atomic action classes: + +1. `PickUpAction` + - Required target_object: + {"obj_name": "", "affordance": "antipodal"} + - Typical cfg: + {"pre_grasp_distance": 0.08, "sample_interval": 45} + +2. `MoveAction` + - Use `control: "arm"` with target_pose or arm target_qpos. + - Use `control: "hand"` with gripper target_qpos. + - Supported target_pose objects: + {"reference": "object", "obj_name": "", "offset": [x, y, z]} + {"reference": "absolute", "position": [x, y, z]} + {"reference": "relative", "offset": [dx, dy, dz], "frame": "world|eef"} + - Supported target_qpos objects: + {"source": "initial"} + {"source": "gripper_state", "state": "open|close"} + {"source": "joint_delta", "joint_index": 5, "delta_degrees": -90} + - Typical cfg: + {"sample_interval": 30} + +3. `PlaceAction` + - Prefer this for placement because one action lowers, opens the gripper, + and retreats upward. + - Required target_pose. Supported pose targets are the same target_pose objects + accepted by `MoveAction`. + - Typical cfg: + {"sample_interval": 80, "lift_height": 0.1} + +Rules: +- Do not output Python code, function calls, or `fn`/`kwargs` action objects. +- Do not output legacy `action`-based specs. +- Use `null` for an idle arm. +- Keep all values JSON primitives. +- Each non-null action must contain exactly one of `target_object`, `target_pose`, + or `target_qpos`. +- To keep a holding arm closed while the other arm moves, use: + { + "atomic_action_class": "MoveAction", + "robot_name": "", + "control": "hand", + "cfg": {"sample_interval": 10}, + "target_qpos": {"source": "gripper_state", "state": "close"} + } diff --git a/embodichain/gen_sim/action_agent_pipeline/prompts/basic_background.txt b/embodichain/gen_sim/action_agent_pipeline/prompts/basic_background.txt new file mode 100644 index 00000000..65088dd0 --- /dev/null +++ b/embodichain/gen_sim/action_agent_pipeline/prompts/basic_background.txt @@ -0,0 +1,20 @@ +The environment uses a right-handed world coordinate system, where 1 unit equals 1 meter. +All robot poses are represented as 4×4 homogeneous transformation matrices. + +The robot base coordinate frame is the ONLY authoritative frame for all spatial reasoning, planning, and action generation. + +ROBOT BASE COORDINATE DEFINITIONS + +All directions below are defined strictly in the robot base frame: + +* Moving forward decreases x +* Moving backward increases x +* Moving left decreases y +* Moving right increases y +* Moving up increases z +* Moving down decreases z + +ROBOT INITIALIZATION AND TERMINATION + +Both robot arms start in predefined initial configurations with their end-effectors open. +At task completion, both arms must be returned to their initial poses. diff --git a/embodichain/gen_sim/action_agent_pipeline/prompts/task_prompt.py b/embodichain/gen_sim/action_agent_pipeline/prompts/task_prompt.py new file mode 100644 index 00000000..8f512216 --- /dev/null +++ b/embodichain/gen_sim/action_agent_pipeline/prompts/task_prompt.py @@ -0,0 +1,122 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from typing import Any + +import torch +from langchain_core.messages import SystemMessage +from langchain_core.prompts import ( + ChatPromptTemplate, + HumanMessagePromptTemplate, +) +from embodichain.utils.utility import encode_image + +__all__ = ["TaskPrompt"] + + +class TaskPrompt: + @staticmethod + def generate_task_graph(observations: dict[str, Any], **kwargs: Any) -> Any: + """Build a prompt that asks the task agent for a nominal JSON graph.""" + schema = """{ + "task": "", + "start": "v0_start", + "goal": "vN_done", + "nodes": [ + {"id": "v0_start", "semantic": ""}, + {"id": "v1_", "semantic": ""} + ], + "edges": [ + { + "id": "e01_", + "source": "v0_start", + "target": "v1_", + "left_arm_action": { + "atomic_action_class": "PickUpAction|MoveAction|PlaceAction", + "robot_name": "left_arm|right_arm", + "control": "arm|hand", + "target_object": {"obj_name": "", "affordance": "antipodal"}, + "cfg": {} + }, + "right_arm_action": null + } + ] +}""" + + observation = ( + observations["rgb"].cpu().numpy() + if isinstance(observations["rgb"], torch.Tensor) + else observations["rgb"] + ) + kwargs.update( + { + "graph_schema": schema, + "observation": encode_image(observation), + } + ) + + prompt = ChatPromptTemplate.from_messages( + [ + SystemMessage( + content=( + "You are a precise robotic manipulation graph planner. " + "Given a camera observation and task description, produce only " + "the nominal atomic-action graph. Do not add failure monitors, " + "error injection, recovery branches, Python code, or prose. " + "All actions must strictly use the provided atomic action class JSON specs." + ) + ), + HumanMessagePromptTemplate.from_template( + [ + { + "type": "image_url", + "image_url": { + "url": "data:image/png;base64,{observation}", + }, + }, + { + "type": "text", + "text": ( + "Use the current camera observation and context below to " + "generate a nominal atomic-action graph for the task.\n\n" + "**Environment background:**\n{basic_background}\n\n" + '**Task goal:**\n"{task_prompt}"\n\n' + "**Available atomic actions:**\n{atom_actions}\n\n" + "**Required JSON schema:**\n" + "{graph_schema}\n\n" + "Rules:\n" + "- Output exactly one JSON object and nothing else.\n" + "- The nominal graph must be one deterministic start-to-goal chain with no branches, cycles, or orphan edges.\n" + "- Each edge is one semantic task step from source node to target node.\n" + "- Every edge must define at least one non-null arm action.\n" + "- Use `null` for an idle arm action.\n" + "- Each non-null arm action must use the atomic action class JSON spec with `atomic_action_class`, `robot_name`, `control`, `cfg`, and exactly one of `target_object`, `target_pose`, or `target_qpos`.\n" + "- Do not output legacy function calls, `action`-based specs, or `fn`/`kwargs` action objects.\n" + "- Put only JSON primitives inside action specs: strings, numbers, booleans, null, arrays, or objects.\n" + "- Do not include `env`, tensors, comments, validation conditions, monitors, errors, or recovery fields.\n" + "- Preserve task order and use both arms on the same edge when they should act simultaneously.\n" + "- Use stable ids such as `v0_start`, `v1_grasped`, `e01_grasp_objects`.\n" + "- Replace `N` with the concrete final step index; do not literally output `vN_done`.\n" + "- The final edge target must equal the `goal` field." + ), + }, + ] + ), + ] + ) + return prompt.invoke(kwargs) diff --git a/embodichain/gen_sim/action_agent_pipeline/runtime/__init__.py b/embodichain/gen_sim/action_agent_pipeline/runtime/__init__.py new file mode 100644 index 00000000..a6bb7005 --- /dev/null +++ b/embodichain/gen_sim/action_agent_pipeline/runtime/__init__.py @@ -0,0 +1,21 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +"""Runtime graph compilation and atomic-action execution.""" + +__all__: list[str] = [] diff --git a/embodichain/gen_sim/action_agent_pipeline/runtime/atom_action_utils.py b/embodichain/gen_sim/action_agent_pipeline/runtime/atom_action_utils.py new file mode 100644 index 00000000..806ad2f0 --- /dev/null +++ b/embodichain/gen_sim/action_agent_pipeline/runtime/atom_action_utils.py @@ -0,0 +1,102 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from embodichain.utils.logger import log_error + + +def _available_arm_sides(env) -> list[str]: + sides = [] + for side in ("left", "right"): + if len(getattr(env, f"{side}_arm_joints", []) or []) > 0: + sides.append(side) + return sides + + +def resolve_arm_side(env, robot_name: str) -> str: + """Resolve robot_name to an available left/right graph slot.""" + name = robot_name or "" + if "right" in name: + side = "right" + elif "left" in name: + side = "left" + else: + sides = _available_arm_sides(env) + side = "right" if sides == ["right"] else "left" + + if side not in _available_arm_sides(env): + log_error( + f"Requested {side}_arm for robot_name='{robot_name}', but available " + f"control parts are {getattr(env.robot, 'control_parts', None)}.", + error_type=ValueError, + ) + return side + + +def get_arm_states(env, robot_name): + """Get the current state of the specified robot arm. + + Args: + env: The simulation environment. + robot_name: Name of the robot arm (should contain "left" or "right"). + + Returns: + Tuple of (is_left, select_arm, current_qpos, current_pose, current_gripper_state): + - is_left: bool, whether this is the left arm + - select_arm: str, arm identifier ("left_arm" or "right_arm") + - current_qpos: Current joint positions + - current_pose: Current end-effector pose (4x4 matrix) + - current_gripper_state: Current gripper state + """ + left_arm_current_qpos, right_arm_current_qpos = env.get_current_qpos_agent() + left_arm_current_pose, right_arm_current_pose = env.get_current_xpos_agent() + left_arm_current_gripper_state, right_arm_current_gripper_state = ( + env.get_current_gripper_state_agent() + ) + + side = resolve_arm_side(env, robot_name) + is_left = True if side == "left" else False + if hasattr(env, "get_agent_arm_control_part"): + select_arm = env.get_agent_arm_control_part(is_left) + else: + select_arm = "left_arm" if is_left else "right_arm" + + arms = { + "left": ( + left_arm_current_qpos, + left_arm_current_pose, + left_arm_current_gripper_state, + ), + "right": ( + right_arm_current_qpos, + right_arm_current_pose, + right_arm_current_gripper_state, + ), + } + ( + select_arm_current_qpos, + select_arm_current_pose, + select_arm_current_gripper_state, + ) = arms[side] + + return ( + is_left, + select_arm, + select_arm_current_qpos, + select_arm_current_pose, + select_arm_current_gripper_state, + ) diff --git a/embodichain/gen_sim/action_agent_pipeline/runtime/atom_actions.py b/embodichain/gen_sim/action_agent_pipeline/runtime/atom_actions.py new file mode 100644 index 00000000..ec6733a5 --- /dev/null +++ b/embodichain/gen_sim/action_agent_pipeline/runtime/atom_actions.py @@ -0,0 +1,1170 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +import hashlib +import os +from copy import deepcopy +from dataclasses import dataclass, field +from typing import Any, Mapping + +import numpy as np +import torch +from tqdm import tqdm + +from embodichain.gen_sim.action_agent_pipeline.runtime.atom_action_utils import ( + get_arm_states, + resolve_arm_side, +) +from embodichain.lab.sim.atomic_actions import ( + AntipodalAffordance, + MoveAction, + MoveActionCfg, + ObjectSemantics, + PickUpAction, + PickUpActionCfg, + PlaceAction, + PlaceActionCfg, +) +from embodichain.lab.sim.planners import MotionGenerator, MotionGenCfg, ToppraPlannerCfg +from embodichain.toolkits.graspkit.pg_grasp import ( + AntipodalSamplerCfg, + GraspGeneratorCfg, + GripperCollisionCfg, +) +from embodichain.toolkits.graspkit.pg_grasp.antipodal_generator import ( + GRASP_ANNOTATOR_CACHE_DIR, +) +from embodichain.utils.logger import log_info +from embodichain.utils.math import get_offset_pose + +__all__ = [ + "AtomicActionSpec", + "execute_atomic_action", + "execute_parallel_atomic_actions", + "normalize_atomic_action_spec", +] + + +SUPPORTED_ATOMIC_ACTION_CLASSES = {"PickUpAction", "MoveAction", "PlaceAction"} +SUPPORTED_CONTROLS = {"arm", "hand"} +TARGET_SPEC_FIELDS = ("target_object", "target_pose", "target_qpos") +ACTION_SPEC_FIELDS = { + "atomic_action_class", + "robot_name", + "control", + "cfg", + *TARGET_SPEC_FIELDS, +} +SUPPORTED_POSE_REFERENCES = {"object", "absolute", "relative"} +SUPPORTED_QPOS_SOURCES = {"initial", "gripper_state", "joint_delta"} +SUPPORTED_CFG_KEYS = { + "sample_interval", + "pre_grasp_distance", + "lift_height", + "hand_interp_steps", + "post_hold_steps", +} + + +ATOMIC_ACTION_REGISTRY = { + "PickUpAction": (PickUpAction, PickUpActionCfg), + "MoveAction": (MoveAction, MoveActionCfg), + "PlaceAction": (PlaceAction, PlaceActionCfg), +} + + +@dataclass(frozen=True) +class AtomicActionSpec: + """JSON-serializable atomic action specification.""" + + atomic_action_class: str + robot_name: str + control: str = "arm" + target_object: dict[str, Any] = field(default_factory=dict) + target_pose: dict[str, Any] = field(default_factory=dict) + target_qpos: dict[str, Any] = field(default_factory=dict) + cfg: dict[str, Any] = field(default_factory=dict) + + @classmethod + def from_mapping(cls, spec: Mapping[str, Any]) -> "AtomicActionSpec": + normalized = normalize_atomic_action_spec(spec) + return cls( + atomic_action_class=normalized["atomic_action_class"], + robot_name=normalized["robot_name"], + control=normalized["control"], + target_object=normalized.get("target_object", {}), + target_pose=normalized.get("target_pose", {}), + target_qpos=normalized.get("target_qpos", {}), + cfg=normalized["cfg"], + ) + + def to_dict(self) -> dict[str, Any]: + spec = { + "atomic_action_class": self.atomic_action_class, + "robot_name": self.robot_name, + "control": self.control, + "cfg": deepcopy(self.cfg), + } + if self.target_object: + spec["target_object"] = deepcopy(self.target_object) + if self.target_pose: + spec["target_pose"] = deepcopy(self.target_pose) + if self.target_qpos: + spec["target_qpos"] = deepcopy(self.target_qpos) + return spec + + +def normalize_atomic_action_spec(spec: Mapping[str, Any]) -> dict[str, Any]: + """Validate and normalize an atomic action JSON spec.""" + if not isinstance(spec, Mapping): + raise TypeError(f"Action spec must be a mapping, got {type(spec)}.") + if "fn" in spec: + raise ValueError( + "Legacy fn/kwargs action schema is not supported. Use atomic action class " + "JSON spec with atomic_action_class, robot_name, control, cfg, and " + "exactly one of target_object, target_pose, or target_qpos." + ) + + if "action" in spec: + raise ValueError( + "Legacy action schema is not supported. Use atomic_action_class with " + "PickUpAction, MoveAction, or PlaceAction." + ) + if "target" in spec: + raise ValueError( + "Legacy target.kind schema is not supported. Use exactly one of " + "target_object, target_pose, or target_qpos." + ) + unknown_fields = set(spec) - ACTION_SPEC_FIELDS + if unknown_fields: + raise ValueError( + f"Unsupported atomic action spec fields: " + f"{', '.join(sorted(unknown_fields))}." + ) + + atomic_action_class = spec.get("atomic_action_class") + if atomic_action_class not in SUPPORTED_ATOMIC_ACTION_CLASSES: + raise ValueError( + f"Unsupported atomic action class {atomic_action_class!r}; expected " + f"one of {sorted(SUPPORTED_ATOMIC_ACTION_CLASSES)}." + ) + + robot_name = spec.get("robot_name") + if not isinstance(robot_name, str) or not robot_name: + raise ValueError("Atomic action spec requires non-empty robot_name.") + + control = spec.get("control", "arm") + if control not in SUPPORTED_CONTROLS: + raise ValueError( + f"Unsupported atomic action control {control!r}; expected one of " + f"{sorted(SUPPORTED_CONTROLS)}." + ) + + cfg = dict(spec.get("cfg") or {}) + unknown_cfg = set(cfg) - SUPPORTED_CFG_KEYS + if unknown_cfg: + raise ValueError( + f"Unsupported atomic action cfg keys: {', '.join(sorted(unknown_cfg))}." + ) + + target_field, target_spec = _normalize_action_target( + spec, + atomic_action_class=atomic_action_class, + control=control, + ) + + normalized = { + "atomic_action_class": atomic_action_class, + "robot_name": robot_name, + "control": control, + "cfg": cfg, + } + normalized[target_field] = target_spec + return normalized + + +def _normalize_action_target( + spec: Mapping[str, Any], + *, + atomic_action_class: str, + control: str, +) -> tuple[str, dict[str, Any]]: + target_fields = [field for field in TARGET_SPEC_FIELDS if field in spec] + if len(target_fields) != 1: + raise ValueError( + "Atomic action spec requires exactly one of target_object, target_pose, " + f"or target_qpos; got {target_fields}." + ) + + target_field = target_fields[0] + target_spec = spec[target_field] + if not isinstance(target_spec, Mapping) or not target_spec: + raise ValueError(f"{target_field} must be a non-empty object.") + target_spec = dict(target_spec) + + if atomic_action_class == "PickUpAction": + if control != "arm" or target_field != "target_object": + raise ValueError("PickUpAction requires control='arm' and target_object.") + _validate_target_object(target_spec) + return target_field, target_spec + + if atomic_action_class == "PlaceAction": + if control != "arm" or target_field != "target_pose": + raise ValueError("PlaceAction requires control='arm' and target_pose.") + _validate_target_pose(target_spec) + return target_field, target_spec + + if target_field == "target_pose": + if control != "arm": + raise ValueError("MoveAction target_pose requires control='arm'.") + _validate_target_pose(target_spec) + return target_field, target_spec + + if target_field == "target_qpos": + _validate_target_qpos(target_spec, control=control) + return target_field, target_spec + + raise ValueError("MoveAction requires target_pose or target_qpos.") + + +def _validate_target_object(target_object: Mapping[str, Any]) -> None: + unknown_fields = set(target_object) - {"obj_name", "affordance"} + if unknown_fields: + raise ValueError( + f"Unsupported target_object fields: {', '.join(sorted(unknown_fields))}." + ) + obj_name = target_object.get("obj_name") + if not isinstance(obj_name, str) or not obj_name: + raise ValueError("target_object requires non-empty obj_name.") + affordance = target_object.get("affordance", "antipodal") + if affordance != "antipodal": + raise ValueError("target_object only supports affordance='antipodal'.") + + +def _validate_target_pose(target_pose: Mapping[str, Any]) -> None: + reference = target_pose.get("reference") + if reference not in SUPPORTED_POSE_REFERENCES: + raise ValueError( + f"target_pose reference must be one of {sorted(SUPPORTED_POSE_REFERENCES)}." + ) + + if reference == "object": + _validate_target_fields( + target_pose, + {"reference", "obj_name", "offset"}, + "target_pose", + ) + obj_name = target_pose.get("obj_name") + if not isinstance(obj_name, str) or not obj_name: + raise ValueError("object target_pose requires non-empty obj_name.") + _xyz(target_pose.get("offset", [0.0, 0.0, 0.0]), "offset") + return + + if reference == "absolute": + _validate_target_fields( + target_pose, + {"reference", "position"}, + "target_pose", + ) + position = target_pose.get("position") + if not isinstance(position, list) or len(position) != 3: + raise ValueError( + "absolute target_pose requires position with three entries." + ) + return + + _validate_target_fields( + target_pose, + {"reference", "offset", "frame"}, + "target_pose", + ) + _xyz(target_pose.get("offset", [0.0, 0.0, 0.0]), "offset") + frame = target_pose.get("frame", "world") + if frame not in {"world", "eef"}: + raise ValueError("relative target_pose frame must be 'world' or 'eef'.") + + +def _validate_target_qpos( + target_qpos: Mapping[str, Any], + *, + control: str, +) -> None: + source = target_qpos.get("source") + if source not in SUPPORTED_QPOS_SOURCES: + raise ValueError( + f"target_qpos source must be one of {sorted(SUPPORTED_QPOS_SOURCES)}." + ) + + if source == "initial": + _validate_target_fields(target_qpos, {"source"}, "target_qpos") + if control != "arm": + raise ValueError("initial target_qpos requires control='arm'.") + return + + if source == "gripper_state": + _validate_target_fields(target_qpos, {"source", "state"}, "target_qpos") + if control != "hand": + raise ValueError("gripper_state target_qpos requires control='hand'.") + state = target_qpos.get("state") + if state not in {"open", "close"}: + raise ValueError( + "gripper_state target_qpos state must be 'open' or 'close'." + ) + return + + _validate_target_fields( + target_qpos, + {"source", "joint_index", "delta_degrees"}, + "target_qpos", + ) + if control != "arm": + raise ValueError("joint_delta target_qpos requires control='arm'.") + if "joint_index" not in target_qpos: + raise ValueError("joint_delta target_qpos requires joint_index.") + int(target_qpos["joint_index"]) + float(target_qpos.get("delta_degrees", 0.0)) + + +def _validate_target_fields( + target_spec: Mapping[str, Any], + allowed_fields: set[str], + target_name: str, +) -> None: + unknown_fields = set(target_spec) - allowed_fields + if unknown_fields: + raise ValueError( + f"Unsupported {target_name} fields: {', '.join(sorted(unknown_fields))}." + ) + + +def execute_atomic_action( + action_spec: Mapping[str, Any] | AtomicActionSpec, + *, + env, + **runtime_kwargs, +) -> np.ndarray: + """Execute one atomic action spec and return local arm+eef qpos actions.""" + spec = ( + action_spec + if isinstance(action_spec, AtomicActionSpec) + else AtomicActionSpec.from_mapping(action_spec) + ) + if spec.atomic_action_class == "MoveAction" and spec.target_qpos: + action_np = _execute_move_qpos_action(env, spec) + action_np = _append_hold_steps( + action_np, + int(spec.cfg.get("post_hold_steps", 0)), + "atomic qpos action", + ) + _sync_agent_state_from_atomic_action( + env, + spec.robot_name, + action_np, + spec.control, + ) + log_info( + "Using action-agent qpos action: " + f"control={spec.control}, target={_target_summary(spec)}, " + f"steps={len(action_np)}.", + color="green", + ) + return action_np + + target = _resolve_target(env, spec, runtime_kwargs) + is_left, arm_part, hand_part, arm_joints, eef_joints = _select_arm_parts( + env, spec.robot_name + ) + cfg = _build_action_cfg(env, spec, arm_part, hand_part, len(eef_joints)) + start_qpos = _resolve_action_start_qpos( + env, + spec, + is_left=is_left, + arm_joints=arm_joints, + eef_joints=eef_joints, + ) + action_cls = _get_atomic_action_class(spec.atomic_action_class) + action = action_cls(motion_generator=_make_motion_generator(env), cfg=cfg) + is_success, trajectory, joint_ids = action.execute( + target=target, + start_qpos=start_qpos, + ) + if not is_success: + raise RuntimeError( + f"Atomic action failed: atomic_action_class={spec.atomic_action_class}, " + f"robot_name={spec.robot_name}, target={_target_summary(spec)}." + ) + + action_np = _trajectory_to_agent_action( + env, + spec.robot_name, + trajectory, + joint_ids, + ) + action_np = _append_hold_steps( + action_np, + int(spec.cfg.get("post_hold_steps", 0)), + "atomic action", + ) + _sync_agent_state_from_atomic_action(env, spec.robot_name, action_np, spec.control) + log_info( + "Using atomic action: " + f"atomic_action_class={spec.atomic_action_class}, cfg={cfg.__class__.__name__}, " + f"control={spec.control}, target={_target_summary(spec)}, " + f"steps={len(action_np)}.", + color="green", + ) + return action_np + + +def execute_parallel_atomic_actions( + left_arm_action=None, + right_arm_action=None, + env=None, + return_result: bool = False, + **runtime_kwargs, +): + """Execute left/right atomic action specs as one synchronized stream.""" + left_arm_action = _resolve_action_spec(left_arm_action, env, runtime_kwargs) + right_arm_action = _resolve_action_spec(right_arm_action, env, runtime_kwargs) + + left_arm_action = _as_2d_action(left_arm_action, "left_arm_action") + right_arm_action = _as_2d_action(right_arm_action, "right_arm_action") + arm_actions = {"left": left_arm_action, "right": right_arm_action} + + if all(action is None for action in arm_actions.values()): + raise ValueError("At least one atomic arm action must be provided.") + + action_len = max( + len(action) for action in arm_actions.values() if action is not None + ) + for side, action in arm_actions.items(): + if action is not None and len(action) < action_len: + diff = action_len - len(action) + padding = np.repeat(action[-1:], diff, axis=0) + arm_actions[side] = np.concatenate([action, padding], axis=0) + + current_qpos = ( + env.robot.get_qpos().squeeze(0).detach().cpu().numpy().astype(np.float32) + ) + actions = np.repeat(current_qpos[None, :], action_len, axis=0) + + for side, action in arm_actions.items(): + if action is None: + continue + + arm_index = list(getattr(env, f"{side}_arm_joints", [])) + list( + getattr(env, f"{side}_eef_joints", []) + ) + if not arm_index: + raise ValueError( + f"{side}_arm_action was provided, but {side}_arm is not configured " + f"on robot control parts {getattr(env.robot, 'control_parts', None)}." + ) + if action.shape[-1] != len(arm_index): + raise ValueError( + f"{side}_arm_action width {action.shape[-1]} does not match " + f"{side}_arm joints plus eef joints ({len(arm_index)})." + ) + actions[:, arm_index] = action + + actions = torch.from_numpy(actions).to(dtype=torch.float32).unsqueeze(1) + actions = list(actions.unbind(dim=0)) + + for action in tqdm(actions): + env.step(action) + env.update_obj_info() + + if return_result: + return { + "actions": actions, + } + return actions + + +def _resolve_action_spec(action_spec, env, runtime_kwargs: dict[str, Any]): + if action_spec is None: + return None + if isinstance(action_spec, np.ndarray): + return action_spec + if isinstance(action_spec, torch.Tensor): + return action_spec + return execute_atomic_action(action_spec, env=env, **runtime_kwargs) + + +def _execute_move_qpos_action(env, spec: AtomicActionSpec) -> np.ndarray: + """Execute MoveAction target_qpos locally without extending core MoveAction.""" + target_qpos = _resolve_qpos_target(env, spec) + start_qpos, joint_ids = _qpos_start_and_joint_ids(env, spec) + target_qpos = _resolve_batched_qpos( + target_qpos, + expected_dof=len(joint_ids), + device=env.robot.device, + name="target_qpos", + ) + sample_interval = int(spec.cfg.get("sample_interval", 80)) + trajectory = _interpolate_qpos_trajectory( + start_qpos, + target_qpos, + sample_interval, + ) + return _trajectory_to_agent_action( + env, + spec.robot_name, + trajectory, + joint_ids, + ) + + +def _qpos_start_and_joint_ids( + env, + spec: AtomicActionSpec, +) -> tuple[torch.Tensor, list[int]]: + is_left, _, _, arm_joints, eef_joints = _select_arm_parts(env, spec.robot_name) + if spec.control == "hand": + _, _, _, _, current_gripper_state = get_arm_states(env, spec.robot_name) + start_qpos = _state_to_hand_qpos( + current_gripper_state, + len(eef_joints), + env.robot.device, + ) + return start_qpos.reshape(1, len(eef_joints)), eef_joints + return _current_arm_qpos(env, is_left, arm_joints), arm_joints + + +def _resolve_batched_qpos( + qpos, + *, + expected_dof: int, + device, + name: str, +) -> torch.Tensor: + qpos = torch.as_tensor(qpos, dtype=torch.float32, device=device) + if qpos.shape == (expected_dof,): + qpos = qpos.reshape(1, expected_dof) + if qpos.ndim != 2 or qpos.shape[1] != expected_dof: + raise ValueError( + f"{name} must have shape ({expected_dof},) or (num_envs, {expected_dof}), " + f"got {tuple(qpos.shape)}." + ) + return qpos + + +def _interpolate_qpos_trajectory( + start_qpos: torch.Tensor, + target_qpos: torch.Tensor, + sample_interval: int, +) -> torch.Tensor: + if sample_interval < 2: + raise ValueError("sample_interval must be at least 2 for qpos interpolation.") + if target_qpos.shape[0] == 1 and start_qpos.shape[0] > 1: + target_qpos = target_qpos.repeat(start_qpos.shape[0], 1) + if start_qpos.shape != target_qpos.shape: + raise ValueError( + f"start_qpos and target_qpos must have matching shapes, got " + f"{tuple(start_qpos.shape)} and {tuple(target_qpos.shape)}." + ) + weights = torch.linspace( + 0.0, + 1.0, + steps=sample_interval, + dtype=start_qpos.dtype, + device=start_qpos.device, + ).reshape(1, sample_interval, 1) + return ( + start_qpos.unsqueeze(1) + + (target_qpos.unsqueeze(1) - start_qpos.unsqueeze(1)) * weights + ) + + +def _select_arm_parts(env, robot_name: str): + is_left = resolve_arm_side(env, robot_name) == "left" + if hasattr(env, "get_agent_arm_control_part"): + arm_part = env.get_agent_arm_control_part(is_left) + hand_part = env.get_agent_eef_control_part(is_left) + else: + arm_part = "left_arm" if is_left else "right_arm" + hand_part = "left_eef" if is_left else "right_eef" + arm_joints = env.left_arm_joints if is_left else env.right_arm_joints + eef_joints = env.left_eef_joints if is_left else env.right_eef_joints + return is_left, arm_part, hand_part, list(arm_joints), list(eef_joints) + + +def _make_motion_generator(env): + return MotionGenerator( + cfg=MotionGenCfg(planner_cfg=ToppraPlannerCfg(robot_uid=env.robot.uid)) + ) + + +def _get_atomic_action_class(atomic_action_class: str): + action_class, _ = ATOMIC_ACTION_REGISTRY[atomic_action_class] + return action_class + + +def _build_action_cfg( + env, + spec: AtomicActionSpec, + arm_part: str, + hand_part: str, + hand_dof: int, +): + cfg_values = dict(spec.cfg) + cfg_values.pop("post_hold_steps", None) + device = env.robot.device + + if spec.atomic_action_class == "PickUpAction": + if spec.control != "arm": + raise ValueError("PickUpAction atomic action requires control='arm'.") + return PickUpActionCfg( + control_part=arm_part, + hand_control_part=hand_part, + hand_open_qpos=_state_to_hand_qpos(env.open_state, hand_dof, device), + hand_close_qpos=_state_to_hand_qpos(env.close_state, hand_dof, device), + **_cfg_supported_kwargs(PickUpActionCfg, cfg_values), + ) + + if spec.atomic_action_class == "PlaceAction": + if spec.control != "arm": + raise ValueError("PlaceAction atomic action requires control='arm'.") + return PlaceActionCfg( + control_part=arm_part, + hand_control_part=hand_part, + hand_open_qpos=_state_to_hand_qpos(env.open_state, hand_dof, device), + hand_close_qpos=_state_to_hand_qpos(env.close_state, hand_dof, device), + **_cfg_supported_kwargs(PlaceActionCfg, cfg_values), + ) + + control_part = arm_part if spec.control == "arm" else hand_part + return MoveActionCfg( + control_part=control_part, + **_cfg_supported_kwargs(MoveActionCfg, cfg_values), + ) + + +def _resolve_action_start_qpos( + env, + spec: AtomicActionSpec, + *, + is_left: bool, + arm_joints: list[int], + eef_joints: list[int], +): + if spec.control == "hand": + _, _, _, _, current_gripper_state = get_arm_states(env, spec.robot_name) + return _state_to_hand_qpos( + current_gripper_state, + len(eef_joints), + env.robot.device, + ).reshape(1, len(eef_joints)) + return _current_arm_qpos(env, is_left, arm_joints) + + +def _resolve_target(env, spec: AtomicActionSpec, runtime_kwargs: dict[str, Any]): + if spec.atomic_action_class == "PickUpAction": + return _resolve_pickup_target(env, spec, runtime_kwargs) + if spec.atomic_action_class == "MoveAction": + return _resolve_move_target(env, spec) + if spec.atomic_action_class == "PlaceAction": + return _resolve_place_target(env, spec) + raise ValueError(f"Unsupported atomic action class: {spec.atomic_action_class}.") + + +def _resolve_pickup_target( + env, + spec: AtomicActionSpec, + runtime_kwargs: dict[str, Any], +): + if not spec.target_object: + raise ValueError("PickUpAction requires target_object.") + return _build_object_semantics(env, spec.target_object, runtime_kwargs) + + +def _resolve_move_target(env, spec: AtomicActionSpec): + if spec.target_pose: + return _resolve_pose_target(env, spec) + if spec.target_qpos: + return _resolve_qpos_target(env, spec) + raise ValueError("MoveAction requires target_pose or target_qpos.") + + +def _resolve_place_target(env, spec: AtomicActionSpec): + if not spec.target_pose: + raise ValueError("PlaceAction requires target_pose.") + return _resolve_pose_target(env, spec) + + +def _resolve_pose_target(env, spec: AtomicActionSpec): + reference = spec.target_pose["reference"] + if reference == "object": + return _resolve_object_pose_target(env, spec) + if reference == "absolute": + return _resolve_absolute_pose_target(env, spec) + if reference == "relative": + return _resolve_relative_pose_target(env, spec) + raise ValueError(f"Unsupported target_pose reference: {reference}.") + + +def _resolve_qpos_target(env, spec: AtomicActionSpec): + source = spec.target_qpos["source"] + if source == "initial": + return _resolve_initial_qpos_target(env, spec) + if source == "gripper_state": + return _resolve_gripper_qpos_target(env, spec) + if source == "joint_delta": + return _resolve_joint_delta_qpos_target(env, spec) + raise ValueError(f"Unsupported target_qpos source: {source}.") + + +def _resolve_object_pose_target(env, spec: AtomicActionSpec): + obj_name = spec.target_pose.get("obj_name") + target_obj = env.sim.get_rigid_object(obj_name) + if target_obj is None: + raise ValueError(f"No rigid object found for {obj_name}.") + offset = _xyz(spec.target_pose.get("offset", [0.0, 0.0, 0.0]), "offset") + _, _, _, current_pose, _ = get_arm_states(env, spec.robot_name) + target_pose = deepcopy(current_pose) + target_obj_pose = target_obj.get_local_pose(to_matrix=True).squeeze(0) + target_pose[:3, 3] = target_obj_pose[:3, 3] + target_pose[0, 3] += offset[0] + target_pose[1, 3] += offset[1] + target_pose[2, 3] += offset[2] + return torch.as_tensor(target_pose, dtype=torch.float32, device=env.robot.device) + + +def _resolve_absolute_pose_target(env, spec: AtomicActionSpec): + position = spec.target_pose.get("position") + if not isinstance(position, list) or len(position) != 3: + raise ValueError("absolute target_pose requires position with three entries.") + _, _, _, current_pose, _ = get_arm_states(env, spec.robot_name) + target_pose = deepcopy(current_pose) + for index, value in enumerate(position): + if value is not None: + target_pose[index, 3] = float(value) + return torch.as_tensor(target_pose, dtype=torch.float32, device=env.robot.device) + + +def _resolve_relative_pose_target(env, spec: AtomicActionSpec): + offset = _xyz(spec.target_pose.get("offset", [0.0, 0.0, 0.0]), "offset") + frame = spec.target_pose.get("frame", "world") + if frame not in {"world", "eef"}: + raise ValueError("relative target_pose frame must be 'world' or 'eef'.") + mode = "extrinsic" if frame == "world" else "intrinsic" + _, _, _, current_pose, _ = get_arm_states(env, spec.robot_name) + target_pose = deepcopy(current_pose) + target_pose = get_offset_pose(target_pose, offset[0], "x", mode) + target_pose = get_offset_pose(target_pose, offset[1], "y", mode) + target_pose = get_offset_pose(target_pose, offset[2], "z", mode) + return torch.as_tensor(target_pose, dtype=torch.float32, device=env.robot.device) + + +def _resolve_initial_qpos_target(env, spec: AtomicActionSpec): + if spec.control != "arm": + raise ValueError("initial target_qpos requires control='arm'.") + is_left, _, _, _, _ = _select_arm_parts(env, spec.robot_name) + target_qpos = env.left_arm_init_qpos if is_left else env.right_arm_init_qpos + return torch.as_tensor(target_qpos, dtype=torch.float32, device=env.robot.device) + + +def _resolve_gripper_qpos_target(env, spec: AtomicActionSpec): + if spec.control != "hand": + raise ValueError("gripper_state target_qpos requires control='hand'.") + state = spec.target_qpos.get("state") + if state == "open": + source = env.open_state + elif state == "close": + source = env.close_state + else: + raise ValueError("gripper_state target_qpos state must be 'open' or 'close'.") + _, _, _, _, eef_joints = _select_arm_parts(env, spec.robot_name) + return _state_to_hand_qpos(source, len(eef_joints), env.robot.device) + + +def _resolve_joint_delta_qpos_target(env, spec: AtomicActionSpec): + if spec.control != "arm": + raise ValueError("joint_delta target_qpos requires control='arm'.") + joint_index = int(spec.target_qpos["joint_index"]) + delta_degrees = float(spec.target_qpos.get("delta_degrees", 0.0)) + _, _, current_qpos, _, _ = get_arm_states(env, spec.robot_name) + target_qpos = torch.as_tensor( + current_qpos, + dtype=torch.float32, + device=env.robot.device, + ).clone() + if joint_index < 0 or joint_index >= target_qpos.numel(): + raise ValueError(f"joint_index {joint_index} is out of range.") + target_qpos[joint_index] += float(np.deg2rad(delta_degrees)) + return target_qpos + + +def _target_summary(spec: AtomicActionSpec) -> str: + if spec.target_object: + return f"target_object:{spec.target_object.get('obj_name')}" + if spec.target_pose: + return f"target_pose:{spec.target_pose.get('reference')}" + if spec.target_qpos: + return f"target_qpos:{spec.target_qpos.get('source')}" + return "target:none" + + +def _build_object_semantics( + env, + target: Mapping[str, Any], + runtime_kwargs: dict[str, Any], +): + obj_name = target.get("obj_name") + if target.get("affordance", "antipodal") != "antipodal": + raise ValueError("target_object only supports antipodal affordance.") + target_obj = env.sim.get_rigid_object(obj_name) + if target_obj is None: + raise ValueError(f"No rigid object found for {obj_name}.") + + _stabilize_affordance_object(env, target_obj, runtime_kwargs) + + mesh_vertices = target_obj.get_vertices(env_ids=[0], scale=True)[0] + mesh_triangles = target_obj.get_triangles(env_ids=[0])[0] + mesh_vertices = torch.as_tensor(mesh_vertices, dtype=torch.float32) + mesh_triangles = torch.as_tensor(mesh_triangles, dtype=torch.int64) + if ( + mesh_vertices.numel() == 0 + or mesh_triangles.numel() == 0 + or mesh_vertices.shape[-1] != 3 + or mesh_triangles.shape[-1] != 3 + ): + raise ValueError(f"Object {obj_name} has empty or invalid mesh geometry.") + + allow_annotation = bool(runtime_kwargs.get("allow_grasp_annotation", True)) + force_reannotate = bool(runtime_kwargs.get("force_grasp_reannotate", False)) + cache_path = _affordance_cache_path(mesh_vertices, mesh_triangles) + if not os.path.exists(cache_path) and not allow_annotation: + raise RuntimeError( + "Grasp annotation cache is missing and annotation is disabled; " + "set allow_grasp_annotation=True." + ) + + antipodal_sampler_cfg = AntipodalSamplerCfg( + **_cfg_supported_kwargs( + AntipodalSamplerCfg, + { + "n_sample": int(runtime_kwargs.get("grasp_antipodal_n_sample", 20000)), + "max_angle": runtime_kwargs.get( + "grasp_antipodal_max_angle", np.pi / 12 + ), + "max_length": runtime_kwargs.get("grasp_max_open_length", 0.088), + "min_length": runtime_kwargs.get("grasp_min_open_length", 0.003), + }, + ) + ) + generator_cfg = GraspGeneratorCfg( + **_cfg_supported_kwargs( + GraspGeneratorCfg, + { + "viser_port": int(runtime_kwargs.get("grasp_viser_port", 11801)), + "antipodal_sampler_cfg": antipodal_sampler_cfg, + "max_deviation_angle": runtime_kwargs.get( + "grasp_max_deviation_angle", + np.pi / 6, + ), + }, + ) + ) + max_decomposition_hulls = _max_decomposition_hulls(target_obj, runtime_kwargs) + source_mesh_path = _rigid_object_mesh_path(target_obj) + body_scale = _rigid_object_body_scale(target_obj) + _prepare_grasp_collision_cache_from_env_coacd( + obj_name=obj_name, + mesh_vertices=mesh_vertices, + mesh_triangles=mesh_triangles, + source_mesh_path=source_mesh_path, + max_decomposition_hulls=max_decomposition_hulls, + body_scale=body_scale, + runtime_kwargs=runtime_kwargs, + ) + + gripper_collision_cfg = GripperCollisionCfg( + **_cfg_supported_kwargs( + GripperCollisionCfg, + { + "max_open_length": runtime_kwargs.get("grasp_max_open_length", 0.088), + "finger_length": runtime_kwargs.get("grasp_finger_length", 0.078), + "point_sample_dense": runtime_kwargs.get( + "grasp_point_sample_dense", + 0.012, + ), + "max_decomposition_hulls": max_decomposition_hulls, + "env_coacd_source_mesh_path": source_mesh_path, + "env_coacd_body_scale": body_scale, + }, + ) + ) + affordance = AntipodalAffordance( + object_label=obj_name, + force_reannotate=force_reannotate, + custom_config={ + "gripper_collision_cfg": gripper_collision_cfg, + "generator_cfg": generator_cfg, + }, + ) + return ObjectSemantics( + label=obj_name, + geometry={ + "mesh_vertices": mesh_vertices, + "mesh_triangles": mesh_triangles, + }, + affordance=affordance, + entity=target_obj, + ) + + +def _prepare_grasp_collision_cache_from_env_coacd( + *, + obj_name: str, + mesh_vertices: torch.Tensor, + mesh_triangles: torch.Tensor, + source_mesh_path: str | None, + max_decomposition_hulls: int, + body_scale: list[float] | None, + runtime_kwargs: Mapping[str, Any], +) -> None: + if not bool(runtime_kwargs.get("reuse_env_coacd_for_grasp", True)): + return + + try: + from embodichain.gen_sim.action_agent_pipeline.runtime.coacd_cache_bridge import ( + ensure_grasp_collision_cache_from_env_coacd, + ) + + result = ensure_grasp_collision_cache_from_env_coacd( + mesh_vertices=mesh_vertices, + mesh_triangles=mesh_triangles, + source_mesh_path=source_mesh_path, + max_decomposition_hulls=max_decomposition_hulls, + body_scale=body_scale, + ) + except Exception: + return + + if result.get("status") == "generated": + log_info( + "Prepared grasp collision cache from environment CoACD cache: " + f"target={obj_name}, cache={result.get('grasp_cache_path')}.", + color="green", + ) + + +def _stabilize_affordance_object( + env, + target_obj, + runtime_kwargs: Mapping[str, Any], +) -> None: + if not bool(runtime_kwargs.get("stabilize_affordance_object", True)): + return + + update_steps = int(runtime_kwargs.get("affordance_stabilization_steps", 5)) + if update_steps > 0 and hasattr(env.sim, "update"): + env.sim.update(step=update_steps) + if hasattr(target_obj, "clear_dynamics"): + target_obj.clear_dynamics() + + +def _trajectory_to_agent_action(env, robot_name, trajectory, joint_ids): + _, _, current_arm_qpos, _, current_gripper_state = get_arm_states(env, robot_name) + _, _, _, arm_joints, eef_joints = _select_arm_parts(env, robot_name) + + if isinstance(trajectory, torch.Tensor): + trajectory = trajectory.detach() + else: + trajectory = torch.as_tensor(trajectory) + + if trajectory.dim() == 3: + trajectory = trajectory[0] + if trajectory.dim() != 2 or trajectory.shape[0] == 0: + raise ValueError( + "Atomic action trajectory must have shape (T, D) or (N, T, D), " + f"got {trajectory.shape}." + ) + + joint_ids = [int(joint_id) for joint_id in joint_ids] + if len(joint_ids) != trajectory.shape[-1]: + raise ValueError( + f"Atomic action joint_ids length {len(joint_ids)} does not match " + f"trajectory width {trajectory.shape[-1]}." + ) + + device = trajectory.device + agent_action = torch.cat( + [ + torch.as_tensor( + current_arm_qpos, dtype=torch.float32, device=device + ).flatten(), + _state_to_hand_qpos(current_gripper_state, len(eef_joints), device), + ], + dim=0, + ) + agent_action = agent_action.unsqueeze(0).repeat(trajectory.shape[0], 1) + + joint_id_to_col = {joint_id: col for col, joint_id in enumerate(joint_ids)} + for out_col, joint_id in enumerate(arm_joints + eef_joints): + if joint_id in joint_id_to_col: + agent_action[:, out_col] = trajectory[:, joint_id_to_col[joint_id]] + + return agent_action.detach().cpu().numpy().astype(np.float32) + + +def _sync_agent_state_from_atomic_action(env, robot_name, action_np, control): + if action_np is None or len(action_np) == 0: + raise ValueError("Atomic action is empty; cannot sync agent state.") + + is_left, _, _, arm_joints, eef_joints = _select_arm_parts(env, robot_name) + final_action = np.asarray(action_np[-1], dtype=np.float32) + arm_dof = len(arm_joints) + + if control == "arm": + arm_qpos = torch.as_tensor( + final_action[:arm_dof], + dtype=torch.float32, + device=env.robot.device, + ) + env.set_current_qpos_agent(arm_qpos, is_left=is_left) + env.set_current_xpos_agent( + env.get_arm_fk(qpos=arm_qpos, is_left=is_left), + is_left=is_left, + ) + + if len(eef_joints) == 0: + return + + _, _, _, _, current_gripper_state = get_arm_states(env, robot_name) + eef_qpos = final_action[arm_dof : arm_dof + len(eef_joints)] + state_dof = max(int(torch.as_tensor(current_gripper_state).numel()), 1) + if len(eef_qpos) >= state_dof: + gripper_qpos = eef_qpos[:state_dof] + else: + gripper_qpos = np.resize(eef_qpos, state_dof) + + current_gripper_state = torch.as_tensor(current_gripper_state) + env.set_current_gripper_state_agent( + torch.as_tensor( + gripper_qpos, + dtype=current_gripper_state.dtype, + device=current_gripper_state.device, + ), + is_left=is_left, + ) + + +def _current_arm_qpos(env, is_left: bool, arm_joints: list[int]) -> torch.Tensor: + source = env.left_arm_current_qpos if is_left else env.right_arm_current_qpos + return torch.as_tensor( + source, + dtype=torch.float32, + device=env.robot.device, + ).reshape(1, len(arm_joints)) + + +def _state_to_hand_qpos(state, hand_dof: int, device): + if hand_dof <= 0: + return torch.empty(0, dtype=torch.float32, device=device) + + state = torch.as_tensor(state, dtype=torch.float32, device=device).flatten() + if state.numel() == 0: + return torch.zeros(hand_dof, dtype=torch.float32, device=device) + if state.numel() == hand_dof: + return state + if state.numel() == 1: + return state.repeat(hand_dof) + if state.numel() > hand_dof: + return state[:hand_dof] + + repeat_num = int(np.ceil(hand_dof / state.numel())) + return state.repeat(repeat_num)[:hand_dof] + + +def _as_2d_action(action, action_name: str): + if action is None: + return None + if isinstance(action, torch.Tensor): + action = action.detach().cpu().numpy() + action = np.asarray(action, dtype=np.float32) + if action.ndim == 1: + action = action[None, :] + if action.ndim != 2 or len(action) == 0: + raise ValueError( + f"{action_name} must have shape (T, D) with T > 0, got {action.shape}." + ) + return action + + +def _append_hold_steps(action_np, hold_steps: int, log_name: str): + hold_steps = int(hold_steps) + if hold_steps <= 0: + return action_np + if action_np is None or len(action_np) == 0: + raise ValueError(f"{log_name} action is empty; cannot append hold steps.") + + hold_actions = np.repeat(action_np[-1:], hold_steps, axis=0) + action_np = np.concatenate([action_np, hold_actions], axis=0) + log_info( + f"Append {hold_steps} hold steps after {log_name}; " + f"total trajectory length is {len(action_np)}.", + color="green", + ) + return action_np + + +def _cfg_supported_kwargs(cfg_cls, values: Mapping[str, Any]): + supported = set() + for cls in reversed(cfg_cls.__mro__): + supported.update(getattr(cls, "__annotations__", {}).keys()) + return {key: value for key, value in values.items() if key in supported} + + +def _affordance_cache_path(mesh_vertices, mesh_triangles): + vert_bytes = mesh_vertices.to("cpu").numpy().tobytes() + face_bytes = mesh_triangles.to("cpu").numpy().tobytes() + md5_hash = hashlib.md5(vert_bytes + face_bytes).hexdigest() + return os.path.join(GRASP_ANNOTATOR_CACHE_DIR, f"antipodal_cache_{md5_hash}.npy") + + +def _rigid_object_mesh_path(obj) -> str | None: + shape = getattr(getattr(obj, "cfg", None), "shape", None) + fpath = getattr(shape, "fpath", None) + return str(fpath) if fpath else None + + +def _rigid_object_body_scale(obj) -> list[float] | None: + body_scale = obj.get_body_scale(env_ids=[0])[0] + return body_scale.detach().to("cpu", dtype=torch.float32).tolist() + + +def _max_decomposition_hulls(target_obj, runtime_kwargs: Mapping[str, Any]) -> int: + if "grasp_max_decomposition_hulls" in runtime_kwargs: + return int(runtime_kwargs["grasp_max_decomposition_hulls"]) + + max_convex_hull_num = getattr( + getattr(target_obj, "cfg", None), + "max_convex_hull_num", + None, + ) + if max_convex_hull_num is not None and int(max_convex_hull_num) > 1: + return int(max_convex_hull_num) + return 8 + + +def _xyz(value, field_name: str) -> list[float]: + if not isinstance(value, list) or len(value) != 3: + raise ValueError(f"{field_name} must be a three-element list.") + return [float(item) for item in value] diff --git a/embodichain/gen_sim/action_agent_pipeline/runtime/coacd_cache_bridge.py b/embodichain/gen_sim/action_agent_pipeline/runtime/coacd_cache_bridge.py new file mode 100644 index 00000000..b0212fec --- /dev/null +++ b/embodichain/gen_sim/action_agent_pipeline/runtime/coacd_cache_bridge.py @@ -0,0 +1,211 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +import hashlib +import pickle +from pathlib import Path +from typing import Any + +import numpy as np +import torch + +from embodichain.gen_sim.action_agent_pipeline.generation.coacd_cache import ( + coacd_cache_path_for_mesh, +) + +__all__ = [ + "ensure_grasp_collision_cache_from_env_coacd", + "grasp_collision_cache_path", +] + + +_DEFAULT_CONVEX_DECOMP_DIR = ( + Path.home() / ".cache" / "embodichain_cache" / "convex_decomposition" +) + + +def grasp_collision_cache_path( + mesh_vertices: torch.Tensor | np.ndarray, + mesh_triangles: torch.Tensor | np.ndarray, + max_decomposition_hulls: int, + *, + cache_dir: str | Path | None = None, +) -> Path: + """Return the grasp collision checker cache path for a scaled mesh.""" + + vertices = _as_numpy(mesh_vertices) + triangles = _as_numpy(mesh_triangles) + mesh_hash = hashlib.md5(vertices.tobytes() + triangles.tobytes()).hexdigest() + return _resolve_cache_dir(cache_dir) / ( + f"{mesh_hash}_{int(max_decomposition_hulls)}.pkl" + ) + + +def ensure_grasp_collision_cache_from_env_coacd( + *, + mesh_vertices: torch.Tensor | np.ndarray, + mesh_triangles: torch.Tensor | np.ndarray, + source_mesh_path: str | Path | None, + max_decomposition_hulls: int, + body_scale: Any = None, + cache_dir: str | Path | None = None, +) -> dict[str, Any]: + """Prepare grasp collision cache from the environment CoACD OBJ cache. + + The environment and grasp collision paths use different cache formats. This + bridge avoids running CoACD again during grasp annotation when the + environment-side convex OBJ cache is already available. + """ + + grasp_cache_path = grasp_collision_cache_path( + mesh_vertices, + mesh_triangles, + max_decomposition_hulls, + cache_dir=cache_dir, + ) + if grasp_cache_path.is_file(): + return { + "status": "hit", + "grasp_cache_path": grasp_cache_path.as_posix(), + } + + if source_mesh_path is None: + return { + "status": "missing_source_mesh", + "grasp_cache_path": grasp_cache_path.as_posix(), + } + + env_cache_path = coacd_cache_path_for_mesh( + source_mesh_path, + max_decomposition_hulls, + _resolve_cache_dir(cache_dir), + ) + if not env_cache_path.is_file(): + return { + "status": "missing_env_cache", + "env_cache_path": env_cache_path.as_posix(), + "grasp_cache_path": grasp_cache_path.as_posix(), + } + + try: + plane_equations = _plane_equations_from_env_cache(env_cache_path, body_scale) + _write_grasp_collision_cache(grasp_cache_path, plane_equations) + except Exception as exc: + return { + "status": "skipped", + "reason": str(exc), + "env_cache_path": env_cache_path.as_posix(), + "grasp_cache_path": grasp_cache_path.as_posix(), + } + + return { + "status": "generated", + "env_cache_path": env_cache_path.as_posix(), + "grasp_cache_path": grasp_cache_path.as_posix(), + } + + +def _plane_equations_from_env_cache( + env_cache_path: Path, + body_scale: Any, +) -> list[tuple[np.ndarray, np.ndarray]]: + from dexsim.kit.meshproc.convex_cache import load_obj_as_convex_parts + + from embodichain.toolkits.graspkit.pg_grasp.collision_checker import ( + extract_plane_equations, + ) + + convex_parts = load_obj_as_convex_parts(env_cache_path.as_posix()) + if not convex_parts: + raise ValueError(f"No convex parts found in {env_cache_path}.") + + scale = _body_scale(body_scale) + if not np.allclose(scale, np.ones(3, dtype=np.float32)): + convex_parts = [ + (vertices.astype(np.float32, copy=False) * scale, faces) + for vertices, faces in convex_parts + ] + + plane_equations = extract_plane_equations(convex_parts) + if not plane_equations: + raise ValueError(f"No plane equations extracted from {env_cache_path}.") + return plane_equations + + +def _write_grasp_collision_cache( + cache_path: Path, + plane_equations_np: list[tuple[np.ndarray, np.ndarray]], +) -> None: + cache_path.parent.mkdir(parents=True, exist_ok=True) + n_convex = len(plane_equations_np) + n_max_equation = max(normals.shape[0] for normals, _ in plane_equations_np) + plane_equations = torch.zeros( + size=(n_convex, n_max_equation, 4), + dtype=torch.float32, + device="cpu", + ) + plane_equation_counts = torch.zeros(n_convex, dtype=torch.int32, device="cpu") + for index, (normals, offsets) in enumerate(plane_equations_np): + n_equation = normals.shape[0] + plane_equations[index, :n_equation, :3] = torch.as_tensor( + normals, + dtype=torch.float32, + ) + plane_equations[index, :n_equation, 3] = torch.as_tensor( + offsets, + dtype=torch.float32, + ) + plane_equation_counts[index] = n_equation + + with cache_path.open("wb") as cache_file: + pickle.dump( + { + "plane_equations": plane_equations, + "plane_equation_counts": plane_equation_counts, + }, + cache_file, + ) + + +def _resolve_cache_dir(cache_dir: str | Path | None) -> Path: + if cache_dir is not None: + return Path(cache_dir).expanduser().resolve() + try: + from embodichain.lab.sim import CONVEX_DECOMP_DIR + except Exception: + return _DEFAULT_CONVEX_DECOMP_DIR + return Path(CONVEX_DECOMP_DIR).expanduser().resolve() + + +def _as_numpy(value: torch.Tensor | np.ndarray) -> np.ndarray: + if isinstance(value, torch.Tensor): + value = value.detach().cpu().numpy() + return np.ascontiguousarray(value) + + +def _body_scale(body_scale: Any) -> np.ndarray: + if body_scale is None: + return np.ones(3, dtype=np.float32) + if isinstance(body_scale, torch.Tensor): + body_scale = body_scale.detach().cpu().numpy() + scale = np.asarray(body_scale, dtype=np.float32).reshape(-1) + if scale.size == 1: + scale = np.repeat(scale, 3) + if scale.size != 3 or not np.all(np.isfinite(scale)): + raise ValueError(f"Invalid body scale: {body_scale!r}.") + return scale.reshape(1, 3) diff --git a/embodichain/gen_sim/action_agent_pipeline/runtime/graph_compiler.py b/embodichain/gen_sim/action_agent_pipeline/runtime/graph_compiler.py new file mode 100644 index 00000000..4393c37c --- /dev/null +++ b/embodichain/gen_sim/action_agent_pipeline/runtime/graph_compiler.py @@ -0,0 +1,254 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +import importlib +from collections.abc import Mapping +from pathlib import Path +from typing import Any + +from embodichain.gen_sim.action_agent_pipeline.utils.llm_json import extract_json_object + +__all__ = [ + "compile_agent_graph_from_file", + "compile_agent_graph_spec", + "load_agent_graph_bundle", +] + +_RECOVERY_KEYS = { + "recovery_graph", + "recovery_spec", + "recovery_bindings", + "recovery_nodes", + "recovery_edges", + "recovery_branches", + "recoveries", +} +_COMPILED_BUNDLE_KEYS = {"task_graph", "metadata"} +_EDGE_KEYS = {"id", "source", "target", "left_arm_action", "right_arm_action"} + + +def load_agent_graph_bundle(path: str | Path) -> dict[str, Any]: + """Load a compiled graph JSON bundle from disk.""" + return extract_json_object(Path(path).read_text(encoding="utf-8")) + + +def compile_agent_graph_from_file( + path: str | Path, + *, + graph_cls: type | None = None, + action_module: Any = None, +) -> Any: + """Compile a graph JSON bundle from disk into an executable graph.""" + bundle = load_agent_graph_bundle(path) + if "task_graph" in bundle: + unknown_bundle_keys = set(bundle) - _COMPILED_BUNDLE_KEYS + if unknown_bundle_keys: + raise ValueError( + "Compiled graph artifact contains unsupported top-level fields: " + f"{', '.join(sorted(unknown_bundle_keys))}." + ) + task_graph = bundle["task_graph"] + else: + task_graph = bundle + return compile_agent_graph_spec( + task_graph, + graph_cls=graph_cls, + action_module=action_module, + ) + + +def compile_agent_graph_spec( + task_graph: str | Mapping[str, Any], + *, + graph_cls: type | None = None, + action_module: Any = None, +) -> Any: + """Compile a nominal JSON graph into ``AgentTaskGraph``.""" + task_spec = extract_json_object(task_graph) + _reject_recovery_keys(task_spec) + _validate_task_spec(task_spec) + graph_cls, action_module = _resolve_runtime( + graph_cls=graph_cls, + action_module=action_module, + ) + + graph = graph_cls( + start=task_spec["start"], + goal=task_spec["goal"], + max_transitions=int(task_spec.get("max_transitions", 1000)), + ) + + for node in task_spec.get("nodes", []): + graph.add_node(node["id"], node.get("semantic", "")) + + for edge in task_spec.get("edges", []): + graph.add_edge( + edge["id"], + edge["source"], + edge["target"], + left_arm_action=_compile_action(edge.get("left_arm_action"), action_module), + right_arm_action=_compile_action( + edge.get("right_arm_action"), action_module + ), + ) + + return graph + + +def _resolve_runtime( + *, + graph_cls: type | None, + action_module: Any, +) -> tuple[type, Any]: + if graph_cls is None: + graph_cls = _resolve_attr( + importlib.import_module( + "embodichain.gen_sim.action_agent_pipeline.runtime.task_graph" + ), + "AgentTaskGraph", + ) + if action_module is None: + action_module = importlib.import_module( + "embodichain.gen_sim.action_agent_pipeline.runtime.atom_actions" + ) + return graph_cls, action_module + + +def _validate_task_spec(task_spec: Mapping[str, Any]) -> None: + node_ids = set() + for node in task_spec.get("nodes", []): + node_id = node["id"] + if node_id in node_ids: + raise ValueError(f"Duplicate graph node id '{node_id}'.") + node_ids.add(node_id) + + for required_node in (task_spec["start"], task_spec["goal"]): + if required_node not in node_ids: + raise ValueError(f"Graph node '{required_node}' is not defined.") + + edge_specs = list(task_spec.get("edges", [])) + edge_ids = set() + for edge in edge_specs: + unknown_edge_keys = set(edge) - _EDGE_KEYS + if unknown_edge_keys: + raise ValueError( + f"Nominal edge '{edge.get('id', '')}' contains unsupported " + f"fields: {', '.join(sorted(unknown_edge_keys))}." + ) + edge_id = edge["id"] + if edge_id in edge_ids: + raise ValueError(f"Duplicate graph edge id '{edge_id}'.") + edge_ids.add(edge_id) + if edge.get("left_arm_action") is None and edge.get("right_arm_action") is None: + raise ValueError(f"Nominal edge '{edge_id}' must define an arm action.") + + for node_key in ("source", "target"): + node_id = edge[node_key] + if node_id not in node_ids: + raise ValueError( + f"Edge '{edge_id}' references unknown {node_key} node '{node_id}'." + ) + + _validate_nominal_path(task_spec, edge_specs) + + +def _validate_nominal_path( + task_spec: Mapping[str, Any], + edge_specs: list[Mapping[str, Any]], +) -> None: + outgoing_edges: dict[str, Mapping[str, Any]] = {} + for edge in edge_specs: + source = edge["source"] + if source in outgoing_edges: + raise ValueError( + f"Nominal node '{source}' has multiple outgoing edges. " + "The current graph executor expects one deterministic nominal path." + ) + outgoing_edges[source] = edge + + current = task_spec["start"] + goal = task_spec["goal"] + visited_edges = set() + visited_nodes = {current} + + while current != goal: + edge = outgoing_edges.get(current) + if edge is None: + raise ValueError( + f"Nominal graph has no start-to-goal path from node '{current}'." + ) + edge_id = edge["id"] + if edge_id in visited_edges: + raise ValueError("Nominal graph contains a cycle.") + + visited_edges.add(edge_id) + current = edge["target"] + if current in visited_nodes and current != goal: + raise ValueError("Nominal graph contains a cycle.") + visited_nodes.add(current) + + all_edge_ids = {edge["id"] for edge in edge_specs} + unused_edge_ids = all_edge_ids - visited_edges + if unused_edge_ids: + unused = ", ".join(sorted(unused_edge_ids)) + raise ValueError( + f"Nominal graph contains edges outside the start-to-goal path: {unused}." + ) + + +def _compile_action(spec: Any, action_module: Any) -> Any: + if spec is None: + return None + if isinstance(spec, str) and spec.strip().lower() in {"", "none", "null"}: + return None + if not isinstance(spec, Mapping): + raise TypeError(f"Action spec must be a mapping or null, but got {type(spec)}.") + if "fn" in spec: + raise ValueError( + "Legacy fn/kwargs action schema is not supported. Use atomic action " + "class JSON spec with atomic_action_class, robot_name, control, cfg, " + "and exactly one of target_object, target_pose, or target_qpos." + ) + if "action" in spec: + raise ValueError( + "Legacy action schema is not supported. Use atomic_action_class with " + "PickUpAction, MoveAction, or PlaceAction." + ) + if spec.get("atomic_action_class") is None: + raise ValueError( + "Atomic action class schema requires atomic_action_class, robot_name, " + "control, cfg, and exactly one of target_object, target_pose, or " + "target_qpos." + ) + + return action_module.normalize_atomic_action_spec(spec) + + +def _reject_recovery_keys(task_spec: Mapping[str, Any]) -> None: + present = _RECOVERY_KEYS & set(task_spec) + if present: + raise ValueError( + "Recovery graph fields are no longer supported: " + f"{', '.join(sorted(present))}." + ) + + +def _resolve_attr(namespace: Any, name: str) -> Any: + if isinstance(namespace, Mapping): + return namespace[name] + return getattr(namespace, name) diff --git a/embodichain/gen_sim/action_agent_pipeline/runtime/task_graph.py b/embodichain/gen_sim/action_agent_pipeline/runtime/task_graph.py new file mode 100644 index 00000000..53ea5a1f --- /dev/null +++ b/embodichain/gen_sim/action_agent_pipeline/runtime/task_graph.py @@ -0,0 +1,134 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from collections import defaultdict +from dataclasses import dataclass +from typing import Any + +from embodichain.gen_sim.action_agent_pipeline.runtime.atom_actions import ( + execute_parallel_atomic_actions, +) + +__all__ = [ + "AgentGraphEdge", + "AgentGraphNode", + "AgentTaskGraph", + "ExecutedActionList", +] + + +@dataclass +class AgentGraphNode: + """Semantic keyframe in an atomic-action task graph.""" + + id: str + semantic: str = "" + + +@dataclass +class AgentGraphEdge: + """Executable transition between two graph nodes.""" + + id: str + source: str + target: str + left_arm_action: Any = None + right_arm_action: Any = None + + +class ExecutedActionList: + """Action sequence already executed online by the graph runtime.""" + + already_executed = True + + def __init__(self, actions: list[Any]) -> None: + self.actions = actions + + def __len__(self) -> int: + return len(self.actions) + + def __iter__(self): + return iter(self.actions) + + def __getitem__(self, index): + return self.actions[index] + + +class AgentTaskGraph: + """Deterministic atomic-action graph with one nominal start-to-goal path.""" + + def __init__(self, start: str, goal: str, max_transitions: int = 1000) -> None: + self.start = start + self.goal = goal + self.max_transitions = max_transitions + self.nodes: dict[str, AgentGraphNode] = {} + self.edges: dict[str, AgentGraphEdge] = {} + self.outgoing: dict[str, list[str]] = defaultdict(list) + + def add_node(self, node_id: str, semantic: str = "") -> "AgentTaskGraph": + self.nodes[node_id] = AgentGraphNode(node_id, semantic) + return self + + def add_edge( + self, + edge_id: str, + source: str, + target: str, + *, + left_arm_action=None, + right_arm_action=None, + ) -> "AgentTaskGraph": + self.edges[edge_id] = AgentGraphEdge( + id=edge_id, + source=source, + target=target, + left_arm_action=left_arm_action, + right_arm_action=right_arm_action, + ) + self.outgoing[source].append(edge_id) + return self + + def run(self, env=None, **kwargs) -> ExecutedActionList: + current = self.start + executed_actions: list[Any] = [] + transitions = 0 + + while current != self.goal: + transitions += 1 + if transitions > self.max_transitions: + raise RuntimeError("Agent task graph exceeded max_transitions.") + + edge = self.edges[self._next_edge(current)] + actions = execute_parallel_atomic_actions( + left_arm_action=edge.left_arm_action, + right_arm_action=edge.right_arm_action, + env=env, + **kwargs, + ) + executed_actions.extend(actions) + current = edge.target + + return ExecutedActionList(executed_actions) + + def _next_edge(self, node_id: str) -> str: + outgoing_edges = self.outgoing[node_id] + if len(outgoing_edges) != 1: + raise RuntimeError( + f"Nominal node '{node_id}' must have exactly one outgoing edge." + ) + return outgoing_edges[0] diff --git a/embodichain/gen_sim/action_agent_pipeline/utils/__init__.py b/embodichain/gen_sim/action_agent_pipeline/utils/__init__.py new file mode 100644 index 00000000..9cfdb173 --- /dev/null +++ b/embodichain/gen_sim/action_agent_pipeline/utils/__init__.py @@ -0,0 +1,18 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + diff --git a/embodichain/gen_sim/action_agent_pipeline/utils/llm_config.py b/embodichain/gen_sim/action_agent_pipeline/utils/llm_config.py new file mode 100644 index 00000000..c267a1c8 --- /dev/null +++ b/embodichain/gen_sim/action_agent_pipeline/utils/llm_config.py @@ -0,0 +1,159 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +import json +import os +from pathlib import Path +from typing import Any + +__all__ = [ + "DEFAULT_LLM_MODEL", + "ACTION_PIPELINE_LLM_ENV_PATH", + "GEN_CONFIG_PATH", + "LLM_ENV_PATH", + "LEGACY_LLM_ENV_PATH", + "SIMREADY_LLM_ENV_PATH", + "get_openai_compatible_llm_config", +] + +DEFAULT_LLM_MODEL = "gpt-4o" +CONFIG_DIR = Path(__file__).resolve().parent +PROJECT_ROOT = next( + ( + parent + for parent in CONFIG_DIR.parents + if (parent / "setup.py").exists() and (parent / "embodichain").exists() + ), + CONFIG_DIR.parents[3], +) +GEN_CONFIG_PATH = ( + PROJECT_ROOT / "embodichain/gen_sim/simready_pipeline/configs/gen_config.json" +) +LLM_ENV_PATH = PROJECT_ROOT / ".env" +SIMREADY_LLM_ENV_PATH = ( + PROJECT_ROOT / "embodichain/gen_sim/simready_pipeline/configs/.env" +) +ACTION_PIPELINE_LLM_ENV_PATH = CONFIG_DIR / ".env" +LEGACY_LLM_ENV_PATH = SIMREADY_LLM_ENV_PATH + + +def _load_env_file(path: Path | None = None) -> dict[str, str]: + """Read local KEY=VALUE credentials without overriding shell variables.""" + path = path or LLM_ENV_PATH + if not path.exists(): + return {} + + env_values: dict[str, str] = {} + for raw_line in path.read_text(encoding="utf-8").splitlines(): + line = raw_line.strip() + if not line or line.startswith("#") or "=" not in line: + continue + key, value = line.split("=", 1) + key = key.strip() + value = value.strip().strip("\"'") + if key: + env_values[key] = value + return env_values + + +def _load_env_files(paths: tuple[Path, ...] | None = None) -> dict[str, str]: + """Read local env files, with later paths taking precedence.""" + env_values: dict[str, str] = {} + for path in paths or ( + SIMREADY_LLM_ENV_PATH, + ACTION_PIPELINE_LLM_ENV_PATH, + LLM_ENV_PATH, + ): + env_values.update(_load_env_file(path)) + return env_values + + +def _get_first_value( + local_env: dict[str, str], + *names: str, + default: str | None = None, +) -> str | None: + for name in names: + value = os.getenv(name) + if value: + return value + value = local_env.get(name) + if value: + return value + return default + + +def _load_gen_config(path: Path | None = None) -> dict[str, Any]: + path = path or GEN_CONFIG_PATH + if not path.exists(): + raise FileNotFoundError(f"gen_config.json not found: {path}") + + with path.open("r", encoding="utf-8") as f: + raw_cfg = json.load(f) + return dict(raw_cfg.get("llm", {}).get("openai_compatible", {})) + + +def get_openai_compatible_llm_config( + *, + required: bool = False, + require_base_url: bool = False, + default_model: str = DEFAULT_LLM_MODEL, +) -> dict[str, Any]: + """Return shared OpenAI-compatible LLM config for agents and gen-sim.""" + local_env = _load_env_files() + json_cfg = _load_gen_config() + + cfg = { + "api_key": _get_first_value(local_env, "OPENAI_API_KEY") + or json_cfg.get("api_key", ""), + "model": _get_first_value(local_env, "OPENAI_MODEL", "LLM_MODEL") + or json_cfg.get("model") + or default_model, + "base_url": _get_first_value( + local_env, + "OPENAI_BASE_URL", + "OPENAI_API_BASE", + "LLM_URL", + ) + or json_cfg.get("base_url", ""), + "default_query": json_cfg.get("default_query", {}) or {}, + "proxy_url": _get_first_value( + local_env, + "EMBODICHAIN_LLM_PROXY", + "LLM_PROXY_URL", + ) + or json_cfg.get("proxy_url", ""), + } + + if cfg["base_url"]: + cfg["base_url"] = cfg["base_url"].rstrip("/") + + if required: + required_keys = ["api_key", "model"] + if require_base_url: + required_keys.append("base_url") + missing = [key for key in required_keys if not cfg.get(key)] + if missing: + raise ValueError( + f"Missing required LLM config keys: {missing}. " + "Set them in shell environment variables, " + f"{LLM_ENV_PATH}, {ACTION_PIPELINE_LLM_ENV_PATH}, " + f"{SIMREADY_LLM_ENV_PATH}, or {GEN_CONFIG_PATH}." + ) + + return cfg diff --git a/embodichain/gen_sim/action_agent_pipeline/utils/llm_json.py b/embodichain/gen_sim/action_agent_pipeline/utils/llm_json.py new file mode 100644 index 00000000..68bdacfb --- /dev/null +++ b/embodichain/gen_sim/action_agent_pipeline/utils/llm_json.py @@ -0,0 +1,73 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +import json +import re +from collections.abc import Mapping +from typing import Any + +__all__ = [ + "extract_json_object", + "normalize_json_content", +] + +_JSON_FENCE_RE = re.compile(r"```(?:json)?\s*(.*?)\s*```", re.DOTALL | re.IGNORECASE) + + +def extract_json_object(content: str | Mapping[str, Any]) -> dict[str, Any]: + """Extract a JSON object from plain or fenced LLM content. + + Args: + content: Raw LLM text, already parsed JSON-like mapping, or markdown fenced + JSON content. + + Returns: + Parsed JSON object. + + Raises: + ValueError: If no JSON object can be parsed. + """ + if isinstance(content, Mapping): + return dict(content) + + text = str(content).strip() + candidates = [match.group(1).strip() for match in _JSON_FENCE_RE.finditer(text)] + candidates.append(text) + + decoder = json.JSONDecoder() + for candidate in candidates: + try: + value = json.loads(candidate) + except json.JSONDecodeError: + start = candidate.find("{") + if start < 0: + continue + try: + value, _ = decoder.raw_decode(candidate[start:]) + except json.JSONDecodeError: + continue + + if isinstance(value, dict): + return value + + raise ValueError("Expected a JSON object in the LLM response.") + + +def normalize_json_content(content: str | Mapping[str, Any]) -> str: + """Normalize JSON-like LLM content into stable pretty-printed JSON text.""" + return json.dumps(extract_json_object(content), ensure_ascii=False, indent=2) diff --git a/embodichain/gen_sim/action_agent_pipeline/utils/llm_usage.py b/embodichain/gen_sim/action_agent_pipeline/utils/llm_usage.py new file mode 100644 index 00000000..e8919bdd --- /dev/null +++ b/embodichain/gen_sim/action_agent_pipeline/utils/llm_usage.py @@ -0,0 +1,410 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from collections.abc import Mapping +from datetime import datetime, timezone +import json +import os +from pathlib import Path +import re +from typing import Any + +__all__ = [ + "LLM_USAGE_PATH_ENV", + "LLM_USAGE_PROCESS_ENV", + "LLM_USAGE_RUN_ID_ENV", + "UsageTrackedChatModel", + "build_usage_summary", + "configure_usage_tracking", + "disable_usage_tracking", + "extract_usage_from_langchain_response", + "normalize_usage", + "normalize_usage_stage", + "record_langchain_usage", + "record_llm_usage", + "scrub_usage_tracking_env", + "write_usage_summary", +] + + +LLM_USAGE_PATH_ENV = "EMBODICHAIN_LLM_USAGE_PATH" +LLM_USAGE_RUN_ID_ENV = "EMBODICHAIN_LLM_USAGE_RUN_ID" +LLM_USAGE_PROCESS_ENV = "EMBODICHAIN_LLM_USAGE_PROCESS" + +_USAGE_ENV_KEYS = { + LLM_USAGE_PATH_ENV, + LLM_USAGE_RUN_ID_ENV, + LLM_USAGE_PROCESS_ENV, +} +_TOKEN_FIELDS = ( + "input_tokens", + "output_tokens", + "total_tokens", + "cached_tokens", + "reasoning_tokens", +) + + +class UsageTrackedChatModel: + """Proxy a LangChain chat model and record usage after each invoke call.""" + + def __init__( + self, + inner: Any, + *, + stage: str | None, + provider: str = "langchain_openai", + ) -> None: + self._inner = inner + self._usage_stage = normalize_usage_stage(stage or "chat") + self._usage_provider = provider + + def invoke(self, *args, **kwargs): + response = self._inner.invoke(*args, **kwargs) + record_langchain_usage( + response, + stage=self._usage_stage, + provider=self._usage_provider, + model=_model_name_from_chat_model(self._inner), + ) + return response + + def __getattr__(self, name: str) -> Any: + return getattr(self._inner, name) + + +def configure_usage_tracking( + *, + usage_path: str | Path, + run_id: str, + process_name: str, + reset: bool = False, +) -> Path: + """Configure process-local environment variables for LLM usage logging.""" + path = Path(usage_path).expanduser().resolve() + path.parent.mkdir(parents=True, exist_ok=True) + if reset: + path.write_text("", encoding="utf-8") + os.environ[LLM_USAGE_PATH_ENV] = path.as_posix() + os.environ[LLM_USAGE_RUN_ID_ENV] = str(run_id) + os.environ[LLM_USAGE_PROCESS_ENV] = str(process_name) + return path + + +def disable_usage_tracking() -> None: + """Disable process-local EmbodiChain LLM usage logging.""" + for key in _USAGE_ENV_KEYS: + os.environ.pop(key, None) + + +def scrub_usage_tracking_env(env: Mapping[str, str] | None = None) -> dict[str, str]: + """Return an environment copy without EmbodiChain LLM usage variables.""" + cleaned = dict(os.environ if env is None else env) + for key in _USAGE_ENV_KEYS: + cleaned.pop(key, None) + return cleaned + + +def normalize_usage_stage(stage: str) -> str: + """Normalize a human-readable usage stage into a compact identifier.""" + value = str(stage or "unknown").strip().lower() + value = re.sub(r"[^a-z0-9_.-]+", "_", value) + value = re.sub(r"_+", "_", value).strip("_.-") + return value or "unknown" + + +def normalize_usage(usage: Mapping[str, Any] | None) -> dict[str, int | None]: + """Normalize OpenAI and LangChain token usage shapes.""" + if not isinstance(usage, Mapping): + return {field: None for field in _TOKEN_FIELDS} + + input_tokens = _first_int(usage, "input_tokens", "prompt_tokens") + output_tokens = _first_int(usage, "output_tokens", "completion_tokens") + total_tokens = _first_int(usage, "total_tokens") + if total_tokens is None and input_tokens is not None and output_tokens is not None: + total_tokens = input_tokens + output_tokens + + prompt_details = _mapping_value(usage, "prompt_tokens_details") + input_details = _mapping_value(usage, "input_token_details") + completion_details = _mapping_value(usage, "completion_tokens_details") + output_details = _mapping_value(usage, "output_token_details") + + cached_tokens = _first_int(usage, "cached_tokens", "cache_read") + if cached_tokens is None: + cached_tokens = _first_int( + prompt_details, + "cached_tokens", + "cache_read", + ) + if cached_tokens is None: + cached_tokens = _first_int(input_details, "cached_tokens", "cache_read") + + reasoning_tokens = _first_int(usage, "reasoning_tokens", "reasoning") + if reasoning_tokens is None: + reasoning_tokens = _first_int( + completion_details, + "reasoning_tokens", + "reasoning", + ) + if reasoning_tokens is None: + reasoning_tokens = _first_int( + output_details, + "reasoning_tokens", + "reasoning", + ) + + return { + "input_tokens": input_tokens, + "output_tokens": output_tokens, + "total_tokens": total_tokens, + "cached_tokens": cached_tokens, + "reasoning_tokens": reasoning_tokens, + } + + +def extract_usage_from_langchain_response( + response: Any, +) -> tuple[dict[str, int | None], dict[str, Any]]: + """Extract usage fields and lightweight metadata from a LangChain response.""" + metadata = _mapping_value_from_object(response, "response_metadata") + usage = _mapping_value_from_object(response, "usage_metadata") + if not usage: + usage = _mapping_value(metadata, "token_usage") + + usage_values = normalize_usage(usage) + response_metadata = { + "model": _string_value(metadata, "model_name", "model"), + "request_id": _string_value(metadata, "id", "request_id"), + "finish_reason": _finish_reason(metadata), + "raw_usage": _json_safe(usage) if isinstance(usage, Mapping) else None, + } + return usage_values, response_metadata + + +def record_langchain_usage( + response: Any, + *, + stage: str, + provider: str = "langchain_openai", + model: str | None = None, +) -> None: + """Record usage from a LangChain response if usage logging is enabled.""" + usage, metadata = extract_usage_from_langchain_response(response) + record_llm_usage( + stage=stage, + provider=provider, + model=metadata.get("model") or model, + usage=usage, + request_id=metadata.get("request_id"), + finish_reason=metadata.get("finish_reason"), + raw_usage=metadata.get("raw_usage"), + ) + + +def record_llm_usage( + *, + stage: str, + provider: str, + model: str | None, + usage: Mapping[str, Any] | None, + request_id: str | None = None, + finish_reason: str | None = None, + raw_usage: Mapping[str, Any] | None = None, + metadata: Mapping[str, Any] | None = None, +) -> None: + """Append one LLM usage record to the configured JSONL file.""" + usage_path = os.getenv(LLM_USAGE_PATH_ENV) + if not usage_path: + return + + usage_values = normalize_usage(usage) + usage_available = any(usage_values[field] is not None for field in _TOKEN_FIELDS) + record: dict[str, Any] = { + "created_at": datetime.now(timezone.utc).isoformat(timespec="milliseconds"), + "run_id": os.getenv(LLM_USAGE_RUN_ID_ENV), + "process": os.getenv(LLM_USAGE_PROCESS_ENV), + "pid": os.getpid(), + "stage": normalize_usage_stage(stage), + "provider": provider, + "model": model, + "usage_available": usage_available, + "request_id": request_id, + "finish_reason": finish_reason, + } + record.update(usage_values) + if raw_usage is not None: + record["raw_usage"] = _json_safe(raw_usage) + if metadata: + record["metadata"] = _json_safe(metadata) + + path = Path(usage_path).expanduser().resolve() + path.parent.mkdir(parents=True, exist_ok=True) + with path.open("a", encoding="utf-8") as file: + file.write(json.dumps(record, ensure_ascii=False, sort_keys=True) + "\n") + + +def build_usage_summary(usage_path: str | Path) -> dict[str, Any]: + """Build aggregate token usage totals from a JSONL usage file.""" + path = Path(usage_path).expanduser().resolve() + records = _read_usage_records(path) + summary: dict[str, Any] = { + "usage_path": path.as_posix(), + "generated_at": datetime.now(timezone.utc).isoformat(timespec="milliseconds"), + "run_id": os.getenv(LLM_USAGE_RUN_ID_ENV), + "total": _empty_bucket(), + "by_stage": {}, + "by_model": {}, + "by_process": {}, + } + + for record in records: + _add_record(summary["total"], record) + _add_grouped_record(summary["by_stage"], record.get("stage"), record) + _add_grouped_record(summary["by_model"], record.get("model"), record) + _add_grouped_record(summary["by_process"], record.get("process"), record) + + return summary + + +def write_usage_summary( + *, + usage_path: str | Path, + summary_path: str | Path, +) -> dict[str, Any]: + """Write a JSON token usage summary and return it.""" + summary = build_usage_summary(usage_path) + path = Path(summary_path).expanduser().resolve() + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text( + json.dumps(summary, ensure_ascii=False, indent=4, sort_keys=True) + "\n", + encoding="utf-8", + ) + return summary + + +def _read_usage_records(path: Path) -> list[dict[str, Any]]: + if not path.is_file(): + return [] + records: list[dict[str, Any]] = [] + for line in path.read_text(encoding="utf-8").splitlines(): + stripped = line.strip() + if not stripped: + continue + try: + parsed = json.loads(stripped) + except json.JSONDecodeError: + continue + if isinstance(parsed, dict): + records.append(parsed) + return records + + +def _empty_bucket() -> dict[str, int]: + bucket = { + "calls": 0, + "calls_with_usage": 0, + } + for field in _TOKEN_FIELDS: + bucket[field] = 0 + return bucket + + +def _add_grouped_record( + groups: dict[str, dict[str, int]], + key: Any, + record: Mapping[str, Any], +) -> None: + group_key = str(key or "unknown") + bucket = groups.setdefault(group_key, _empty_bucket()) + _add_record(bucket, record) + + +def _add_record(bucket: dict[str, int], record: Mapping[str, Any]) -> None: + bucket["calls"] += 1 + if record.get("usage_available"): + bucket["calls_with_usage"] += 1 + for field in _TOKEN_FIELDS: + value = record.get(field) + if isinstance(value, int): + bucket[field] += value + + +def _model_name_from_chat_model(model: Any) -> str | None: + for attr in ("model_name", "model"): + value = getattr(model, attr, None) + if value: + return str(value) + return None + + +def _mapping_value_from_object(value: Any, attr_name: str) -> Mapping[str, Any]: + attr = getattr(value, attr_name, None) + return attr if isinstance(attr, Mapping) else {} + + +def _mapping_value(mapping: Mapping[str, Any], key: str) -> Mapping[str, Any]: + value = mapping.get(key) if isinstance(mapping, Mapping) else None + return value if isinstance(value, Mapping) else {} + + +def _first_int(mapping: Mapping[str, Any], *keys: str) -> int | None: + if not isinstance(mapping, Mapping): + return None + for key in keys: + value = mapping.get(key) + if isinstance(value, bool): + continue + if isinstance(value, int): + return value + if isinstance(value, float) and value.is_integer(): + return int(value) + return None + + +def _string_value(mapping: Mapping[str, Any], *keys: str) -> str | None: + if not isinstance(mapping, Mapping): + return None + for key in keys: + value = mapping.get(key) + if isinstance(value, str) and value: + return value + return None + + +def _finish_reason(metadata: Mapping[str, Any]) -> str | None: + reason = _string_value(metadata, "finish_reason") + if reason: + return reason + response_metadata = ( + metadata.get("response_metadata") if isinstance(metadata, Mapping) else None + ) + if isinstance(response_metadata, Mapping): + return _string_value(response_metadata, "finish_reason") + return None + + +def _json_safe(value: Any) -> Any: + try: + json.dumps(value, ensure_ascii=False) + return value + except TypeError: + if isinstance(value, Mapping): + return {str(key): _json_safe(item) for key, item in value.items()} + if isinstance(value, (list, tuple)): + return [_json_safe(item) for item in value] + return str(value) diff --git a/embodichain/gen_sim/action_agent_pipeline/utils/mllm.py b/embodichain/gen_sim/action_agent_pipeline/utils/mllm.py new file mode 100644 index 00000000..f39f2d0f --- /dev/null +++ b/embodichain/gen_sim/action_agent_pipeline/utils/mllm.py @@ -0,0 +1,115 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +import os +from collections.abc import Mapping +from typing import Any + +from embodichain.gen_sim.action_agent_pipeline.utils.llm_config import ( + DEFAULT_LLM_MODEL, + get_openai_compatible_llm_config, +) +from embodichain.gen_sim.action_agent_pipeline.utils.llm_usage import ( + UsageTrackedChatModel, +) + +__all__ = [ + "DEFAULT_LLM_MODEL", + "apply_proxy_env", + "create_chat_openai", + "create_openai_client", + "get_openai_compatible_llm_config", +] + + +def apply_proxy_env(proxy_url: str | None) -> None: + """Apply an optional proxy URL for OpenAI-compatible clients.""" + if not proxy_url: + return + os.environ["HTTP_PROXY"] = proxy_url + os.environ["HTTPS_PROXY"] = proxy_url + + +def _resolve_llm_config( + *, + config: Mapping[str, Any] | None, + required: bool, + require_base_url: bool, +) -> dict[str, Any]: + if config is not None: + return dict(config) + return get_openai_compatible_llm_config( + required=required, + require_base_url=require_base_url, + ) + + +def create_openai_client( + *, + config: Mapping[str, Any] | None = None, + required: bool = True, + require_base_url: bool = False, +): + """Create the shared OpenAI-compatible SDK client used by gen-sim MLLM calls.""" + from openai import OpenAI + + cfg = _resolve_llm_config( + config=config, + required=required, + require_base_url=require_base_url, + ) + apply_proxy_env(cfg.get("proxy_url")) + + kwargs: dict[str, Any] = { + "api_key": cfg["api_key"], + "default_query": cfg.get("default_query") or None, + } + if cfg.get("base_url"): + kwargs["base_url"] = cfg["base_url"] + return OpenAI(**kwargs) + + +def create_chat_openai( + *, + temperature: float = 0.0, + model: str | None = None, + config: Mapping[str, Any] | None = None, + required: bool = True, + usage_stage: str | None = None, +): + """Create the shared LangChain OpenAI-compatible chat client for agents.""" + from langchain_openai import ChatOpenAI + + cfg = _resolve_llm_config( + config=config, + required=required, + require_base_url=False, + ) + apply_proxy_env(cfg.get("proxy_url")) + + kwargs: dict[str, Any] = { + "temperature": temperature, + "model": model or cfg.get("model") or DEFAULT_LLM_MODEL, + "api_key": cfg["api_key"], + } + if cfg.get("base_url"): + kwargs["base_url"] = cfg["base_url"] + return UsageTrackedChatModel( + ChatOpenAI(**kwargs), + stage=usage_stage, + ) diff --git a/tests/gen_sim/action_agent_pipeline/test_backend_atomic_runtime.py b/tests/gen_sim/action_agent_pipeline/test_backend_atomic_runtime.py new file mode 100644 index 00000000..2cdbe9aa --- /dev/null +++ b/tests/gen_sim/action_agent_pipeline/test_backend_atomic_runtime.py @@ -0,0 +1,463 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from types import SimpleNamespace + +import pytest +import torch + +from embodichain.gen_sim.action_agent_pipeline.runtime import atom_actions +from embodichain.gen_sim.action_agent_pipeline.runtime.atom_actions import ( + execute_atomic_action, + normalize_atomic_action_spec, +) +from embodichain.lab.sim.atomic_actions import ( + MoveActionCfg, + PickUpActionCfg, + PlaceActionCfg, +) + + +class _FakeRobot: + uid = "fake_robot" + device = torch.device("cpu") + control_parts = { + "left_arm": [0, 1], + "left_eef": [2], + "right_arm": [3, 4], + "right_eef": [5], + } + + def get_qpos(self): + return torch.zeros(1, 6) + + +class _FakeObject: + cfg = SimpleNamespace(shape=SimpleNamespace(fpath="/tmp/fake.obj")) + + def __init__(self, xyz): + self._pose = torch.eye(4) + self._pose[:3, 3] = torch.tensor(xyz, dtype=torch.float32) + + def get_local_pose(self, to_matrix: bool = True): + return self._pose.unsqueeze(0) + + def get_vertices(self, env_ids=None, scale: bool = True): + return [torch.tensor([[0.0, 0.0, 0.0], [0.01, 0.0, 0.0], [0.0, 0.01, 0.0]])] + + def get_triangles(self, env_ids=None): + return [torch.tensor([[0, 1, 2]])] + + def get_body_scale(self, env_ids=None): + return torch.ones(1, 3) + + +class _FakeSim: + def __init__(self): + self.objects = {"apple": _FakeObject([0.4, -0.2, 0.1])} + + def get_rigid_object(self, uid: str): + return self.objects.get(uid) + + +class _FakeEnv: + def __init__(self): + self.robot = _FakeRobot() + self.sim = _FakeSim() + self.left_arm_joints = [0, 1] + self.left_eef_joints = [2] + self.right_arm_joints = [3, 4] + self.right_eef_joints = [5] + self.left_arm_current_qpos = torch.tensor([0.1, 0.2]) + self.right_arm_current_qpos = torch.tensor([0.3, 0.4]) + self.left_arm_init_qpos = torch.tensor([-0.1, -0.2]) + self.right_arm_init_qpos = torch.tensor([-0.3, -0.4]) + self.left_arm_current_xpos = torch.eye(4) + self.right_arm_current_xpos = torch.eye(4) + self.left_arm_current_gripper_state = torch.tensor([0.0]) + self.right_arm_current_gripper_state = torch.tensor([0.0]) + self.open_state = torch.tensor([0.05]) + self.close_state = torch.tensor([0.0]) + + def get_current_qpos_agent(self): + return self.left_arm_current_qpos, self.right_arm_current_qpos + + def set_current_qpos_agent(self, arm_qpos, is_left): + if is_left: + self.left_arm_current_qpos = arm_qpos + else: + self.right_arm_current_qpos = arm_qpos + + def get_current_xpos_agent(self): + return self.left_arm_current_xpos, self.right_arm_current_xpos + + def set_current_xpos_agent(self, arm_xpos, is_left): + if is_left: + self.left_arm_current_xpos = arm_xpos + else: + self.right_arm_current_xpos = arm_xpos + + def get_current_gripper_state_agent(self): + return self.left_arm_current_gripper_state, self.right_arm_current_gripper_state + + def set_current_gripper_state_agent(self, arm_gripper_state, is_left): + if is_left: + self.left_arm_current_gripper_state = arm_gripper_state + else: + self.right_arm_current_gripper_state = arm_gripper_state + + def get_arm_fk(self, qpos, is_left): + pose = torch.eye(4) + pose[0, 3] = torch.as_tensor(qpos).flatten()[0] + return pose + + +class _FakeBackendAction: + capture: list | None = None + + def __init__(self, motion_generator, cfg): + self.motion_generator = motion_generator + self.cfg = cfg + if self.capture is not None: + self.capture.append( + { + "cfg": self.cfg, + "motion_generator": self.motion_generator, + } + ) + + def execute(self, target, start_qpos=None, **kwargs): + if self.capture is not None: + self.capture[-1].update({"target": target, "start_qpos": start_qpos}) + if self.cfg.name in {"pick_up", "place"}: + trajectory = torch.tensor( + [[[0.1, 0.2, 0.3], [0.2, 0.3, 0.4]]], dtype=torch.float32 + ) + return ( + True, + trajectory, + [0, 1, 2] if "left" in self.cfg.control_part else [3, 4, 5], + ) + if self.cfg.control_part.endswith("eef"): + trajectory = torch.tensor([[[0.0], [0.05]]], dtype=torch.float32) + return True, trajectory, [2 if "left" in self.cfg.control_part else 5] + trajectory = torch.tensor([[[0.1, 0.2], [0.2, 0.3]]], dtype=torch.float32) + return True, trajectory, [0, 1] if "left" in self.cfg.control_part else [3, 4] + + +def test_normalize_atomic_action_spec_rejects_legacy_schema() -> None: + with pytest.raises(ValueError, match="Legacy action schema"): + normalize_atomic_action_spec({"action": "move", "robot_name": "left_arm"}) + + +def test_normalize_atomic_action_spec_rejects_legacy_target_kind_schema() -> None: + with pytest.raises(ValueError, match="Legacy target.kind schema"): + normalize_atomic_action_spec( + { + "atomic_action_class": "MoveAction", + "robot_name": "left_arm", + "control": "arm", + "target": {"kind": "pose_relative_to_object", "obj_name": "apple"}, + "cfg": {}, + } + ) + + +def test_normalize_atomic_action_spec_rejects_unknown_fields() -> None: + with pytest.raises(ValueError, match="Unsupported atomic action spec fields"): + normalize_atomic_action_spec( + { + "atomic_action_class": "MoveAction", + "robot_name": "left_arm", + "control": "arm", + "target_qpos": {"source": "initial"}, + "cfg": {}, + "description": "return home", + } + ) + + +def test_normalize_atomic_action_spec_rejects_multiple_target_fields() -> None: + with pytest.raises(ValueError, match="exactly one of target_object"): + normalize_atomic_action_spec( + { + "atomic_action_class": "MoveAction", + "robot_name": "left_arm", + "control": "arm", + "target_pose": { + "reference": "relative", + "offset": [0.0, 0.0, 0.1], + }, + "target_qpos": {"source": "initial"}, + "cfg": {}, + } + ) + + +def test_normalize_atomic_action_spec_rejects_orientation_field() -> None: + with pytest.raises(ValueError, match="Unsupported target_pose fields"): + normalize_atomic_action_spec( + { + "atomic_action_class": "MoveAction", + "robot_name": "left_arm", + "control": "arm", + "target_pose": { + "reference": "object", + "obj_name": "apple", + "offset": [0.0, 0.0, 0.1], + "orientation": "current", + }, + "cfg": {}, + } + ) + + +def test_normalize_atomic_action_spec_rejects_pickup_pose_target() -> None: + with pytest.raises(ValueError, match="PickUpAction requires control='arm'"): + normalize_atomic_action_spec( + { + "atomic_action_class": "PickUpAction", + "robot_name": "left_arm", + "control": "arm", + "target_pose": { + "reference": "relative", + "offset": [0.0, 0.0, 0.1], + }, + "cfg": {}, + } + ) + + +def test_atom_actions_module_exposes_atomic_runtime_entrypoints() -> None: + assert atom_actions.execute_atomic_action is execute_atomic_action + assert atom_actions.normalize_atomic_action_spec is normalize_atomic_action_spec + assert callable(atom_actions.execute_parallel_atomic_actions) + + +def test_object_referenced_pose_builds_move_cfg_and_pose_target(monkeypatch) -> None: + env = _FakeEnv() + capture = [] + _FakeBackendAction.capture = capture + + monkeypatch.setattr( + atom_actions, + "_make_motion_generator", + lambda env: SimpleNamespace(robot=env.robot, device=env.robot.device), + ) + monkeypatch.setattr( + atom_actions, + "_get_atomic_action_class", + lambda atomic_action_class: _FakeBackendAction, + ) + + action = execute_atomic_action( + { + "atomic_action_class": "MoveAction", + "robot_name": "left_arm", + "control": "arm", + "target_pose": { + "reference": "object", + "obj_name": "apple", + "offset": [0.1, 0.2, 0.3], + }, + "cfg": {"sample_interval": 12}, + }, + env=env, + ) + + assert action.shape == (2, 3) + assert isinstance(capture[0]["cfg"], MoveActionCfg) + assert capture[0]["cfg"].control_part == "left_arm" + assert capture[0]["cfg"].sample_interval == 12 + assert capture[0]["target"][:3, 3].tolist() == pytest.approx([0.5, 0.0, 0.4]) + + +def test_gripper_state_qpos_target_interpolates_hand_action(monkeypatch) -> None: + env = _FakeEnv() + capture = [] + _FakeBackendAction.capture = capture + + monkeypatch.setattr( + atom_actions, + "_make_motion_generator", + lambda env: SimpleNamespace(robot=env.robot, device=env.robot.device), + ) + monkeypatch.setattr( + atom_actions, + "_get_atomic_action_class", + lambda atomic_action_class: _FakeBackendAction, + ) + + action = execute_atomic_action( + { + "atomic_action_class": "MoveAction", + "robot_name": "left_arm", + "control": "hand", + "target_qpos": {"source": "gripper_state", "state": "open"}, + "cfg": {"sample_interval": 5, "post_hold_steps": 2}, + }, + env=env, + ) + + assert action.shape == (7, 3) + assert capture == [] + assert action[0].tolist() == pytest.approx([0.1, 0.2, 0.0]) + assert action[4].tolist() == pytest.approx([0.1, 0.2, 0.05]) + assert action[-1].tolist() == pytest.approx([0.1, 0.2, 0.05]) + assert env.left_arm_current_gripper_state.tolist() == pytest.approx([0.05]) + + +def test_initial_qpos_target_interpolates_arm_action(monkeypatch) -> None: + env = _FakeEnv() + capture = [] + _FakeBackendAction.capture = capture + + monkeypatch.setattr( + atom_actions, + "_make_motion_generator", + lambda env: SimpleNamespace(robot=env.robot, device=env.robot.device), + ) + monkeypatch.setattr( + atom_actions, + "_get_atomic_action_class", + lambda atomic_action_class: _FakeBackendAction, + ) + + action = execute_atomic_action( + { + "atomic_action_class": "MoveAction", + "robot_name": "right_arm", + "control": "arm", + "target_qpos": {"source": "initial"}, + "cfg": {"sample_interval": 4}, + }, + env=env, + ) + + assert action.shape == (4, 3) + assert capture == [] + assert action[0].tolist() == pytest.approx([0.3, 0.4, 0.0]) + assert action[-1].tolist() == pytest.approx([-0.3, -0.4, 0.0]) + assert env.right_arm_current_qpos.tolist() == pytest.approx([-0.3, -0.4]) + + +def test_target_object_builds_pick_up_cfg(monkeypatch) -> None: + env = _FakeEnv() + capture = [] + _FakeBackendAction.capture = capture + + monkeypatch.setattr( + atom_actions, + "_make_motion_generator", + lambda env: SimpleNamespace(robot=env.robot, device=env.robot.device), + ) + monkeypatch.setattr( + atom_actions, + "_get_atomic_action_class", + lambda atomic_action_class: _FakeBackendAction, + ) + + execute_atomic_action( + { + "atomic_action_class": "PickUpAction", + "robot_name": "left_arm", + "control": "arm", + "target_object": { + "obj_name": "apple", + "affordance": "antipodal", + }, + "cfg": { + "pre_grasp_distance": 0.07, + "sample_interval": 11, + }, + }, + env=env, + allow_grasp_annotation=True, + ) + + assert isinstance(capture[0]["cfg"], PickUpActionCfg) + assert capture[0]["cfg"].control_part == "left_arm" + assert capture[0]["cfg"].hand_control_part == "left_eef" + assert capture[0]["cfg"].pre_grasp_distance == pytest.approx(0.07) + assert capture[0]["target"].label == "apple" + + +def test_place_action_builds_place_cfg(monkeypatch) -> None: + env = _FakeEnv() + capture = [] + _FakeBackendAction.capture = capture + + monkeypatch.setattr( + atom_actions, + "_make_motion_generator", + lambda env: SimpleNamespace(robot=env.robot, device=env.robot.device), + ) + monkeypatch.setattr( + atom_actions, + "_get_atomic_action_class", + lambda atomic_action_class: _FakeBackendAction, + ) + + action = execute_atomic_action( + { + "atomic_action_class": "PlaceAction", + "robot_name": "left_arm", + "control": "arm", + "target_pose": { + "reference": "relative", + "offset": [0.0, 0.0, 0.1], + "frame": "world", + }, + "cfg": {"sample_interval": 19, "lift_height": 0.06}, + }, + env=env, + ) + + assert action.shape == (2, 3) + assert isinstance(capture[0]["cfg"], PlaceActionCfg) + assert capture[0]["cfg"].control_part == "left_arm" + assert capture[0]["cfg"].lift_height == pytest.approx(0.06) + + +def test_place_action_rejects_qpos_target(monkeypatch) -> None: + env = _FakeEnv() + monkeypatch.setattr( + atom_actions, + "_make_motion_generator", + lambda env: SimpleNamespace(robot=env.robot, device=env.robot.device), + ) + monkeypatch.setattr( + atom_actions, + "_get_atomic_action_class", + lambda atomic_action_class: _FakeBackendAction, + ) + + with pytest.raises( + ValueError, + match="PlaceAction requires control='arm' and target_pose", + ): + execute_atomic_action( + { + "atomic_action_class": "PlaceAction", + "robot_name": "left_arm", + "control": "arm", + "target_qpos": {"source": "initial"}, + "cfg": {"sample_interval": 20}, + }, + env=env, + ) diff --git a/tests/gen_sim/action_agent_pipeline/test_coacd_cache.py b/tests/gen_sim/action_agent_pipeline/test_coacd_cache.py new file mode 100644 index 00000000..8a974071 --- /dev/null +++ b/tests/gen_sim/action_agent_pipeline/test_coacd_cache.py @@ -0,0 +1,130 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +import hashlib +import pickle + +import pytest +import torch + +from embodichain.gen_sim.action_agent_pipeline.generation.coacd_cache import ( + coacd_cache_path_for_mesh, + dexsim_coacd_cache_key_for_mesh, +) +from embodichain.gen_sim.action_agent_pipeline.runtime.coacd_cache_bridge import ( + ensure_grasp_collision_cache_from_env_coacd, +) + + +def test_coacd_cache_path_matches_dexsim_load_actor_key(tmp_path) -> None: + mesh_path = tmp_path / "object.obj" + mesh_path.write_text("# placeholder mesh\n", encoding="utf-8") + cache_dir = tmp_path / "cache" + + cache_path = coacd_cache_path_for_mesh( + mesh_path, + 16, + cache_dir, + ) + + expected_key = hashlib.md5( + f"{mesh_path.resolve()}|mesh_count=1".encode("utf-8") + ).hexdigest() + assert dexsim_coacd_cache_key_for_mesh(mesh_path) == expected_key + assert cache_path == cache_dir.resolve() / f"{expected_key}_16.obj" + + +def test_grasp_cache_bridge_uses_existing_env_coacd_obj(tmp_path) -> None: + pytest.importorskip("dexsim.kit.meshproc.convex_cache") + source_mesh_path = tmp_path / "source.obj" + _write_tetra_obj(source_mesh_path) + + cache_dir = tmp_path / "cache" + env_cache_path = coacd_cache_path_for_mesh( + source_mesh_path, + 4, + cache_dir, + ) + env_cache_path.parent.mkdir(parents=True, exist_ok=True) + _write_tetra_obj(env_cache_path) + + mesh_vertices = torch.tensor( + [ + [0.0, 0.0, 0.0], + [2.0, 0.0, 0.0], + [0.0, 2.0, 0.0], + [0.0, 0.0, 2.0], + ], + dtype=torch.float32, + ) + mesh_triangles = torch.tensor( + [ + [0, 2, 1], + [0, 1, 3], + [1, 2, 3], + [2, 0, 3], + ], + dtype=torch.int64, + ) + + result = ensure_grasp_collision_cache_from_env_coacd( + mesh_vertices=mesh_vertices, + mesh_triangles=mesh_triangles, + source_mesh_path=source_mesh_path, + max_decomposition_hulls=4, + body_scale=[2.0, 2.0, 2.0], + cache_dir=cache_dir, + ) + + assert result["status"] == "generated" + assert result["env_cache_path"] == env_cache_path.as_posix() + with open(result["grasp_cache_path"], "rb") as cache_file: + cache = pickle.load(cache_file) + assert set(cache) == {"plane_equations", "plane_equation_counts"} + assert cache["plane_equations"].shape[-1] == 4 + assert cache["plane_equation_counts"].numel() == 1 + + second_result = ensure_grasp_collision_cache_from_env_coacd( + mesh_vertices=mesh_vertices, + mesh_triangles=mesh_triangles, + source_mesh_path=source_mesh_path, + max_decomposition_hulls=4, + body_scale=[2.0, 2.0, 2.0], + cache_dir=cache_dir, + ) + assert second_result["status"] == "hit" + + +def _write_tetra_obj(path) -> None: + path.write_text( + "\n".join( + [ + "o convex_0", + "v 0.0 0.0 0.0", + "v 1.0 0.0 0.0", + "v 0.0 1.0 0.0", + "v 0.0 0.0 1.0", + "f 1 3 2", + "f 1 2 4", + "f 2 3 4", + "f 3 1 4", + "", + ] + ), + encoding="utf-8", + ) diff --git a/tests/gen_sim/action_agent_pipeline/test_demo3_semantic_grasp_integration.py b/tests/gen_sim/action_agent_pipeline/test_demo3_semantic_grasp_integration.py new file mode 100644 index 00000000..1c55ef74 --- /dev/null +++ b/tests/gen_sim/action_agent_pipeline/test_demo3_semantic_grasp_integration.py @@ -0,0 +1,262 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +import argparse +import json +import os +import re +from pathlib import Path +from typing import TYPE_CHECKING + +import pytest + +if TYPE_CHECKING: + import torch + +pytestmark = pytest.mark.skipif( + os.environ.get("RUN_DEXSIM_GRASP_TESTS") != "1", + reason="Set RUN_DEXSIM_GRASP_TESTS=1 to run DexSim semantic grasp integration tests.", +) + + +_REPO_ROOT = Path(__file__).resolve().parents[3] +_DEFAULT_DEMO3_CONFIG_DIR = ( + _REPO_ROOT / "gym_project/action_agent_pipeline/configs/demo3_text" +) +_DEMO3_CONFIG_DIR = ( + Path( + os.environ.get( + "RUN_DEXSIM_GRASP_CONFIG_DIR", + str(_DEFAULT_DEMO3_CONFIG_DIR), + ) + ) + .expanduser() + .resolve() +) +_MIN_LIFT_M = float(os.environ.get("RUN_DEXSIM_GRASP_MIN_LIFT_M", "0.04")) +_MAX_EEF_DISTANCE_M = float( + os.environ.get("RUN_DEXSIM_GRASP_MAX_EEF_DISTANCE_M", "0.25") +) +_POST_GRASP_HOLD_STEPS = int(os.environ.get("RUN_DEXSIM_GRASP_HOLD_STEPS", "10")) +_PICK_UP_SPEC_RE = re.compile( + r'"atomic_action_class"\s*:\s*"PickUpAction".*?' + r'"robot_name"\s*:\s*"(?P[^"]+)".*?' + r'"obj_name"\s*:\s*"(?P[^"]+)"', + re.DOTALL, +) + + +def _load_demo3_gym_config() -> dict: + return json.loads( + (_DEMO3_CONFIG_DIR / "fast_gym_config.json").read_text(encoding="utf-8") + ) + + +def _configured_rigid_object_uids() -> set[str]: + return { + rigid_object["uid"] + for rigid_object in _load_demo3_gym_config().get("rigid_object", []) + } + + +def _configured_grasp_targets() -> list[tuple[str, str]]: + atom_actions_text = (_DEMO3_CONFIG_DIR / "atom_actions.txt").read_text( + encoding="utf-8" + ) + targets = [ + (match.group("robot_name"), match.group("obj_name")) + for match in _PICK_UP_SPEC_RE.finditer(atom_actions_text) + ] + rigid_object_uids = _configured_rigid_object_uids() + stale_targets = [ + (robot_name, obj_name) + for robot_name, obj_name in targets + if obj_name not in rigid_object_uids + ] + assert not stale_targets, ( + "atom_actions.txt references pick_up objects that are not present in " + f"fast_gym_config.json: stale_targets={stale_targets}, " + f"rigid_object_uids={sorted(rigid_object_uids)}." + ) + return targets + + +def _configured_grasp_target_for(*keywords: str) -> tuple[str, str]: + lower_keywords = tuple(keyword.lower() for keyword in keywords) + matches = [ + (robot_name, obj_name) + for robot_name, obj_name in _configured_grasp_targets() + if all(keyword in obj_name.lower() for keyword in lower_keywords) + ] + assert matches, ( + f"No configured grasp target matching keywords={keywords}. " + f"grasp_targets={_configured_grasp_targets()}." + ) + assert ( + len(matches) == 1 + ), f"Ambiguous grasp target matching keywords={keywords}: {matches}." + return matches[0] + + +def _write_runtime_gym_config(tmp_path: Path) -> Path: + gym_config = _load_demo3_gym_config() + gym_config["env"]["events"] = {} + gym_config["env"]["dataset"] = {} + gym_config["sensor"] = [] + + runtime_config_path = tmp_path / "demo3_semantic_grasp_gym_config.json" + runtime_config_path.write_text( + json.dumps(gym_config, indent=2), + encoding="utf-8", + ) + return runtime_config_path + + +def _make_env(tmp_path: Path): + import gymnasium + + from embodichain.lab.gym.utils.gym_utils import build_env_cfg_from_args + from embodichain.utils.utility import load_config + + # Import registers AtomicActionsAgent-v3. + from embodichain.gen_sim.action_agent_pipeline.env_adapters.tableware import ( # noqa: F401 + agent_env, + ) + + args = argparse.Namespace( + num_envs=1, + device=os.environ.get("RUN_DEXSIM_GRASP_DEVICE", "cpu"), + headless=True, + renderer=os.environ.get("RUN_DEXSIM_GRASP_RENDERER", "hybrid"), + arena_space=float(os.environ.get("RUN_DEXSIM_GRASP_ARENA_SPACE", "5.0")), + gpu_id=int(os.environ.get("RUN_DEXSIM_GRASP_GPU_ID", "0")), + gym_config=str(_write_runtime_gym_config(tmp_path)), + action_config=None, + preview=False, + filter_visual_rand=True, + filter_dataset_saving=True, + ) + env_cfg, gym_config, _ = build_env_cfg_from_args(args) + agent_config_path = _DEMO3_CONFIG_DIR / "agent_config.json" + return gymnasium.make( + id=gym_config["id"], + cfg=env_cfg, + agent_config=load_config(agent_config_path), + agent_config_path=str(agent_config_path), + task_name="Demo3_Text", + ) + + +def _object_xyz(env, obj_name: str) -> torch.Tensor: + pose = env.sim.get_rigid_object(obj_name).get_local_pose(to_matrix=True).squeeze(0) + return pose[:3, 3].detach().cpu() + + +def _arm_eef_xyz(env, robot_name: str) -> torch.Tensor: + left_pose, right_pose = env.get_current_xpos_agent() + pose = left_pose if "left" in robot_name else right_pose + return pose[:3, 3].detach().cpu() + + +def _hold_last_action(env, actions: list, steps: int) -> None: + if steps <= 0 or not actions: + return + last_action = actions[-1] + for _ in range(steps): + env.step(last_action) + + +def _assert_semantic_grasp_lifts_object( + tmp_path: Path, + *, + robot_name: str, + obj_name: str, +) -> None: + import torch + + from embodichain.gen_sim.action_agent_pipeline.runtime.atom_actions import ( + execute_parallel_atomic_actions, + ) + + gym_env = _make_env(tmp_path) + env = gym_env.unwrapped + try: + gym_env.reset() + z_before = float(_object_xyz(env, obj_name)[2]) + action_spec = { + "atomic_action_class": "PickUpAction", + "robot_name": robot_name, + "control": "arm", + "target_object": { + "obj_name": obj_name, + "affordance": "antipodal", + }, + "cfg": { + "pre_grasp_distance": 0.08, + "lift_height": 0.14, + "sample_interval": 80, + }, + } + result = execute_parallel_atomic_actions( + left_arm_action=action_spec if "left" in robot_name else None, + right_arm_action=action_spec if "right" in robot_name else None, + env=env, + return_result=True, + allow_grasp_annotation=True, + force_grasp_reannotate=bool( + int(os.environ.get("RUN_DEXSIM_GRASP_FORCE_REANNOTATE", "0")) + ), + ) + _hold_last_action(env, result["actions"], _POST_GRASP_HOLD_STEPS) + + obj_xyz = _object_xyz(env, obj_name) + eef_xyz = _arm_eef_xyz(env, robot_name) + lift = float(obj_xyz[2] - z_before) + eef_distance = float(torch.linalg.norm(obj_xyz - eef_xyz)) + + assert lift >= _MIN_LIFT_M, ( + f"{obj_name} semantic grasp did not lift enough: lift={lift:.4f}m, " + f"required={_MIN_LIFT_M:.4f}m, obj_xyz={obj_xyz.tolist()}, " + f"eef_xyz={eef_xyz.tolist()}." + ) + assert eef_distance <= _MAX_EEF_DISTANCE_M, ( + f"{obj_name} is too far from {robot_name} after grasp: " + f"distance={eef_distance:.4f}m, " + f"required<={_MAX_EEF_DISTANCE_M:.4f}m, " + f"obj_xyz={obj_xyz.tolist()}, eef_xyz={eef_xyz.tolist()}." + ) + finally: + gym_env.close() + + +def test_demo3_semantic_grasp_lifts_orange(tmp_path: Path) -> None: + robot_name, obj_name = _configured_grasp_target_for("orange", "1") + _assert_semantic_grasp_lifts_object( + tmp_path, + robot_name=robot_name, + obj_name=obj_name, + ) + + +def test_demo3_semantic_grasp_lifts_can(tmp_path: Path) -> None: + robot_name, obj_name = _configured_grasp_target_for("can") + _assert_semantic_grasp_lifts_object( + tmp_path, + robot_name=robot_name, + obj_name=obj_name, + ) diff --git a/tests/gen_sim/action_agent_pipeline/test_graph_spec_backend_atomic.py b/tests/gen_sim/action_agent_pipeline/test_graph_spec_backend_atomic.py new file mode 100644 index 00000000..122b188b --- /dev/null +++ b/tests/gen_sim/action_agent_pipeline/test_graph_spec_backend_atomic.py @@ -0,0 +1,121 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +import pytest + +from embodichain.gen_sim.action_agent_pipeline.runtime.graph_compiler import ( + compile_agent_graph_spec, +) + + +class _FakeGraph: + def __init__(self, start: str, goal: str, max_transitions: int = 1000) -> None: + self.start = start + self.goal = goal + self.max_transitions = max_transitions + self.nodes = {} + self.edges = {} + + def add_node(self, node_id: str, semantic: str = ""): + self.nodes[node_id] = semantic + return self + + def add_edge( + self, + edge_id: str, + source: str, + target: str, + *, + left_arm_action=None, + right_arm_action=None, + ): + self.edges[edge_id] = { + "source": source, + "target": target, + "left_arm_action": left_arm_action, + "right_arm_action": right_arm_action, + } + return self + + +def _pick_up_spec(robot_name: str, obj_name: str) -> dict: + return { + "atomic_action_class": "PickUpAction", + "robot_name": robot_name, + "control": "arm", + "target_object": { + "obj_name": obj_name, + "affordance": "antipodal", + }, + "cfg": { + "pre_grasp_distance": 0.08, + "sample_interval": 45, + }, + } + + +def _task_graph(action: dict) -> dict: + return { + "task": "unit", + "start": "v0_start", + "goal": "v1_done", + "nodes": [ + {"id": "v0_start"}, + {"id": "v1_done"}, + ], + "edges": [ + { + "id": "e01", + "source": "v0_start", + "target": "v1_done", + "left_arm_action": action, + "right_arm_action": None, + } + ], + } + + +def test_compile_agent_graph_accepts_atomic_action_class_spec() -> None: + action = _pick_up_spec("left_arm", "apple") + graph = compile_agent_graph_spec( + _task_graph(action), + graph_cls=_FakeGraph, + ) + + assert graph.edges["e01"]["left_arm_action"] == action + + +def test_compile_agent_graph_rejects_legacy_action_schema() -> None: + task_graph = _task_graph({"action": "pick_up", "robot_name": "left_arm"}) + + with pytest.raises(ValueError, match="Legacy action schema"): + compile_agent_graph_spec( + task_graph, + graph_cls=_FakeGraph, + ) + + +def test_compile_agent_graph_rejects_extra_edge_fields() -> None: + task_graph = _task_graph(_pick_up_spec("left_arm", "apple")) + task_graph["edges"][0]["monitor"] = {"condition": "object visible"} + + with pytest.raises(ValueError, match="unsupported fields: monitor"): + compile_agent_graph_spec( + task_graph, + graph_cls=_FakeGraph, + ) diff --git a/tests/gen_sim/action_agent_pipeline/test_llm_usage.py b/tests/gen_sim/action_agent_pipeline/test_llm_usage.py new file mode 100644 index 00000000..8c85123d --- /dev/null +++ b/tests/gen_sim/action_agent_pipeline/test_llm_usage.py @@ -0,0 +1,161 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +import json + +import pytest + +from embodichain.gen_sim.action_agent_pipeline.utils.llm_usage import ( + LLM_USAGE_PATH_ENV, + UsageTrackedChatModel, + build_usage_summary, + configure_usage_tracking, + disable_usage_tracking, + normalize_usage, + record_langchain_usage, + scrub_usage_tracking_env, +) + + +class _FakeLangChainResponse: + usage_metadata = { + "input_tokens": 10, + "output_tokens": 4, + "total_tokens": 14, + "input_token_details": {"cache_read": 3}, + "output_token_details": {"reasoning": 2}, + } + response_metadata = { + "model_name": "gpt-test", + "id": "chatcmpl-test", + "finish_reason": "stop", + } + content = "{}" + + +class _FakeChatModel: + model_name = "gpt-test" + + def __init__(self) -> None: + self.inputs = [] + + def invoke(self, value): + self.inputs.append(value) + return _FakeLangChainResponse() + + +@pytest.fixture(autouse=True) +def _clear_usage_env(): + disable_usage_tracking() + yield + disable_usage_tracking() + + +def test_normalize_usage_handles_openai_and_langchain_shapes(): + openai_usage = { + "prompt_tokens": 11, + "completion_tokens": 5, + "total_tokens": 16, + "prompt_tokens_details": {"cached_tokens": 7}, + "completion_tokens_details": {"reasoning_tokens": 2}, + } + assert normalize_usage(openai_usage) == { + "input_tokens": 11, + "output_tokens": 5, + "total_tokens": 16, + "cached_tokens": 7, + "reasoning_tokens": 2, + } + + langchain_usage = { + "input_tokens": 10, + "output_tokens": 4, + "input_token_details": {"cache_read": 3}, + "output_token_details": {"reasoning": 2}, + } + assert normalize_usage(langchain_usage) == { + "input_tokens": 10, + "output_tokens": 4, + "total_tokens": 14, + "cached_tokens": 3, + "reasoning_tokens": 2, + } + + +def test_record_langchain_usage_writes_jsonl_and_summary(tmp_path): + usage_path = tmp_path / "llm_usage.jsonl" + configure_usage_tracking( + usage_path=usage_path, + run_id="test-run", + process_name="pytest", + reset=True, + ) + + record_langchain_usage( + _FakeLangChainResponse(), + stage="Action Agent Task Graph", + model="fallback-model", + ) + + records = [ + json.loads(line) for line in usage_path.read_text(encoding="utf-8").splitlines() + ] + assert len(records) == 1 + assert records[0]["stage"] == "action_agent_task_graph" + assert records[0]["model"] == "gpt-test" + assert records[0]["input_tokens"] == 10 + assert records[0]["output_tokens"] == 4 + + summary = build_usage_summary(usage_path) + assert summary["total"]["calls"] == 1 + assert summary["total"]["total_tokens"] == 14 + assert summary["by_stage"]["action_agent_task_graph"]["cached_tokens"] == 3 + + +def test_usage_tracked_chat_model_records_invoke(tmp_path): + usage_path = tmp_path / "llm_usage.jsonl" + configure_usage_tracking( + usage_path=usage_path, + run_id="test-run", + process_name="pytest", + reset=True, + ) + inner = _FakeChatModel() + wrapped = UsageTrackedChatModel(inner, stage="action_agent.task_graph") + + response = wrapped.invoke("hello") + + assert response.content == "{}" + assert inner.inputs == ["hello"] + record = json.loads(usage_path.read_text(encoding="utf-8").splitlines()[0]) + assert record["stage"] == "action_agent.task_graph" + assert record["total_tokens"] == 14 + + +def test_scrub_usage_tracking_env_removes_usage_keys(tmp_path): + usage_path = tmp_path / "llm_usage.jsonl" + configure_usage_tracking( + usage_path=usage_path, + run_id="test-run", + process_name="pytest", + reset=True, + ) + + cleaned = scrub_usage_tracking_env() + + assert LLM_USAGE_PATH_ENV not in cleaned diff --git a/tests/gen_sim/action_agent_pipeline/test_ur5_basket_config_generation.py b/tests/gen_sim/action_agent_pipeline/test_ur5_basket_config_generation.py new file mode 100644 index 00000000..fc475160 --- /dev/null +++ b/tests/gen_sim/action_agent_pipeline/test_ur5_basket_config_generation.py @@ -0,0 +1,1866 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from pathlib import Path +import base64 +import hashlib +import json +import struct + +import pytest +import torch + +from embodichain.gen_sim.action_agent_pipeline.generation import ( + ur5_basket_config as ur5_basket_config_generation, +) +from embodichain.gen_sim.action_agent_pipeline.cli import ( + run_agent_pipeline as run_agent_pipeline_cli, +) +from embodichain.gen_sim.action_agent_pipeline.generation.mesh_frame_normalization import ( + MESH_FRAME_NORMALIZATION_POLICY_VERSION, + MeshFrameNormalizer, +) +from embodichain.gen_sim.action_agent_pipeline.generation.ur5_basket_config import ( + TargetReplacementSpec, + generate_ur5_basket_config_from_project, +) +from embodichain.gen_sim.action_agent_pipeline.env_adapters.tableware.success import ( + evaluate_configured_success, +) + + +def test_ur5_basket_generator_uses_parallel_handoff( + tmp_path: Path, +) -> None: + project_dir = tmp_path / "1790000000_gym_project" + _write_project(project_dir) + + paths = generate_ur5_basket_config_from_project( + project_dir, + tmp_path / "generated_agent", + target_body_scale=0.6, + ) + + gym_config = json.loads(paths.gym_config.read_text(encoding="utf-8")) + rigid_objects = {obj["uid"]: obj for obj in gym_config["rigid_object"]} + background_objects = {obj["uid"]: obj for obj in gym_config["background"]} + + assert set(rigid_objects) == {"left_apple", "right_apple"} + assert rigid_objects["left_apple"]["body_scale"] == [0.6, 0.6, 0.6] + assert rigid_objects["right_apple"]["body_scale"] == [0.6, 0.6, 0.6] + assert rigid_objects["left_apple"]["body_type"] == "dynamic" + assert rigid_objects["right_apple"]["body_type"] == "dynamic" + assert background_objects["table"]["body_scale"] == [1.0, 1.0, 1.0] + assert background_objects["wicker_basket"]["body_scale"] == [1.0, 1.0, 1.0] + assert background_objects["wicker_basket"]["body_type"] == "kinematic" + _assert_normalized_obj_path(rigid_objects["left_apple"]["shape"]["fpath"]) + _assert_normalized_obj_path(rigid_objects["right_apple"]["shape"]["fpath"]) + _assert_normalized_obj_path(background_objects["table"]["shape"]["fpath"]) + _assert_normalized_obj_path(background_objects["wicker_basket"]["shape"]["fpath"]) + table_top_z = ur5_basket_config_generation._mesh_config_world_zmax( + background_objects["table"] + ) + expected_robot_init_z = ( + table_top_z + + ur5_basket_config_generation._DUAL_UR5_TABLETOP_CLEARANCE + - ur5_basket_config_generation._DUAL_UR5_ARM_COMPONENT_Z + ) + assert gym_config["robot"]["init_pos"] == pytest.approx( + [2.0, 0.0, expected_robot_init_z] + ) + assert gym_config["robot"]["init_rot"] == [0.0, 0.0, -90.0] + extensions = gym_config["env"]["extensions"] + assert extensions["agent_arm_slots"]["left"] == { + "arm": "right_arm", + "eef": "right_eef", + } + assert extensions["agent_arm_slots"]["right"] == { + "arm": "left_arm", + "eef": "left_eef", + } + assert extensions["arm_aim_yaw_offset"]["left"] == pytest.approx(3.141592653589793) + assert extensions["arm_aim_yaw_offset"]["right"] == pytest.approx(0.0) + + success_terms = gym_config["env"]["extensions"]["agent_success"]["terms"] + assert {term["object"] for term in success_terms} == {"left_apple", "right_apple"} + assert {term["container"] for term in success_terms} == {"wicker_basket"} + + registry = gym_config["env"]["events"]["register_info_to_env"]["params"]["registry"] + registered_uids = {entry["entity_cfg"]["uid"] for entry in registry} + assert registered_uids == {"left_apple", "right_apple", "wicker_basket"} + + task_prompt = paths.task_prompt.read_text(encoding="utf-8") + basic_background = paths.basic_background.read_text(encoding="utf-8") + atom_actions = paths.atom_actions.read_text(encoding="utf-8") + normalized_task_prompt = " ".join(task_prompt.split()) + + assert "Generate exactly 6 nominal edges" in normalized_task_prompt + assert "Generate exactly 10 nominal edges" not in normalized_task_prompt + assert "positive-y side" in basic_background + assert "negative-y side" in basic_background + assert "negative-x side" not in basic_background + assert "positive-x side" not in basic_background + left_high_offset_spec = ( + '"robot_name":"left_arm","control":"arm","target_pose":{"reference":"object",' + '"obj_name":"wicker_basket","offset":[0.0,-0.04,0.22]' + ) + right_high_offset_spec = ( + '"robot_name":"right_arm","control":"arm","target_pose":{"reference":"object",' + '"obj_name":"wicker_basket","offset":[0.0,0.04,0.22]' + ) + assert left_high_offset_spec in task_prompt + assert right_high_offset_spec in task_prompt + assert ( + '"atomic_action_class":"PlaceAction","robot_name":"left_arm","control":"arm",' + '"target_pose":{"reference":"object","obj_name":"wicker_basket",' + '"offset":[0.0,-0.04,0.12]}' in task_prompt + ) + assert ( + '"atomic_action_class":"PlaceAction","robot_name":"right_arm","control":"arm",' + '"target_pose":{"reference":"object","obj_name":"wicker_basket",' + '"offset":[0.0,0.04,0.12]}' in task_prompt + ) + assert '"offset":[-0.04,0.0,0.22]' not in task_prompt + assert '"offset":[0.04,0.0,0.22]' not in task_prompt + assert left_high_offset_spec in atom_actions + assert right_high_offset_spec in atom_actions + assert "parallel handoff" in task_prompt + assert "parallel handoff" in basic_background + assert "parallel handoff" in atom_actions + assert len(paths.summary["normalized_meshes"]) == 4 + + handoff_edge = task_prompt.split("4. After the left gripper", maxsplit=1)[1].split( + "\n5. Place the held right target object", + maxsplit=1, + )[0] + assert ( + '"robot_name":"left_arm","control":"arm","target_qpos":{"source":"initial"}' + in handoff_edge + ) + assert ( + '"robot_name":"right_arm","control":"arm","target_pose":{"reference":"object"' + in handoff_edge + ) + assert '"state":"close"' not in handoff_edge + assert "left_arm_action: null" not in handoff_edge + assert paths.summary["mode"] == "basket_template" + + +def test_generator_normalizes_glb_meshes_and_preserves_source_rot( + tmp_path: Path, +) -> None: + project_dir = tmp_path / "1790000000_gym_project" + _write_project(project_dir) + + paths = generate_ur5_basket_config_from_project( + project_dir, + tmp_path / "generated_agent", + ) + + gym_config = json.loads(paths.gym_config.read_text(encoding="utf-8")) + rigid_objects = {obj["uid"]: obj for obj in gym_config["rigid_object"]} + background_objects = {obj["uid"]: obj for obj in gym_config["background"]} + + assert background_objects["table"]["init_rot"] == [0.0, 0.0, 180.0] + assert background_objects["wicker_basket"]["init_rot"] == [0.0, 0.0, 180.0] + assert rigid_objects["right_apple"]["init_rot"] == [0.0, 0.0, 140.0] + assert rigid_objects["left_apple"]["init_rot"] == [0.0, 0.0, 160.0] + for obj_config in [ + background_objects["table"], + background_objects["wicker_basket"], + rigid_objects["right_apple"], + rigid_objects["left_apple"], + ]: + _assert_normalized_obj_path(obj_config["shape"]["fpath"]) + + source_paths = { + Path(entry["source_path"]).name for entry in paths.summary["normalized_meshes"] + } + assert source_paths == { + "table_0.glb", + "basket_3.glb", + "apple_1.glb", + "apple_2.glb", + } + + +def test_mesh_frame_normalizer_bakes_glb_scene_transform_to_obj( + tmp_path: Path, +) -> None: + mesh_path = tmp_path / "source" / "triangle.glb" + _write_minimal_glb( + mesh_path, + [(0.0, 0.0, 0.0), (0.0, 1.0, 0.0), (0.0, 0.0, 1.0)], + node_translation=(1.0, 0.0, 0.0), + ) + source_sha256 = hashlib.sha256(mesh_path.read_bytes()).hexdigest() + normalizer = MeshFrameNormalizer(output_dir=tmp_path / "normalized") + + normalized_path = normalizer.normalize_path(mesh_path) + repeated_path = normalizer.normalize_path(mesh_path) + + assert repeated_path == normalized_path + assert normalized_path.suffix == ".obj" + assert MESH_FRAME_NORMALIZATION_POLICY_VERSION not in normalized_path.name + assert len(normalized_path.name) <= 64 + obj_text = normalized_path.read_text(encoding="utf-8") + assert f"policy_version: {MESH_FRAME_NORMALIZATION_POLICY_VERSION}" in obj_text + assert f"source_sha256: {source_sha256}" in obj_text + assert "dexsim_engine_version:" in obj_text + assert ( + "transform: [[1.0,0.0,0.0,0.0],[0.0,1.0,0.0,0.0]," + "[0.0,0.0,1.0,0.0],[0.0,0.0,0.0,1.0]]" + ) in obj_text + assert "mtllib material.mtl" in obj_text + material_text = (normalized_path.parent / "material.mtl").read_text( + encoding="utf-8" + ) + material_name = _single_obj_material_name(obj_text) + assert material_name != "material_0" + assert f"newmtl {material_name}" in material_text + assert "map_Kd " not in material_text + assert _rounded_vertex_set(_obj_vertices(normalized_path)) == { + (1.0, 0.0, 0.0), + (1.0, 1.0, 0.0), + (1.0, 0.0, 1.0), + } + + +def test_mesh_frame_normalizer_extracts_embedded_base_color_texture( + tmp_path: Path, +) -> None: + mesh_path = tmp_path / "source" / "textured_triangle.glb" + texture_png = _tiny_png() + _write_minimal_glb( + mesh_path, + _default_mesh_vertices(), + embedded_base_color_png=texture_png, + ) + output_dir = tmp_path / "normalized" + + normalized_path = MeshFrameNormalizer(output_dir=output_dir).normalize_path( + mesh_path + ) + + obj_text = normalized_path.read_text(encoding="utf-8") + material_name = _single_obj_material_name(obj_text) + material_text = (output_dir / "material.mtl").read_text(encoding="utf-8") + assert f"newmtl {material_name}" in material_text + assert "Kd 1.0 1.0 1.0" in material_text + map_kd = _single_map_kd_path(material_text, material_name) + assert map_kd.startswith("textures/") + assert map_kd.endswith("_basecolor.png") + assert (output_dir / map_kd).read_bytes() == texture_png + + material_path = output_dir / "material.mtl" + texture_path = output_dir / map_kd + material_path.unlink() + texture_path.unlink() + + reused_path = MeshFrameNormalizer(output_dir=output_dir).normalize_path(mesh_path) + + assert reused_path == normalized_path + assert material_path.is_file() + assert texture_path.read_bytes() == texture_png + + +def test_mesh_frame_normalizer_recreates_material_library_for_reused_obj( + tmp_path: Path, +) -> None: + mesh_path = tmp_path / "source" / "triangle.glb" + _write_minimal_glb(mesh_path, _default_mesh_vertices()) + output_dir = tmp_path / "normalized" + normalized_path = MeshFrameNormalizer(output_dir=output_dir).normalize_path( + mesh_path + ) + material_path = normalized_path.parent / "material.mtl" + material_path.unlink() + + reused_path = MeshFrameNormalizer(output_dir=output_dir).normalize_path(mesh_path) + + assert reused_path == normalized_path + assert material_path.is_file() + material_text = material_path.read_text(encoding="utf-8") + reused_material_name = _single_obj_material_name( + reused_path.read_text(encoding="utf-8") + ) + assert f"newmtl {reused_material_name}" in material_text + + +def test_target_replacements_generate_meshes_and_replace_paths( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + project_dir = tmp_path / "1790000000_gym_project" + _write_project(project_dir) + calls = _patch_prompt2geometry(monkeypatch) + + paths = generate_ur5_basket_config_from_project( + project_dir, + tmp_path / "generated_agent", + target_replacements=[ + TargetReplacementSpec("apple_1", "A orange", "new1"), + TargetReplacementSpec("apple_2", "A apple", "new2"), + ], + ) + + assert calls == [ + ("A orange", project_dir / "mesh_assets" / "new1", "orange.glb"), + ("A apple", project_dir / "mesh_assets" / "new2", "apple.glb"), + ] + + gym_config = json.loads(paths.gym_config.read_text(encoding="utf-8")) + rigid_objects = {obj["uid"]: obj for obj in gym_config["rigid_object"]} + + background_objects = {obj["uid"]: obj for obj in gym_config["background"]} + + assert set(rigid_objects) == {"left_apple", "right_apple"} + assert "wicker_basket" in background_objects + assert background_objects["wicker_basket"]["body_type"] == "kinematic" + _assert_normalized_obj_path(rigid_objects["right_apple"]["shape"]["fpath"]) + _assert_normalized_obj_path(rigid_objects["left_apple"]["shape"]["fpath"]) + normalized_sources = { + Path(entry["source_path"]).as_posix() + for entry in paths.summary["normalized_meshes"] + } + assert ( + project_dir / "mesh_assets" / "new1" / "orange.glb" + ).as_posix() in normalized_sources + assert ( + project_dir / "mesh_assets" / "new2" / "apple.glb" + ).as_posix() in normalized_sources + assert paths.summary["target_replacements"][0]["source_uid"] == "apple_1" + assert paths.summary["target_replacements"][1]["source_uid"] == "apple_2" + + +def test_target_replacements_can_sync_runtime_names( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + project_dir = tmp_path / "1790000000_gym_project" + _write_project(project_dir) + _patch_prompt2geometry(monkeypatch) + + paths = generate_ur5_basket_config_from_project( + project_dir, + tmp_path / "generated_agent", + target_replacements=[ + TargetReplacementSpec("apple_2", "A orange", "new1"), + TargetReplacementSpec("apple_1", "A apple", "new2"), + ], + sync_replacement_names=True, + ) + + gym_config = json.loads(paths.gym_config.read_text(encoding="utf-8")) + rigid_objects = {obj["uid"]: obj for obj in gym_config["rigid_object"]} + + background_objects = {obj["uid"]: obj for obj in gym_config["background"]} + + assert set(rigid_objects) == {"left_orange", "right_apple"} + assert "wicker_basket" in background_objects + assert background_objects["wicker_basket"]["body_type"] == "kinematic" + _assert_normalized_obj_path(rigid_objects["left_orange"]["shape"]["fpath"]) + _assert_normalized_obj_path(rigid_objects["right_apple"]["shape"]["fpath"]) + + success_terms = gym_config["env"]["extensions"]["agent_success"]["terms"] + assert {term["object"] for term in success_terms} == { + "left_orange", + "right_apple", + } + + task_prompt = paths.task_prompt.read_text(encoding="utf-8") + basic_background = paths.basic_background.read_text(encoding="utf-8") + assert "the left orange and right apple into the wicker_basket" in task_prompt + assert "left_arm must only manipulate `left_orange`" in task_prompt + assert "- left_orange: the orange mesh initially" in basic_background + assert "- right_apple: the apple mesh initially" in basic_background + + +def test_pipeline_auto_replacement_uses_rotated_robot_view_order() -> None: + gym_config = { + "rigid_object": [ + {"uid": "bread_1", "init_pos": [0.0, 0.2, 0.76]}, + {"uid": "bread_2", "init_pos": [0.0, -0.1, 0.76]}, + ], + } + + assert ( + run_agent_pipeline_cli._auto_replacement_source_uid( + gym_config, + replacement_number=1, + option_name="--target_replacement1", + ) + == "bread_2" + ) + assert ( + run_agent_pipeline_cli._auto_replacement_source_uid( + gym_config, + replacement_number=2, + option_name="--target_replacement2", + ) + == "bread_1" + ) + + +def test_directory_input_prefers_merged_config_and_preserves_extra_scene_scale( + tmp_path: Path, +) -> None: + project_dir = tmp_path / "1790000000_gym_project" + _write_project(project_dir) + background_mesh = project_dir / "mesh_assets/backgrounds/vase_0.glb" + _write_minimal_glb(background_mesh, _default_mesh_vertices()) + + merged_config_path = project_dir / "gym_config_merged.json" + source_config = json.loads( + (project_dir / "gym_config.json").read_text(encoding="utf-8") + ) + extra_scene_object = _mesh_object( + "vase_0", + "mesh_assets/backgrounds/vase_0.glb", + [0.16, -0.44, 0.77], + [0.0, 0.0, -90.0], + ) + extra_scene_object["body_scale"] = [1.2, 1.1, 0.9] + source_config["rigid_object"].append(extra_scene_object) + merged_config_path.write_text( + json.dumps(source_config, indent=2), + encoding="utf-8", + ) + + paths = generate_ur5_basket_config_from_project( + project_dir, + tmp_path / "generated_agent", + target_body_scale=0.8, + ) + + gym_config = json.loads(paths.gym_config.read_text(encoding="utf-8")) + rigid_objects = {obj["uid"]: obj for obj in gym_config["rigid_object"]} + + assert set(rigid_objects) == { + "left_apple", + "right_apple", + "vase_0", + } + background_objects = {obj["uid"]: obj for obj in gym_config["background"]} + assert "wicker_basket" in background_objects + assert background_objects["wicker_basket"]["body_type"] == "kinematic" + assert rigid_objects["left_apple"]["body_scale"] == [0.8, 0.8, 0.8] + assert rigid_objects["right_apple"]["body_scale"] == [0.8, 0.8, 0.8] + assert rigid_objects["vase_0"]["body_scale"] == [1.2, 1.1, 0.9] + _assert_normalized_obj_path(rigid_objects["vase_0"]["shape"]["fpath"]) + + +def test_task_description_generates_relative_left_of_config( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + project_dir = tmp_path / "1790000000_gym_project" + _write_project(project_dir) + + def fake_call_relative_task_llm(**kwargs): + assert kwargs["task_description"] == "把 apple_2 放到 basket_3 左边" + return { + "moved_object": "apple_2", + "reference_object": "basket_3", + "goal_relation": "left_of", + "task_prompt_summary": "Move apple_2 to the left of basket_3.", + "basic_background_notes": "The basket is the spatial reference.", + "action_sketch": [ + "grasp apple_2", + "move to the left side of basket_3", + "release on the table", + ], + } + + monkeypatch.setattr( + ur5_basket_config_generation, + "_call_relative_task_llm", + fake_call_relative_task_llm, + ) + monkeypatch.setattr( + ur5_basket_config_generation, + "_resolve_table_mesh_world_zmax", + lambda scene_dir, table_obj: None, + ) + + paths = generate_ur5_basket_config_from_project( + project_dir, + tmp_path / "generated_relative_agent", + task_name="AppleLeftOfBasket", + task_description="把 apple_2 放到 basket_3 左边", + target_body_scale=0.5, + prewarm_coacd_cache=False, + ) + + gym_config = json.loads(paths.gym_config.read_text(encoding="utf-8")) + rigid_objects = {obj["uid"]: obj for obj in gym_config["rigid_object"]} + background_objects = {obj["uid"]: obj for obj in gym_config["background"]} + assert set(rigid_objects) == {"apple_2"} + assert rigid_objects["apple_2"]["body_scale"] == [0.5, 0.5, 0.5] + assert rigid_objects["apple_2"]["body_type"] == "dynamic" + assert background_objects["apple_1"]["body_scale"] == [1.0, 1.0, 1.0] + assert background_objects["apple_1"]["body_type"] == "kinematic" + assert background_objects["table"]["body_scale"] == [1.0, 1.0, 1.0] + assert background_objects["wicker_basket"]["body_scale"] == [1.0, 1.0, 1.0] + assert background_objects["wicker_basket"]["body_type"] == "kinematic" + assert background_objects["wicker_basket"]["init_rot"] == [0.0, 0.0, 180.0] + + success = gym_config["env"]["extensions"]["agent_success"] + assert success["op"] == "all" + axis_terms = { + (term.get("axis"), term.get("offset")) + for term in success["terms"] + if term["type"] == "object_axis_offset_near" + } + assert ("y", -0.16) in axis_terms + assert ("x", 0.0) in axis_terms + + assert "agent_grasp_pose_overrides" not in gym_config["env"]["extensions"] + + task_prompt = paths.task_prompt.read_text(encoding="utf-8") + assert "Move apple_2 to the left of basket_3." in task_prompt + assert ( + "Generate one deterministic nominal graph with exactly 4 nominal edges" + in task_prompt + ) + assert '"atomic_action_class":"PickUpAction","robot_name":"left_arm"' in task_prompt + assert '"atomic_action_class":"PlaceAction","robot_name":"left_arm"' in task_prompt + assert '"obj_name":"apple_2"' in task_prompt + assert "right_arm_action: null" in task_prompt + assert "Generate exactly 10 nominal edges" not in task_prompt + + assert _stable_summary(paths.summary) == { + "mode": "relative_placement", + "moved_object": "apple_2", + "reference_object": "wicker_basket", + "relation": "left_of", + "active_arm": "left_arm", + "release_offset": [0.0, -0.16, 0.12], + } + + +def test_task_description_generates_relative_front_of_config( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + project_dir = tmp_path / "1790000000_gym_project" + _write_project(project_dir) + + def fake_call_relative_task_llm(**kwargs): + assert kwargs["task_description"] == "用右臂把 apple_1 放到 apple_2 前边" + return { + "moved_object": "apple_1", + "reference_object": "apple_2", + "goal_relation": "front_of", + "arm": "right", + "task_prompt_summary": "Move apple_1 in front of apple_2.", + "basic_background_notes": "The apple_2 object is the spatial reference.", + } + + monkeypatch.setattr( + ur5_basket_config_generation, + "_call_relative_task_llm", + fake_call_relative_task_llm, + ) + monkeypatch.setattr( + ur5_basket_config_generation, + "_resolve_table_mesh_world_zmax", + lambda scene_dir, table_obj: None, + ) + + paths = generate_ur5_basket_config_from_project( + project_dir, + tmp_path / "generated_front_relative_agent", + task_name="AppleFrontOfApple", + task_description="用右臂把 apple_1 放到 apple_2 前边", + prewarm_coacd_cache=False, + ) + + gym_config = json.loads(paths.gym_config.read_text(encoding="utf-8")) + rigid_objects = {obj["uid"]: obj for obj in gym_config["rigid_object"]} + background_objects = {obj["uid"]: obj for obj in gym_config["background"]} + assert set(rigid_objects) == {"apple_1"} + assert rigid_objects["apple_1"]["body_type"] == "dynamic" + assert background_objects["apple_2"]["body_type"] == "kinematic" + assert background_objects["wicker_basket"]["body_type"] == "kinematic" + + success = gym_config["env"]["extensions"]["agent_success"] + assert success["op"] == "all" + axis_terms = { + (term.get("axis"), term.get("offset")) + for term in success["terms"] + if term["type"] == "object_axis_offset_near" + } + assert ("x", -0.16) in axis_terms + assert ("y", 0.0) in axis_terms + + task_prompt = paths.task_prompt.read_text(encoding="utf-8") + atom_actions = paths.atom_actions.read_text(encoding="utf-8") + assert '"offset":[-0.16,0.0,0.22]' in task_prompt + assert '"offset":[-0.16,0.0,0.22]' in atom_actions + + assert _stable_summary(paths.summary) == { + "mode": "relative_placement", + "moved_object": "apple_1", + "reference_object": "apple_2", + "relation": "front_of", + "active_arm": "right_arm", + "release_offset": [-0.16, 0.0, 0.12], + } + + +def test_task_description_generates_self_relative_front_left_config( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + project_dir = tmp_path / "1790000000_gym_project" + for rel_path in ( + "mesh_assets/table/table_0.glb", + "mesh_assets/chip_bag/chip_bag_1.glb", + ): + _write_minimal_glb(project_dir / rel_path, _default_mesh_vertices()) + + gym_config = { + "id": "Image2Tabletop-1790000000-v0", + "background": [ + _mesh_object( + "table", + "mesh_assets/table/table_0.glb", + [0.0, 0.0, 0.36], + [0.0, 0.0, 180.0], + ), + ], + "rigid_object": [ + _mesh_object( + "chip_bag_1", + "mesh_assets/chip_bag/chip_bag_1.glb", + [0.18, 0.22, 0.76], + [0.0, 0.0, 25.0], + ) + ], + } + (project_dir / "gym_config.json").write_text( + json.dumps(gym_config, indent=2), + encoding="utf-8", + ) + + def fake_call_relative_task_llm(**kwargs): + return { + "moved_object": "chip_bag_1", + "reference_object": "chip_bag_1", + "goal_relation": "front_left_of", + "arm": "left", + "task_prompt_summary": "Move the chip bag front-left from its start.", + } + + monkeypatch.setattr( + ur5_basket_config_generation, + "_call_relative_task_llm", + fake_call_relative_task_llm, + ) + + paths = generate_ur5_basket_config_from_project( + project_dir, + tmp_path / "generated_self_relative_agent", + task_description="用左臂把薯片袋子往左前移动", + target_body_scale=0.5, + prewarm_coacd_cache=False, + ) + + gym_config = json.loads(paths.gym_config.read_text(encoding="utf-8")) + rigid_objects = {obj["uid"]: obj for obj in gym_config["rigid_object"]} + assert set(rigid_objects) == {"chip_bag"} + initial_position = rigid_objects["chip_bag"]["init_pos"] + expected_x = round(initial_position[0] - 0.16, 6) + expected_y = round(initial_position[1] - 0.16, 6) + + success = gym_config["env"]["extensions"]["agent_success"] + assert success["op"] == "all" + axis_terms = { + (term.get("axis"), term.get("target")) + for term in success["terms"] + if term["type"] == "object_axis_near" + } + assert ("x", expected_x) in axis_terms + assert ("y", expected_y) in axis_terms + + task_prompt = paths.task_prompt.read_text(encoding="utf-8") + atom_actions = paths.atom_actions.read_text(encoding="utf-8") + assert '"reference":"absolute"' in task_prompt + assert '"reference":"absolute"' in atom_actions + assert f'"position":[{expected_x},{expected_y},' in task_prompt + + assert _stable_summary(paths.summary) == { + "mode": "relative_placement", + "moved_object": "chip_bag", + "reference_object": "chip_bag", + "relation": "front_left_of", + "active_arm": "left_arm", + "release_offset": [-0.16, -0.16, 0.12], + } + + +def test_task_description_generates_relative_front_right_config( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + project_dir = tmp_path / "1790000000_gym_project" + _write_project(project_dir) + + def fake_call_relative_task_llm(**kwargs): + return { + "moved_object": "apple_1", + "reference_object": "basket_3", + "goal_relation": "front_right_of", + "arm": "right", + "task_prompt_summary": "Move apple_1 to the front-right of basket_3.", + } + + monkeypatch.setattr( + ur5_basket_config_generation, + "_call_relative_task_llm", + fake_call_relative_task_llm, + ) + monkeypatch.setattr( + ur5_basket_config_generation, + "_resolve_table_mesh_world_zmax", + lambda scene_dir, table_obj: None, + ) + + paths = generate_ur5_basket_config_from_project( + project_dir, + tmp_path / "generated_front_right_relative_agent", + task_description="用右臂把 apple_1 放到 basket_3 右前", + prewarm_coacd_cache=False, + ) + + gym_config = json.loads(paths.gym_config.read_text(encoding="utf-8")) + success = gym_config["env"]["extensions"]["agent_success"] + axis_terms = { + (term.get("axis"), term.get("offset")) + for term in success["terms"] + if term["type"] == "object_axis_offset_near" + } + assert ("x", -0.16) in axis_terms + assert ("y", 0.16) in axis_terms + + task_prompt = paths.task_prompt.read_text(encoding="utf-8") + assert '"offset":[-0.16,0.16,0.12]' in task_prompt + assert _stable_summary(paths.summary)["release_offset"] == [-0.16, 0.16, 0.12] + + +def test_side_relation_offsets_use_robot_view_front_back_convention() -> None: + assert ur5_basket_config_generation._side_relation_xy_offsets("front_of") == ( + -0.16, + 0.0, + ) + assert ur5_basket_config_generation._side_relation_xy_offsets("behind") == ( + 0.16, + 0.0, + ) + assert ur5_basket_config_generation._side_relation_xy_offsets("front_left_of") == ( + -0.16, + -0.16, + ) + assert ur5_basket_config_generation._side_relation_xy_offsets("back_right_of") == ( + 0.16, + 0.16, + ) + + +@pytest.mark.parametrize( + ("raw_relation", "normalized"), + [ + ("左前", "front_left_of"), + ("左后", "back_left_of"), + ("右前", "front_right_of"), + ("右后", "back_right_of"), + ], +) +def test_relative_relation_aliases_include_diagonal_chinese_directions( + raw_relation: str, + normalized: str, +) -> None: + assert ur5_basket_config_generation._normalize_relative_relation(raw_relation) == ( + normalized + ) + + +def test_task_description_on_container_is_compiled_as_inside( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + project_dir = tmp_path / "1790000000_gym_project" + _write_project(project_dir) + + def fake_call_relative_task_llm(**kwargs): + return { + "moved_object": "apple_1", + "reference_object": "basket_3", + "goal_relation": "on", + "task_prompt_summary": "Release apple_1 above basket_3.", + } + + monkeypatch.setattr( + ur5_basket_config_generation, + "_call_relative_task_llm", + fake_call_relative_task_llm, + ) + monkeypatch.setattr( + ur5_basket_config_generation, + "_resolve_table_mesh_world_zmax", + lambda scene_dir, table_obj: None, + ) + + paths = generate_ur5_basket_config_from_project( + project_dir, + tmp_path / "generated_above_container_agent", + task_description="把 apple_1 放到 basket_3 上方然后松手", + ) + + gym_config = json.loads(paths.gym_config.read_text(encoding="utf-8")) + success = gym_config["env"]["extensions"]["agent_success"] + assert success["type"] == "object_in_container" + assert success["object"] == "apple_1" + assert success["container"] == "wicker_basket" + assert paths.summary["relation"] == "inside" + assert paths.summary["active_arm"] == "right_arm" + + assert "agent_grasp_pose_overrides" not in gym_config["env"]["extensions"] + + +def test_task_description_respects_explicit_left_arm( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + project_dir = tmp_path / "1790000000_gym_project" + _write_project(project_dir) + + def fake_call_relative_task_llm(**kwargs): + return { + "moved_object": "apple_1", + "reference_object": "basket_3", + "goal_relation": "left_of", + "arm": "left", + "task_prompt_summary": "Use the left arm to move apple_1.", + } + + monkeypatch.setattr( + ur5_basket_config_generation, + "_call_relative_task_llm", + fake_call_relative_task_llm, + ) + monkeypatch.setattr( + ur5_basket_config_generation, + "_resolve_table_mesh_world_zmax", + lambda scene_dir, table_obj: None, + ) + + paths = generate_ur5_basket_config_from_project( + project_dir, + tmp_path / "generated_left_arm_agent", + task_description="左臂把 apple_1 放到 basket_3 左边", + ) + + gym_config = json.loads(paths.gym_config.read_text(encoding="utf-8")) + rigid_objects = {obj["uid"]: obj for obj in gym_config["rigid_object"]} + background_objects = {obj["uid"]: obj for obj in gym_config["background"]} + assert set(rigid_objects) == {"apple_1"} + assert rigid_objects["apple_1"]["body_type"] == "dynamic" + assert background_objects["apple_2"]["body_type"] == "kinematic" + assert background_objects["wicker_basket"]["body_type"] == "kinematic" + assert background_objects["wicker_basket"]["init_rot"] == [0.0, 0.0, 180.0] + assert "agent_grasp_pose_overrides" not in gym_config["env"]["extensions"] + assert paths.summary["active_arm"] == "left_arm" + + task_prompt = paths.task_prompt.read_text(encoding="utf-8") + assert '"atomic_action_class":"PickUpAction","robot_name":"left_arm"' in task_prompt + assert '"obj_name":"apple_1"' in task_prompt + assert "right_arm_action: null" in task_prompt + + +def test_task_description_respects_explicit_right_arm( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + project_dir = tmp_path / "1790000000_gym_project" + _write_project(project_dir) + + def fake_call_relative_task_llm(**kwargs): + return { + "moved_object": "apple_2", + "reference_object": "basket_3", + "goal_relation": "right_of", + "arm": "right", + "task_prompt_summary": "Use the right arm to move apple_2.", + } + + monkeypatch.setattr( + ur5_basket_config_generation, + "_call_relative_task_llm", + fake_call_relative_task_llm, + ) + + paths = generate_ur5_basket_config_from_project( + project_dir, + tmp_path / "generated_right_arm_agent", + task_description="右臂把 apple_2 放到 basket_3 右边", + ) + + gym_config = json.loads(paths.gym_config.read_text(encoding="utf-8")) + rigid_objects = {obj["uid"]: obj for obj in gym_config["rigid_object"]} + background_objects = {obj["uid"]: obj for obj in gym_config["background"]} + assert set(rigid_objects) == {"apple_2"} + assert rigid_objects["apple_2"]["body_type"] == "dynamic" + assert background_objects["apple_1"]["body_type"] == "kinematic" + assert background_objects["wicker_basket"]["body_type"] == "kinematic" + assert "agent_grasp_pose_overrides" not in gym_config["env"]["extensions"] + assert paths.summary["active_arm"] == "right_arm" + + task_prompt = paths.task_prompt.read_text(encoding="utf-8") + assert ( + '"atomic_action_class":"PickUpAction","robot_name":"right_arm"' in task_prompt + ) + assert '"obj_name":"apple_2"' in task_prompt + assert "left_arm_action: null" in task_prompt + + +def test_demo3_relative_placement_uses_role_aware_scene_partition( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + project_dir = tmp_path / "1790000000_gym_project" + _write_demo3_role_project(project_dir) + + def fake_call_relative_task_llm(**kwargs): + assert kwargs["task_description"] == "用右臂把咖啡杯子放到垫子上" + return { + "moved_object": "cup_1", + "reference_object": "pad_1", + "goal_relation": "on", + "arm": "right", + "task_prompt_summary": "Place the cup on the pad.", + } + + monkeypatch.setattr( + ur5_basket_config_generation, + "_call_relative_task_llm", + fake_call_relative_task_llm, + ) + + paths = generate_ur5_basket_config_from_project( + project_dir, + tmp_path / "generated_demo3_relative_agent", + task_description="用右臂把咖啡杯子放到垫子上", + target_body_scale=0.8, + prewarm_coacd_cache=False, + ) + + gym_config = json.loads(paths.gym_config.read_text(encoding="utf-8")) + rigid_objects = {obj["uid"]: obj for obj in gym_config["rigid_object"]} + background_objects = {obj["uid"]: obj for obj in gym_config["background"]} + assert set(rigid_objects) == {"cup"} + assert rigid_objects["cup"]["body_type"] == "dynamic" + assert rigid_objects["cup"]["body_scale"] == [0.8, 0.8, 0.8] + assert background_objects["pad"]["body_type"] == "kinematic" + assert background_objects["pad"]["body_scale"] == [1.2, 1.0, 0.4] + assert background_objects["fork"]["body_type"] == "kinematic" + assert background_objects["fork"]["body_scale"] == [0.7, 0.7, 0.7] + + success = gym_config["env"]["extensions"]["agent_success"] + assert success["type"] == "object_on_object" + assert success["object"] == "cup" + assert success["support"] == "pad" + + atom_actions = paths.atom_actions.read_text(encoding="utf-8") + assert atom_actions.count('"atomic_action_class":"PickUpAction"') == 1 + assert ( + '"atomic_action_class":"PickUpAction","robot_name":"right_arm"' in atom_actions + ) + assert '"obj_name":"cup"' in atom_actions + assert _stable_summary(paths.summary)["relation"] == "on" + + +def test_task_description_allows_single_rigid_with_background_reference( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + project_dir = tmp_path / "1790000000_gym_project" + for rel_path in ( + "mesh_assets/table/table_0.glb", + "mesh_assets/pad/pad_1.glb", + "mesh_assets/chip_bag/chip_bag_1.glb", + ): + _write_minimal_glb(project_dir / rel_path, _default_mesh_vertices()) + + gym_config = { + "id": "Image2Tabletop-1790000000-v0", + "background": [ + _mesh_object( + "table", + "mesh_assets/table/table_0.glb", + [0.0, 0.0, 0.36], + [0.0, 0.0, 180.0], + ), + _mesh_object( + "pad_1", + "mesh_assets/pad/pad_1.glb", + [-0.1, -0.15, 0.74], + [0.0, 0.0, 0.0], + ), + ], + "rigid_object": [ + _mesh_object( + "chip_bag_1", + "mesh_assets/chip_bag/chip_bag_1.glb", + [0.18, 0.22, 0.76], + [0.0, 0.0, 25.0], + ) + ], + } + (project_dir / "gym_config.json").write_text( + json.dumps(gym_config, indent=2), + encoding="utf-8", + ) + + def fake_call_relative_task_llm(**kwargs): + scene_roles = { + item["source_uid"]: item["role"] for item in kwargs["scene_summary"] + } + assert scene_roles["chip_bag_1"] == "rigid_object" + assert scene_roles["pad_1"] == "background" + return { + "moved_object": "chip_bag_1", + "reference_object": "pad_1", + "goal_relation": "on", + "arm": "left", + "task_prompt_summary": "Place the chip bag on the pad.", + } + + monkeypatch.setattr( + ur5_basket_config_generation, + "_call_relative_task_llm", + fake_call_relative_task_llm, + ) + + paths = generate_ur5_basket_config_from_project( + project_dir, + tmp_path / "generated_single_rigid_agent", + task_description="用左臂抓薯片袋子放到垫子上", + target_body_scale=0.5, + prewarm_coacd_cache=False, + ) + + gym_config = json.loads(paths.gym_config.read_text(encoding="utf-8")) + rigid_objects = {obj["uid"]: obj for obj in gym_config["rigid_object"]} + background_objects = {obj["uid"]: obj for obj in gym_config["background"]} + assert set(rigid_objects) == {"chip_bag"} + assert rigid_objects["chip_bag"]["body_type"] == "dynamic" + assert rigid_objects["chip_bag"]["body_scale"] == [0.5, 0.5, 0.5] + assert background_objects["pad"]["body_type"] == "static" + + success = gym_config["env"]["extensions"]["agent_success"] + assert success["type"] == "object_on_object" + assert success["object"] == "chip_bag" + assert success["support"] == "pad" + + registry = gym_config["env"]["events"]["register_info_to_env"]["params"]["registry"] + registered_uids = {entry["entity_cfg"]["uid"] for entry in registry} + assert {"chip_bag", "pad"}.issubset(registered_uids) + + +def test_task_description_generates_dual_arm_relative_config( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + project_dir = tmp_path / "1790000000_gym_project" + _write_project(project_dir) + + def fake_call_relative_task_llm(**kwargs): + assert kwargs["task_description"] == ( + "左臂把 apple_2 放到 basket_3 左边,右臂把 apple_1 放到 basket_3 右边" + ) + return { + "placements": [ + { + "moved_object": "apple_2", + "reference_object": "basket_3", + "goal_relation": "left_of", + "arm": "left", + }, + { + "moved_object": "apple_1", + "reference_object": "basket_3", + "goal_relation": "right_of", + "arm": "right", + }, + ], + "task_prompt_summary": "Use both arms for two side placements.", + "basic_background_notes": "Both arms have explicit work.", + } + + monkeypatch.setattr( + ur5_basket_config_generation, + "_call_relative_task_llm", + fake_call_relative_task_llm, + ) + monkeypatch.setattr( + ur5_basket_config_generation, + "_resolve_table_mesh_world_zmax", + lambda scene_dir, table_obj: None, + ) + + paths = generate_ur5_basket_config_from_project( + project_dir, + tmp_path / "generated_dual_relative_agent", + task_description=( + "左臂把 apple_2 放到 basket_3 左边,右臂把 apple_1 放到 basket_3 右边" + ), + target_body_scale=0.7, + prewarm_coacd_cache=False, + ) + + gym_config = json.loads(paths.gym_config.read_text(encoding="utf-8")) + rigid_objects = {obj["uid"]: obj for obj in gym_config["rigid_object"]} + background_objects = {obj["uid"]: obj for obj in gym_config["background"]} + assert set(rigid_objects) == {"apple_1", "apple_2"} + assert rigid_objects["apple_1"]["body_type"] == "dynamic" + assert rigid_objects["apple_2"]["body_type"] == "dynamic" + assert background_objects["wicker_basket"]["body_type"] == "kinematic" + assert "agent_grasp_pose_overrides" not in gym_config["env"]["extensions"] + + success = gym_config["env"]["extensions"]["agent_success"] + assert success["op"] == "all" + assert len(success["terms"]) == 2 + axis_terms = { + (term["object"], term["axis"], term["offset"]) + for placement_success in success["terms"] + for term in placement_success["terms"] + if term["type"] == "object_axis_offset_near" + } + assert ("apple_2", "y", -0.16) in axis_terms + assert ("apple_1", "y", 0.16) in axis_terms + + attr_names = { + attr["name"] + for attr in gym_config["env"]["events"]["prepare_extra_attr"]["params"]["attrs"] + } + assert "grasp_pose_object" not in attr_names + + assert _stable_summary(paths.summary) == { + "mode": "dual_arm_relative_placement", + "placements": [ + { + "moved_object": "apple_2", + "reference_object": "wicker_basket", + "relation": "left_of", + "active_arm": "left_arm", + "release_offset": [0.0, -0.16, 0.12], + }, + { + "moved_object": "apple_1", + "reference_object": "wicker_basket", + "relation": "right_of", + "active_arm": "right_arm", + "release_offset": [0.0, 0.16, 0.12], + }, + ], + } + + task_prompt = paths.task_prompt.read_text(encoding="utf-8") + basic_background = paths.basic_background.read_text(encoding="utf-8") + atom_actions = paths.atom_actions.read_text(encoding="utf-8") + assert "Generate one deterministic nominal graph with exactly 6 nominal edges" in ( + task_prompt + ) + assert ( + 'left_arm_action: {"atomic_action_class":"PickUpAction","robot_name":"left_arm"' + in task_prompt + ) + assert ( + 'right_arm_action: {"atomic_action_class":"PickUpAction","robot_name":"right_arm"' + in task_prompt + ) + assert ( + '"robot_name":"right_arm","control":"hand","target_qpos":{"source":"gripper_state","state":"close"}' + in task_prompt + ) + assert '"atomic_action_class":"PlaceAction","robot_name":"left_arm"' in task_prompt + assert '"atomic_action_class":"PlaceAction","robot_name":"right_arm"' in task_prompt + assert "The inactive arm must remain null" not in task_prompt + assert "Both arms participate" in basic_background + assert "left_arm moves `apple_2`" in basic_background + assert "right_arm moves `apple_1`" in basic_background + assert ( + '"atomic_action_class":"PickUpAction","robot_name":"left_arm"' in atom_actions + ) + assert '"obj_name":"apple_2"' in atom_actions + assert ( + '"atomic_action_class":"PickUpAction","robot_name":"right_arm"' in atom_actions + ) + assert '"obj_name":"apple_1"' in atom_actions + + +def test_task_description_rejects_dual_relative_same_arm( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + project_dir = tmp_path / "1790000000_gym_project" + _write_project(project_dir) + + def fake_call_relative_task_llm(**kwargs): + return { + "placements": [ + { + "moved_object": "apple_2", + "reference_object": "basket_3", + "goal_relation": "left_of", + "arm": "left", + }, + { + "moved_object": "apple_1", + "reference_object": "basket_3", + "goal_relation": "right_of", + "arm": "left", + }, + ], + } + + monkeypatch.setattr( + ur5_basket_config_generation, + "_call_relative_task_llm", + fake_call_relative_task_llm, + ) + + with pytest.raises(ValueError, match="one left arm and one right arm"): + generate_ur5_basket_config_from_project( + project_dir, + tmp_path / "bad_dual_relative_agent", + task_description="双臂分别移动两个苹果", + ) + + +def test_task_description_dual_auto_assigns_complementary_arms( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + project_dir = tmp_path / "1790000000_gym_project" + _write_project(project_dir) + + gym_config_path = project_dir / "gym_config.json" + source_config = json.loads(gym_config_path.read_text(encoding="utf-8")) + for obj_config in source_config["rigid_object"]: + if obj_config["uid"] == "apple_1": + obj_config["init_pos"][1] = -0.03 + gym_config_path.write_text( + json.dumps(source_config, indent=2), + encoding="utf-8", + ) + + def fake_call_relative_task_llm(**kwargs): + return { + "placements": [ + { + "moved_object": "apple_2", + "reference_object": "basket_3", + "goal_relation": "left_of", + "arm": "auto", + }, + { + "moved_object": "apple_1", + "reference_object": "basket_3", + "goal_relation": "right_of", + "arm": "auto", + }, + ], + } + + monkeypatch.setattr( + ur5_basket_config_generation, + "_call_relative_task_llm", + fake_call_relative_task_llm, + ) + + paths = generate_ur5_basket_config_from_project( + project_dir, + tmp_path / "generated_dual_auto_relative_agent", + task_description="双臂分别移动两个苹果", + prewarm_coacd_cache=False, + ) + + active_arms = [placement["active_arm"] for placement in paths.summary["placements"]] + assert active_arms == ["left_arm", "right_arm"] + + gym_config = json.loads(paths.gym_config.read_text(encoding="utf-8")) + assert "agent_grasp_pose_overrides" not in gym_config["env"]["extensions"] + + +def test_task_description_on_object_uses_object_on_object_success( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + project_dir = tmp_path / "1790000000_gym_project" + _write_project(project_dir) + + def fake_call_relative_task_llm(**kwargs): + return { + "moved_object": "apple_2", + "reference_object": "apple_1", + "goal_relation": "on", + "task_prompt_summary": "Stack apple_2 on apple_1.", + } + + monkeypatch.setattr( + ur5_basket_config_generation, + "_call_relative_task_llm", + fake_call_relative_task_llm, + ) + + paths = generate_ur5_basket_config_from_project( + project_dir, + tmp_path / "generated_stack_agent", + task_description="把 apple_2 放到 apple_1 上方并松手", + target_body_scale=0.6, + ) + + gym_config = json.loads(paths.gym_config.read_text(encoding="utf-8")) + rigid_objects = {obj["uid"]: obj for obj in gym_config["rigid_object"]} + background_objects = {obj["uid"]: obj for obj in gym_config["background"]} + assert set(rigid_objects) == {"apple_2"} + assert rigid_objects["apple_2"]["body_scale"] == [0.6, 0.6, 0.6] + assert rigid_objects["apple_2"]["body_type"] == "dynamic" + assert background_objects["apple_1"]["body_scale"] == [1.0, 1.0, 1.0] + assert background_objects["apple_1"]["body_type"] == "kinematic" + assert background_objects["wicker_basket"]["body_scale"] == [1.0, 1.0, 1.0] + assert background_objects["wicker_basket"]["body_type"] == "kinematic" + + success = gym_config["env"]["extensions"]["agent_success"] + assert success["type"] == "object_on_object" + assert success["object"] == "apple_2" + assert success["support"] == "apple_1" + + task_prompt = paths.task_prompt.read_text(encoding="utf-8") + assert "on top of `apple_1`" in task_prompt + + +def test_task_description_rejects_unknown_llm_uid( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + project_dir = tmp_path / "1790000000_gym_project" + _write_project(project_dir) + + def fake_call_relative_task_llm(**kwargs): + return { + "moved_object": "missing_bread", + "reference_object": "basket_3", + "goal_relation": "left_of", + } + + monkeypatch.setattr( + ur5_basket_config_generation, + "_call_relative_task_llm", + fake_call_relative_task_llm, + ) + + with pytest.raises(ValueError, match="unknown moved_object"): + generate_ur5_basket_config_from_project( + project_dir, + tmp_path / "bad_agent", + task_description="把 missing_bread 放到 basket_3 左边", + ) + + +def test_high_tabletop_scene_adjusts_robot_height_and_light( + tmp_path: Path, +) -> None: + project_dir = tmp_path / "1790000000_gym_project" + _write_project(project_dir) + _write_minimal_glb( + project_dir / "mesh_assets/table/table_0.glb", + [(-0.5, 0.0, 0.82), (0.5, 0.0, 0.82), (0.0, -0.82, 0.82)], + ) + + gym_config_path = project_dir / "gym_config.json" + source_config = json.loads(gym_config_path.read_text(encoding="utf-8")) + for obj_config in source_config["rigid_object"]: + obj_config["init_pos"][2] = 0.12 + gym_config_path.write_text( + json.dumps(source_config, indent=2), + encoding="utf-8", + ) + + paths = generate_ur5_basket_config_from_project( + project_dir, + tmp_path / "generated_high_table_agent", + ) + + gym_config = json.loads(paths.gym_config.read_text(encoding="utf-8")) + expected_init_z = ( + 1.18 + + ur5_basket_config_generation._DUAL_UR5_TABLETOP_CLEARANCE + - ur5_basket_config_generation._DUAL_UR5_ARM_COMPONENT_Z + ) + assert gym_config["robot"]["init_pos"][2] == pytest.approx(expected_init_z) + assert gym_config["light"]["direct"][0]["intensity"] == 40.0 + + +def test_tabletop_z_placement_uses_normalized_mesh_bounds( + tmp_path: Path, +) -> None: + project_dir = tmp_path / "1790000000_gym_project" + _write_project(project_dir) + + paths = generate_ur5_basket_config_from_project( + project_dir, + tmp_path / "generated_z_agent", + target_body_scale=0.8, + prewarm_coacd_cache=False, + ) + + gym_config = json.loads(paths.gym_config.read_text(encoding="utf-8")) + table_config = next( + obj for obj in gym_config["background"] if obj["uid"] == "table" + ) + table_top_z = ur5_basket_config_generation._mesh_config_world_zmax(table_config) + expected_min_z = ( + table_top_z + ur5_basket_config_generation._TABLETOP_OBJECT_CLEARANCE + ) + for obj_config in [ + *[obj for obj in gym_config["background"] if obj["uid"] != "table"], + *gym_config["rigid_object"], + ]: + min_z, _ = ur5_basket_config_generation._mesh_config_world_z_bounds(obj_config) + assert min_z == pytest.approx(expected_min_z) + + +def test_table_mesh_world_zmax_reads_glb_vertices(tmp_path: Path) -> None: + scene_dir = tmp_path / "1790000000_gym_project" + mesh_path = scene_dir / "mesh_assets/table/table_0.glb" + _write_minimal_glb( + mesh_path, + [(-0.5, -0.5, 0.0), (0.5, -0.5, 1.2), (0.0, 0.5, 0.4)], + ) + table_obj = ur5_basket_config_generation._SceneObject( + source_uid="table", + source_role="background", + config=_mesh_object( + "table", + "mesh_assets/table/table_0.glb", + [0.0, 0.0, 0.1], + [0.0, 0.0, 0.0], + ), + ) + table_obj.config["body_scale"] = [1.0, 1.0, 2.0] + + assert ur5_basket_config_generation._resolve_table_mesh_world_zmax( + scene_dir, + table_obj, + ) == pytest.approx(2.5) + + +def test_object_on_object_success_predicate() -> None: + env = _FakeEnv( + { + "apple_2": [0.0, 0.0, 0.15], + "apple_1": [0.02, 0.01, 0.0], + } + ) + + success = evaluate_configured_success( + env, + { + "type": "object_on_object", + "object": "apple_2", + "support": "apple_1", + "xy_radius": 0.08, + "min_z_offset": 0.02, + "max_z_offset": 0.35, + }, + ) + + assert bool(success.item()) is True + + +def _write_project(project_dir: Path) -> None: + for rel_path in ( + "mesh_assets/table/table_0.glb", + "mesh_assets/basket/basket_3/basket_3.glb", + "mesh_assets/apple/apple_1/apple_1.glb", + "mesh_assets/apple/apple_2/apple_2.glb", + ): + mesh_path = project_dir / rel_path + _write_minimal_glb(mesh_path, _default_mesh_vertices()) + + gym_config = { + "id": "Image2Tabletop-1790000000-v0", + "background": [ + _mesh_object( + "table", + "mesh_assets/table/table_0.glb", + [0.0, 0.0, 0.36], + [0.0, 0.0, 180.0], + ) + ], + "rigid_object": [ + _mesh_object( + "basket_3", + "mesh_assets/basket/basket_3/basket_3.glb", + [0.0, 0.08, 0.75], + [0.0, 0.0, 180.0], + ), + _mesh_object( + "apple_1", + "mesh_assets/apple/apple_1/apple_1.glb", + [0.38, 0.11, 0.76], + [0.0, 0.0, 140.0], + ), + _mesh_object( + "apple_2", + "mesh_assets/apple/apple_2/apple_2.glb", + [-0.39, -0.12, 0.76], + [0.0, 0.0, 160.0], + ), + ], + } + (project_dir / "gym_config.json").write_text( + json.dumps(gym_config, indent=2), + encoding="utf-8", + ) + + +def _write_demo3_role_project(project_dir: Path) -> None: + for rel_path in ( + "mesh_assets/table/table_0.glb", + "mesh_assets/cup/cup_1/cup_1.glb", + "mesh_assets/pad/pad_1/pad_1.glb", + "mesh_assets/fork/fork_1/fork_1.glb", + ): + _write_minimal_glb(project_dir / rel_path, _default_mesh_vertices()) + + cup = _mesh_object( + "cup_1", + "mesh_assets/cup/cup_1/cup_1.glb", + [0.18, 0.22, 0.76], + [0.0, 0.0, 25.0], + ) + pad = _mesh_object( + "pad_1", + "mesh_assets/pad/pad_1/pad_1.glb", + [-0.1, -0.15, 0.74], + [0.0, 0.0, -10.0], + ) + pad["body_scale"] = [1.2, 1.0, 0.4] + fork = _mesh_object( + "fork_1", + "mesh_assets/fork/fork_1/fork_1.glb", + [0.32, -0.18, 0.75], + [0.0, 0.0, 90.0], + ) + fork["body_scale"] = [0.7, 0.7, 0.7] + + gym_config = { + "id": "Image2Tabletop-1790000000-v0", + "background": [ + _mesh_object( + "table", + "mesh_assets/table/table_0.glb", + [0.0, 0.0, 0.36], + [0.0, 0.0, 180.0], + ) + ], + "rigid_object": [cup, pad, fork], + } + (project_dir / "gym_config.json").write_text( + json.dumps(gym_config, indent=2), + encoding="utf-8", + ) + + +def _mesh_object( + uid: str, + fpath: str, + init_pos: list[float], + init_rot: list[float], +) -> dict: + return { + "uid": uid, + "shape": { + "shape_type": "Mesh", + "fpath": fpath, + "compute_uv": False, + }, + "init_pos": init_pos, + "init_rot": init_rot, + "body_scale": [1.0, 1.0, 1.0], + } + + +def _assert_normalized_obj_path(fpath: str) -> None: + path = Path(fpath) + assert path.suffix == ".obj" + assert "mesh_assets/normalized" in path.as_posix() + assert MESH_FRAME_NORMALIZATION_POLICY_VERSION not in path.name + assert len(path.name) <= 64 + assert path.is_file() + assert (path.parent / "material.mtl").is_file() + + +def _stable_summary(summary: dict) -> dict: + return { + key: value + for key, value in summary.items() + if key not in {"normalized_meshes", "coacd_cache"} + } + + +def _obj_vertices(path: Path) -> list[tuple[float, float, float]]: + vertices = [] + for line in path.read_text(encoding="utf-8").splitlines(): + if not line.startswith("v "): + continue + _, x, y, z = line.split(maxsplit=3) + vertices.append((float(x), float(y), float(z))) + return vertices + + +def _single_obj_material_name(obj_text: str) -> str: + names = { + line.split(maxsplit=1)[1].strip() + for line in obj_text.splitlines() + if line.startswith("usemtl ") + } + assert len(names) == 1 + return next(iter(names)) + + +def _single_map_kd_path(material_text: str, material_name: str) -> str: + current_material = None + texture_paths = [] + for line in material_text.splitlines(): + if line.startswith("newmtl "): + current_material = line.split(maxsplit=1)[1].strip() + continue + if current_material == material_name and line.startswith("map_Kd "): + texture_paths.append(line.split(maxsplit=1)[1].strip()) + assert len(texture_paths) == 1 + return texture_paths[0] + + +def _rounded_vertex_set( + vertices: list[tuple[float, float, float]], +) -> set[tuple[float, float, float]]: + return { + (round(vertex[0], 6), round(vertex[1], 6), round(vertex[2], 6)) + for vertex in vertices + } + + +def _default_mesh_vertices() -> list[tuple[float, float, float]]: + return [(-0.05, 0.0, 0.0), (0.05, 0.0, 0.0), (0.0, -0.04, 0.0)] + + +def _tiny_png() -> bytes: + return base64.b64decode( + "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADUlEQVR4nGP4z8DwHwAF" + "gAJ/l7p7YwAAAABJRU5ErkJggg==" + ) + + +def _write_minimal_glb( + path: Path, + vertices: list[tuple[float, float, float]], + *, + node_translation: tuple[float, float, float] | None = None, + embedded_base_color_png: bytes | None = None, +) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + if len(vertices) < 3: + raise ValueError("Minimal GLB test mesh requires at least three vertices.") + position_binary = b"".join(struct.pack(" list: + calls = [] + + def fake_run_prompt2geometry_replacement( + *, + prompt: str, + output_root: Path, + output_name: str, + ) -> dict: + output_root.mkdir(parents=True, exist_ok=True) + mesh_path = output_root / output_name + _write_minimal_glb(mesh_path, _default_mesh_vertices()) + calls.append((prompt, output_root, output_name)) + return {"scaled_mesh_path": str(mesh_path)} + + monkeypatch.setattr( + ur5_basket_config_generation, + "_run_prompt2geometry_replacement", + fake_run_prompt2geometry_replacement, + ) + return calls + + +class _FakeEnv: + num_envs = 1 + device = torch.device("cpu") + + def __init__(self, positions: dict[str, list[float]]) -> None: + self.sim = _FakeSim(positions) + + +class _FakeSim: + def __init__(self, positions: dict[str, list[float]]) -> None: + self._objects = { + uid: _FakeRigidObject(position) for uid, position in positions.items() + } + + def get_rigid_object(self, uid: str): + return self._objects[uid] + + +class _FakeRigidObject: + def __init__(self, position: list[float]) -> None: + self._position = torch.tensor(position, dtype=torch.float32) + + def get_local_pose(self, to_matrix: bool = True) -> torch.Tensor: + pose = torch.eye(4, dtype=torch.float32).unsqueeze(0) + pose[:, :3, 3] = self._position.reshape(1, 3) + return pose