diff --git a/.claude/skills/add-atomic-action b/.claude/skills/add-atomic-action
new file mode 120000
index 00000000..ee63a4bc
--- /dev/null
+++ b/.claude/skills/add-atomic-action
@@ -0,0 +1 @@
+../../skills/add-atomic-action
\ No newline at end of file
diff --git a/.claude/skills/add-functor b/.claude/skills/add-functor
new file mode 120000
index 00000000..59a2505a
--- /dev/null
+++ b/.claude/skills/add-functor
@@ -0,0 +1 @@
+../../skills/add-functor
\ No newline at end of file
diff --git a/.claude/skills/add-task-env b/.claude/skills/add-task-env
new file mode 120000
index 00000000..c06093df
--- /dev/null
+++ b/.claude/skills/add-task-env
@@ -0,0 +1 @@
+../../skills/add-task-env
\ No newline at end of file
diff --git a/.claude/skills/add-test b/.claude/skills/add-test
new file mode 120000
index 00000000..bc175531
--- /dev/null
+++ b/.claude/skills/add-test
@@ -0,0 +1 @@
+../../skills/add-test
\ No newline at end of file
diff --git a/.claude/skills/benchmark b/.claude/skills/benchmark
new file mode 120000
index 00000000..2735c494
--- /dev/null
+++ b/.claude/skills/benchmark
@@ -0,0 +1 @@
+../../skills/benchmark
\ No newline at end of file
diff --git a/.claude/skills/pr b/.claude/skills/pr
new file mode 120000
index 00000000..5167ba85
--- /dev/null
+++ b/.claude/skills/pr
@@ -0,0 +1 @@
+../../skills/pr
\ No newline at end of file
diff --git a/.claude/skills/pre-commit-check b/.claude/skills/pre-commit-check
new file mode 120000
index 00000000..b0cc815c
--- /dev/null
+++ b/.claude/skills/pre-commit-check
@@ -0,0 +1 @@
+../../skills/pre-commit-check
\ No newline at end of file
diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml
index 27483d52..2650bcd0 100644
--- a/.github/workflows/main.yml
+++ b/.github/workflows/main.yml
@@ -31,7 +31,7 @@ jobs:
run: |
echo "Workspace: ${GITHUB_WORKSPACE}"
ls
- pip install black==24.3.0
+ pip install black==26.3.1
black --check --diff --color ./
if [ $? -ne 0 ]; then
echo "Code style check failed, please run [black ./] before commit!"
@@ -45,24 +45,79 @@ jobs:
NVIDIA_DRIVER_CAPABILITIES: all
NVIDIA_VISIBLE_DEVICES: all
NVIDIA_DISABLE_REQUIRE: 1
+ DOCS_MAX_VERSIONS: "4" # Max number of release versions to keep
container: *container_template
steps:
- uses: actions/checkout@v4
+
+ - name: Cache Python dependencies
+ id: cache-pip
+ uses: actions/cache@v4
+ with:
+ path: ~/.cache/pip
+ key: ${{ runner.os }}-pip-docs-${{ hashFiles('docs/requirements.txt') }}
+ restore-keys: |
+ ${{ runner.os }}-pip-docs-
+
+ - name: Restore previous docs output
+ if: github.event_name == 'push'
+ uses: actions/cache@v4
+ with:
+ path: docs/build/html
+ key: docs-output-${{ github.repository }}-${{ github.ref_name }}
+ restore-keys: |
+ docs-output-${{ github.repository }}-${{ github.ref_name }}-
+ docs-output-${{ github.repository }}-
+
- name: Build docs
+ shell: bash
run: |
pip install -e . --extra-index-url http://pyp.open3dv.site:2345/simple/ --trusted-host pyp.open3dv.site
pip install -r docs/requirements.txt
+ python3 docs/scripts/sync_readme.py
cd ${GITHUB_WORKSPACE}/docs
- echo "Start Building docs..."
pip uninstall pymeshlab -y
pip install pymeshlab==2023.12.post3
- make html
+
+ if [[ "${GITHUB_REF}" == refs/tags/v* ]]; then
+ VERSION="${GITHUB_REF_NAME}"
+ echo "Building docs for release tag ${VERSION}..."
+
+ # Build only this version into its own subdirectory
+ sphinx-build source build/html/${VERSION}
+
+ cd build/html
+
+ # Prune old release versions beyond the window
+ mapfile -t TAG_DIRS < <(ls -d v*/ 2>/dev/null | sort -V)
+ while [[ ${#TAG_DIRS[@]} -gt ${DOCS_MAX_VERSIONS} ]]; do
+ echo "Pruning old version: ${TAG_DIRS[0]}"
+ rm -rf "${TAG_DIRS[0]}"
+ TAG_DIRS=("${TAG_DIRS[@]:1}")
+ done
+
+ # Generate versions.json and root index.html
+ python3 ${GITHUB_WORKSPACE}/docs/scripts/generate_versions_json.py \
+ --build-dir .
+
+ else
+ echo "Building dev docs for main branch..."
+ # Build only main/ — don't touch existing version directories
+ rm -rf build/html/main
+ sphinx-build source build/html/main
+
+ cd build/html
+
+ # Generate versions.json and root index.html
+ python3 ${GITHUB_WORKSPACE}/docs/scripts/generate_versions_json.py \
+ --build-dir .
+ fi
+
- name: Upload docs artifact
- if: github.event_name == 'push' && github.ref == 'refs/heads/main'
+ if: github.event_name == 'push'
uses: actions/upload-pages-artifact@v3
- with:
+ with:
path: ${{ github.workspace }}/docs/build/html
- retention-days: 3
test:
if: github.event_name == 'pull_request'
@@ -86,19 +141,13 @@ jobs:
pytest tests
publish:
- if: github.event_name == 'push' && github.ref == 'refs/heads/main'
+ if: github.event_name == 'push'
needs: build
runs-on: Linux
permissions:
pages: write
- id-token: write
- env:
- NVIDIA_DRIVER_CAPABILITIES: all
- NVIDIA_VISIBLE_DEVICES: all
- NVIDIA_DISABLE_REQUIRE: 1
- container: *container_template
+ id-token: write
steps:
- - uses: actions/checkout@v4
- name: Download docs artifact
uses: actions/download-artifact@v4
with:
@@ -108,40 +157,60 @@ jobs:
uses: actions/deploy-pages@v4
- # release:
- # if: startsWith(github.ref, 'refs/tags/v')
- # runs-on: Linux
- # permissions:
- # contents: write
- # id-token: write # PyPI Trusted Publishing
-
- # container: *container_template
-
- # steps:
- # - uses: actions/checkout@v4
- # with:
- # fetch-depth: 0
-
- # - name: (Release) Install build tools
- # run: |
- # python -m pip install --upgrade pip
- # pip install build
-
- # - name: (Release) Build sdist and wheel
- # run: |
- # python -m build --wheel
-
- # # - name: (Release) Create GitHub Release (draft)
- # # uses: softprops/action-gh-release@v2
- # # with:
- # # draft: true
- # # generate_release_notes: true
- # # files: |
- # # dist/*
- # # env:
- # # GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
-
- # - name: (Release) Publish to PyPI
- # uses: pypa/gh-action-pypi-publish@release/v1
- # with:
- # password: ${{ secrets.PYPI_API_TOKEN }}
\ No newline at end of file
+ release-build:
+ if: startsWith(github.ref, 'refs/tags/v')
+ needs: lint
+ runs-on: Linux
+
+ container: *container_template
+
+ steps:
+ - uses: actions/checkout@v4
+ with:
+ fetch-depth: 0
+
+ - name: (Release) Install build tools
+ run: |
+ python -m pip install --upgrade pip
+ pip install build
+
+ - name: (Release) Build sdist and wheel
+ run: |
+ python -m build
+
+ # - name: (Release) Create GitHub Release (draft)
+ # uses: softprops/action-gh-release@v2
+ # with:
+ # draft: true
+ # generate_release_notes: true
+ # files: |
+ # dist/*
+ # env:
+ # GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
+
+ - name: (Release) Upload distributions
+ uses: actions/upload-artifact@v4
+ with:
+ name: python-distributions
+ path: dist/
+
+ release-publish:
+ if: startsWith(github.ref, 'refs/tags/v')
+ needs: release-build
+ runs-on: ubuntu-latest
+ environment:
+ name: pypi
+ url: https://pypi.org/p/embodichain
+ permissions:
+ contents: read
+ id-token: write # PyPI Trusted Publishing
+
+ steps:
+ - name: (Release) Download distributions
+ uses: actions/download-artifact@v4
+ with:
+ name: python-distributions
+ path: dist/
+
+ - name: (Release) Publish to PyPI
+ uses: pypa/gh-action-pypi-publish@release/v1
diff --git a/.gitignore b/.gitignore
index 040955d9..7405b279 100644
--- a/.gitignore
+++ b/.gitignore
@@ -198,3 +198,6 @@ wandb/
.vscode/
embodichain/VERSION
+
+# benchmark results
+scripts/benchmark/rl/reports/*
\ No newline at end of file
diff --git a/AGENTS.md b/AGENTS.md
index 117cd57f..0920a327 100644
--- a/AGENTS.md
+++ b/AGENTS.md
@@ -58,10 +58,11 @@ EmbodiChain/
### Formatting
-- **Formatter**: `black==24.3.0` — run before every commit.
+- **Formatter**: `black==26.3.1` — run before every commit.
```bash
black .
```
+- Use the `/pre-commit-check` skill before committing to catch all CI violations locally.
### File Headers
@@ -108,22 +109,14 @@ class MyManagerCfg:
### Functor / Manager Pattern
-Managers (observation, event, reward, randomization) use a `Functor`/`FunctorCfg` pattern:
+Managers (observation, event, reward, randomization) use a `Functor`/`FunctorCfg` pattern with two styles:
- **Function-style**: a plain function with signature `(env, env_ids, ...) -> None`.
- **Class-style**: a class inheriting `Functor`, with `__init__(cfg, env)` and `__call__(env, env_ids, ...)`.
-- Registered in a manager config via `FunctorCfg(func=..., params={...})`.
-```python
-from embodichain.lab.gym.envs.managers import Functor, FunctorCfg
-
-class my_randomizer(Functor):
- def __init__(self, cfg: FunctorCfg, env):
- super().__init__(cfg, env)
+Registered in a manager config via `FunctorCfg(func=..., params={...})`.
- def __call__(self, env, env_ids, my_param: float = 0.5):
- ...
-```
+Use the `/add-functor` skill to scaffold new functors with the correct signature and module placement.
### Docstrings
@@ -200,20 +193,10 @@ Include:
1. **Fork** the repository and create a focused branch.
2. **Keep PRs small** — one logical change per PR.
-3. **Format** the code with `black==24.3.0` before submitting.
+3. **Format** the code with `black==26.3.1` before submitting.
4. **Update documentation** for any public API changes.
5. **Add tests** that prove your fix or feature works.
-6. **Submit** using the PR template (`.github/PULL_REQUEST_TEMPLATE.md`):
- - Summarize changes and link the related issue (`Fixes #123`).
- - Specify the type of change (bug fix / enhancement / new feature / breaking change / docs).
- - Attach before/after screenshots for visual changes.
- - Complete the checklist:
- - [ ] `black .` has been run
- - [ ] Documentation updated
- - [ ] Tests added
- - [ ] Dependencies updated (if applicable)
-
-> It is recommended to open an issue and discuss the design before opening a large PR.
+6. Use the `/pr` skill to create PRs following the project's template and label conventions.
### Adding a New Robot
@@ -231,107 +214,25 @@ Also add robot documentation in `docs/source/resources/robot/` (see existing exa
### Adding a New Task Environment
-Refer to `embodichain/lab/gym/envs/tasks/` for existing examples. Tasks subclass `EmbodiedEnv` or `BaseAgentEnv` and implement `_setup_scene`, `_reset_idx`, and evaluation logic.
-
----
-
-## Unit Tests
-
-### Structure
-
-Tests live in `tests/` and mirror the source tree:
-
-```text
-tests/
-├── toolkits/
-│ └── test_pg_grasp.py
-├── gym/
-│ └── action_bank/
-│ └── test_configurable_action.py
-└── sim/
- ├── objects/
- │ ├── test_light.py
- │ └── test_rigid_object_group.py
- ├── sensors/
- │ ├── test_camera.py
- │ └── test_stereo.py
- └── planners/
- └── test_motion_generator.py
-```
-
-Place new test files at `tests//test_.py`, matching the layout of `embodichain/`.
+Use the `/add-task-env` skill to scaffold a new task with the correct file structure, `@register_env` decorator, base class, and test stub.
-### Two accepted styles
+### Adding Functors
-**pytest style** — for pure-Python logic with no test ordering dependency:
+Use the `/add-functor` skill to scaffold observation, reward, event, action, dataset, or randomization functors with the correct signature, style, and module placement.
-```python
-# ----------------------------------------------------------------------------
-# Copyright (c) 2021-2026 DexForce Technology Co., Ltd.
-# Licensed under the Apache License, Version 2.0 (the "License");
-# ...
-# ----------------------------------------------------------------------------
-
-from embodichain.my_module import my_function
-
-
-def test_expected_output():
- result = my_function(input_value)
- assert result == expected_value
-
-
-def test_edge_case():
- result = my_function(edge_input)
- assert result is not None
-```
+### Writing Tests
-**`Class` style** — when tests must run in a specific order or share `setup_method`/`teardown_method` state:
+Use the `/add-test` skill to scaffold tests with the correct file placement, style (pytest vs class), mock patterns, and project conventions.
-```python
-# ----------------------------------------------------------------------------
-# Copyright (c) 2021-2026 DexForce Technology Co., Ltd.
-# Licensed under the Apache License, Version 2.0 (the "License");
-# ...
-# ----------------------------------------------------------------------------
-
-from embodichain.my_module import MyClass
-
-
-class TestMyClass():
- def setup_method(self):
- self.obj = MyClass(param=1.0)
-
- def teardown_method(self):
- pass
-
- def test_basic_behavior(self):
- result = self.obj.run()
- assert result == expected_result
-
- def test_raises_on_bad_input(self):
- with pytest.raises(ValueError):
- self.obj.run(bad_input)
-
-### Conventions
-
-- **File header**: include the standard Apache 2.0 copyright block (same as all source files).
-- **Naming**: test files are `test_.py`; test functions/methods are `test_`.
-- **Simulation-dependent tests**: tests that require a running `SimulationManager` (GPU, sensors, robots) must initialize and teardown the sim inside `setUp`/`tearDown` or a pytest fixture. Keep them isolated from pure-logic tests.
-- **No magic numbers**: define expected values as named constants or comments explaining their origin.
-- **`if __name__ == "__main__"`**: include this block for tests that support optional visual/interactive output (pass `is_visual=True` manually when debugging).
-
-### Running tests
-
-```bash
-# Run all tests
-pytest tests/
-
-# Run a specific file
-pytest tests/toolkits/test_pg_grasp.py
+---
-# Run a specific test function
-pytest tests/toolkits/test_pg_grasp.py::test_antipodal_score_selector
+## Skills Quick Reference
-# Run with verbose output
-pytest -v tests/
-```
+| Skill | Command | Purpose |
+|-------|---------|---------|
+| Add Task Env | `/add-task-env` | Scaffold a new `EmbodiedEnv` task |
+| Add Functor | `/add-functor` | Scaffold observation/reward/event/action/dataset/randomization functors |
+| Add Test | `/add-test` | Scaffold tests following project conventions |
+| Pre-Commit Check | `/pre-commit-check` | Run all local CI checks before committing |
+| Create PR | `/pr` | Create a PR following the project template |
+| Benchmark | `/benchmark` | Write benchmark scripts for EmbodiChain modules |
diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
index c8ce9852..af1401cf 100644
--- a/CONTRIBUTING.md
+++ b/CONTRIBUTING.md
@@ -29,15 +29,31 @@ We welcome pull requests for bug fixes, new features, and documentation improvem
```bash
black .
```
- > Currently, we use black==24.3.0 for formatting. Make sure to use the same version to avoid inconsistencies.
+ > Currently, we use black==26.3.1 for formatting. Make sure to use the same version to avoid inconsistencies.
4. **Submit a Pull Request**.
* Use the [Pull Request Template](.github/PULL_REQUEST_TEMPLATE.md).
* Keep PRs small and focused.
* Include a summary of the changes and link to any relevant issues (e.g., `Fixes #123`).
* Ensure all checks pass.
+
+## Contribute specific robots
+
+To contribute a new robot, please check the documentation on [Adding a New Robot](https://dexforce.github.io/EmbodiChain/guides/add_robot.html).
+
+## Contribute specific environments
+
+To contribute a new environment, please check the documentation on [Embodied Environments](https://dexforce.github.io/EmbodiChain/overview/gym/env.html) and see the tutorial below:
+- [Creating a Basic Environment](https://dexforce.github.io/EmbodiChain/tutorial/basic_env.html)
+- [Creating a Modular Environment](https://dexforce.github.io/EmbodiChain/tutorial/modular_env.html)
+
+If you want to implement your tasks in a new repo and with some customized functors and utilities, you can also use the [Task Template Repo](https://github.com/DexForce/embodichain_task_template).
+
## Using Claude Code for Contributions
+
+Setup, skills, and tips for using Claude Code
+
[Claude Code](https://docs.anthropic.com/en/docs/claude-code/overview) is an AI-powered CLI that can assist you throughout the contribution workflow — from understanding the codebase to writing, reviewing, and debugging code.
### Setup
@@ -51,6 +67,33 @@ claude
A `CLAUDE.md` file is present at the root of this repository. Claude Code reads it automatically at startup to load project conventions, structure, and style rules, so it is context-aware from the first prompt.
+### Skills
+
+Claude Code skills are built-in slash commands that automate common development tasks. They scaffold code, run checks, and enforce project conventions so you can focus on your logic instead of boilerplate. Invoke any skill by typing its command in the Claude Code prompt.
+
+| Skill | Command | Purpose |
+|-------|---------|---------|
+| Add Functor | `/add-functor` | Scaffold a new observation, reward, event, action, dataset, or randomization functor with the correct signature, style, and module placement |
+| Add Task Env | `/add-task-env` | Scaffold a new task environment with the correct file structure, `@register_env` decorator, base class, and test stub |
+| Add Test | `/add-test` | Scaffold tests with the correct file placement, style (pytest vs class), mock patterns, and project conventions |
+| Pre-Commit Check | `/pre-commit-check` | Run all local CI checks — code style, headers, annotations, exports, and docstrings — before committing |
+| Create PR | `/pr` | Create a pull request following the project template and label conventions |
+| Benchmark | `/benchmark` | Write benchmark scripts for measuring performance of solvers, samplers, and other computationally intensive components |
+
+#### When to use each skill
+
+**`/add-functor`** — Use when adding a new observation, event, reward, action, dataset, or randomization functor to an EmbodiChain environment. The skill will ask for the functor type and name, then generate the function- or class-style implementation with proper docstrings, type hints, and `__all__` exports.
+
+**`/add-task-env`** — Use when creating a new task environment, including expert demonstration tasks, RL tasks, or any `EmbodiedEnv` subclass. The skill scaffolds the task file with `_setup_scene`, `_reset_idx`, and evaluation logic, plus a test stub.
+
+**`/add-test`** — Use when writing tests for any EmbodiChain module — functors, solvers, sensors, environments, or utilities. The skill determines the correct test file location, style (pytest function vs class), and generates tests with the standard Apache 2.0 header and named constants.
+
+**`/pre-commit-check`** — Run this before committing or creating a PR. It verifies code formatting (`black`), file headers, type annotations, `__all__` exports, and docstring completeness — the same checks the CI pipeline enforces.
+
+**`/pr`** — Use after committing your changes to create a pull request. The skill checks git state, determines the PR type, drafts a description following the project template, runs formatting, creates a feature branch, and opens the PR via `gh` CLI with the correct labels.
+
+**`/benchmark`** — Use when you need to measure the performance of a module (IK solvers, grasp samplers, metrics, etc.). The skill generates a well-structured benchmark script following project conventions.
+
### Suggested workflows
**Explore the codebase before making changes**
@@ -66,7 +109,7 @@ A `CLAUDE.md` file is present at the root of this repository. Claude Code reads
```
> I want to add a new observation functor that returns the end-effector velocity.
Which existing functor should I model it after?
-> Generate the functor following the project style, with a proper docstring and type hints.
+> /add-functor
```
**Validate style and formatting before submitting**
@@ -74,13 +117,13 @@ A `CLAUDE.md` file is present at the root of this repository. Claude Code reads
```
> Review my changes in embodichain/lab/gym/envs/managers/randomization/visual.py
for style issues, missing type hints, and docstring completeness.
+> /pre-commit-check
```
**Write or update tests**
```
-> Write a pytest test for the randomize_emission_light function in
- embodichain/lab/gym/envs/managers/randomization/visual.py.
+> /add-test
```
**Understand a bug**
@@ -92,38 +135,17 @@ A `CLAUDE.md` file is present at the root of this repository. Claude Code reads
**Create a pull request**
-After you've made your changes and committed them, use the `/pr` command to create a pull request:
+After you've made your changes and committed them:
```
> /pr
```
-This will guide you through:
-1. Checking the current git state and changes
-2. Determining the PR type (bug fix, enhancement, new feature, etc.)
-3. Drafting a proper PR description following the project template
-4. Running code formatting with `black .`
-5. Creating a properly named feature branch
-6. Committing changes with a conventional commit message
-7. Pushing to remote and creating the PR via `gh` CLI
-
-The `/pr` skill ensures your PR follows the EmbodiChain contribution guidelines and populates the required checklist items.
+The `/pr` skill will guide you through checking git state, determining the PR type, drafting a description, running formatting, and creating the PR with proper labels.
### Tips
-* Always run `black .` after Claude Code generates or edits Python files — Claude Code can do this for you if you ask.
+* Always run `/pre-commit-check` after making changes — it catches the same issues the CI pipeline checks.
* Claude Code respects the `CLAUDE.md` conventions. If you notice it deviating (wrong docstring style, missing `__all__`, etc.), point it out and it will correct the output.
-* For large features, break the work into small, focused tasks and handle them one at a time.
-* Claude Code can help draft your PR description and populate the PR checklist once your changes are ready.
-
-## Contribute specific robots
-
-To contribute a new robot, please check the documentation on [Adding a New Robot](https://dexforce.github.io/EmbodiChain/guides/add_robot.html).
-
-## Contribute specific environments
-
-To contribute a new environment, please check the documentation on [Embodied Environments](https://dexforce.github.io/EmbodiChain/overview/gym/env.html) and see the tutorial below:
-- [Creating a Basic Environment](https://dexforce.github.io/EmbodiChain/tutorial/basic_env.html)
-- [Creating a Modular Environment](https://dexforce.github.io/EmbodiChain/tutorial/modular_env.html)
-
-If you want to implement your tasks in a new repo and with some customized functors and utilities, you can also use the [Task Template Repo](https://github.com/DexForce/embodichain_task_template).
\ No newline at end of file
+* For large features, break the work into small, focused tasks and handle them one at a time using the appropriate skill for each step.
+* If you add a new skill to `.claude/skills/`, make sure to also add it to the Skills table and "When to use each skill" list in this document so contributors can discover it.
\ No newline at end of file
diff --git a/README.md b/README.md
index e7c28de7..5c9cdb97 100644
--- a/README.md
+++ b/README.md
@@ -2,18 +2,18 @@

-[](LICENSE)
-[](https://dexforce.com/embodichain/index.html#/)
-[](https://dexforce.github.io/EmbodiChain/introduction.html)
-[](https://docs.python.org/3/whatsnew/3.10.html)
-[](https://github.com/DexForce/EmbodiChain/releases)
+[](LICENSE)
+[](https://dexforce.com/embodichain/index.html#/)
+[](https://dexforce.github.io/EmbodiChain/main/index.html)
+[](https://docs.python.org/3/whatsnew/3.10.html)
+[](https://github.com/DexForce/EmbodiChain/releases)
---
-EmbodiChain is an end-to-end, GPU-accelerated framework for Embodied AI. It streamlines research and development by unifying high-performance simulation, real-to-sim data pipelines, modular model architectures, and efficient training workflows. This integration enables rapid experimentation, seamless deployment of intelligent agents, and effective Sim2Real transfer for real-world robotic systems.
+EmbodiChain is an end-to-end, GPU-accelerated framework for Embodied AI. It streamlines research and development by unifying high-performance simulation, automated generative data pipelines, modular model architectures, and efficient training workflows. This integration enables rapid experimentation, seamless deployment of intelligent agents, and effective Sim2Real transfer for real-world robotic systems.
> [!NOTE]
> EmbodiChain is in Alpha and under active development:
-> * More features will be continually added in the coming months. You can find more details in the [roadmap](https://dexforce.github.io/EmbodiChain/resources/roadmap.html).
+> * More features will be continually added in the coming months. You can find more details in the [roadmap](https://dexforce.github.io/EmbodiChain/main/resources/roadmap.html).
> * Since this is an early release, we welcome feedback (bug reports, feature requests, etc.) via GitHub Issues.
@@ -36,9 +36,9 @@ The figure below illustrates the overall architecture of EmbodiChain:
To get started with EmbodiChain, follow these steps:
-- [Installation Guide](https://dexforce.github.io/EmbodiChain/quick_start/install.html)
-- [Quick Start Tutorial](https://dexforce.github.io/EmbodiChain/tutorial/index.html)
-- [API Reference](https://dexforce.github.io/EmbodiChain/api_reference/index.html)
+- [Installation Guide](https://dexforce.github.io/EmbodiChain/main/quick_start/install.html)
+- [Quick Start Tutorial](https://dexforce.github.io/EmbodiChain/main/tutorial/index.html)
+- [API Reference](https://dexforce.github.io/EmbodiChain/main/api_reference/index.html)
## Contribution Guide
diff --git a/VERSION b/VERSION
index b1e80bb2..0ea3a944 100644
--- a/VERSION
+++ b/VERSION
@@ -1 +1 @@
-0.1.3
+0.2.0
diff --git a/configs/agents/rl/basic/cart_pole/train_config.json b/configs/agents/rl/basic/cart_pole/train_config.json
index 02a302d1..6da5f735 100644
--- a/configs/agents/rl/basic/cart_pole/train_config.json
+++ b/configs/agents/rl/basic/cart_pole/train_config.json
@@ -1,11 +1,10 @@
-{
+{
"trainer": {
"exp_name": "cart_pole_ppo",
"gym_config": "configs/agents/rl/basic/cart_pole/gym_config.json",
"seed": 42,
"device": "cuda:0",
"headless": true,
- "enable_rt": false,
"gpu_id": 0,
"num_envs": 64,
"iterations": 1000,
@@ -22,30 +21,57 @@
"interval_step": 1,
"params": {
"name": "main_cam",
- "resolution": [640, 480],
- "eye": [-1.4, 1.4, 2.5],
- "target": [0, 0, 0.7],
- "up": [0, 0, 1],
- "intrinsics": [600, 600, 320, 240],
+ "resolution": [
+ 640,
+ 480
+ ],
+ "eye": [
+ -1.4,
+ 1.4,
+ 2.5
+ ],
+ "target": [
+ 0,
+ 0,
+ 0.7
+ ],
+ "up": [
+ 0,
+ 0,
+ 1
+ ],
+ "intrinsics": [
+ 600,
+ 600,
+ 320,
+ 240
+ ],
"save_path": "./outputs/videos/eval"
}
}
}
- }
+ },
+ "renderer": "fast-rt"
},
"policy": {
"name": "actor_critic",
"actor": {
"type": "mlp",
"network_cfg": {
- "hidden_sizes": [256, 256],
+ "hidden_sizes": [
+ 256,
+ 256
+ ],
"activation": "relu"
}
},
"critic": {
"type": "mlp",
"network_cfg": {
- "hidden_sizes": [256, 256],
+ "hidden_sizes": [
+ 256,
+ 256
+ ],
"activation": "relu"
}
}
@@ -64,4 +90,4 @@
"max_grad_norm": 0.5
}
}
-}
+}
\ No newline at end of file
diff --git a/configs/agents/rl/basic/cart_pole/train_config_grpo.json b/configs/agents/rl/basic/cart_pole/train_config_grpo.json
index 4da5cab7..86ac34f2 100644
--- a/configs/agents/rl/basic/cart_pole/train_config_grpo.json
+++ b/configs/agents/rl/basic/cart_pole/train_config_grpo.json
@@ -5,7 +5,6 @@
"seed": 42,
"device": "cuda:0",
"headless": true,
- "enable_rt": false,
"gpu_id": 0,
"num_envs": 64,
"iterations": 1000,
@@ -23,23 +22,47 @@
"interval_step": 1,
"params": {
"name": "main_cam",
- "resolution": [640, 480],
- "eye": [-1.4, 1.4, 2.5],
- "target": [0, 0, 0.7],
- "up": [0, 0, 1],
- "intrinsics": [600, 600, 320, 240],
+ "resolution": [
+ 640,
+ 480
+ ],
+ "eye": [
+ -1.4,
+ 1.4,
+ 2.5
+ ],
+ "target": [
+ 0,
+ 0,
+ 0.7
+ ],
+ "up": [
+ 0,
+ 0,
+ 1
+ ],
+ "intrinsics": [
+ 600,
+ 600,
+ 320,
+ 240
+ ],
"save_path": "./outputs/videos/eval"
}
}
}
- }
+ },
+ "renderer": "hybrid"
},
"policy": {
"name": "actor_only",
"actor": {
"type": "mlp",
"network_cfg": {
- "hidden_sizes": [256, 256],
+ "hidden_sizes": [
+ 256,
+ 256
+ ],
"activation": "relu"
}
}
@@ -55,7 +78,7 @@
"ent_coef": 0.01,
"kl_coef": 0.0,
"group_size": 4,
- "eps": 1e-8,
+ "eps": 1e-08,
"reset_every_rollout": true,
"max_grad_norm": 0.5,
"truncate_at_first_done": true
diff --git a/configs/agents/rl/push_cube/gym_config.json b/configs/agents/rl/push_cube/gym_config.json
index 4e8cec4d..a97cc65d 100644
--- a/configs/agents/rl/push_cube/gym_config.json
+++ b/configs/agents/rl/push_cube/gym_config.json
@@ -71,33 +71,33 @@
"reaching_reward": {
"func": "reaching_behind_object",
"mode": "add",
- "weight": 0.1,
+ "weight": 0.03,
"params": {
"object_cfg": {
"uid": "cube"
},
"target_pose_key": "goal_pose",
- "behind_offset": 0.015,
+ "behind_offset": 0.03,
"height_offset": 0.015,
- "distance_scale": 5.0,
+ "distance_scale": 8.0,
"part_name": "arm"
}
},
- "place_reward": {
- "func": "incremental_distance_to_target",
+ "goal_distance_reward": {
+ "func": "distance_to_target",
"mode": "add",
- "weight": 1.0,
+ "weight": 0.8,
"params": {
"source_entity_cfg": {
"uid": "cube"
},
"target_pose_key": "goal_pose",
- "tanh_scale": 10.0,
- "positive_weight": 2.0,
- "negative_weight": 0.5,
+ "exponential": true,
+ "sigma": 0.12,
"use_xy_only": true
}
},
+
"action_penalty": {
"func": "action_smoothness_penalty",
"mode": "add",
@@ -175,9 +175,9 @@
"body_type": "dynamic",
"init_pos": [-0.6, -0.4, 0.05],
"attrs": {
- "mass": 10.0,
- "static_friction": 3.0,
- "dynamic_friction": 2.0,
+ "mass": 2.0,
+ "static_friction": 1.0,
+ "dynamic_friction": 0.8,
"linear_damping": 2.0,
"angular_damping": 2.0,
"contact_offset": 0.003,
diff --git a/configs/agents/rl/push_cube/train_config.json b/configs/agents/rl/push_cube/train_config.json
index d44aa0b3..11b0972d 100644
--- a/configs/agents/rl/push_cube/train_config.json
+++ b/configs/agents/rl/push_cube/train_config.json
@@ -1,11 +1,10 @@
-{
+{
"trainer": {
"exp_name": "push_cube_ppo",
"gym_config": "configs/agents/rl/push_cube/gym_config.json",
"seed": 42,
"device": "cuda:0",
"headless": true,
- "enable_rt": false,
"gpu_id": 0,
"num_envs": 64,
"iterations": 1000,
@@ -13,9 +12,9 @@
"enable_eval": true,
"num_eval_envs": 16,
"num_eval_episodes": 3,
- "eval_freq": 2,
- "save_freq": 200,
- "use_wandb": false,
+ "eval_freq": 100,
+ "save_freq": 100,
+ "use_wandb": true,
"wandb_project_name": "embodichain-push_cube",
"events": {
"eval": {
@@ -30,25 +29,32 @@
"target": [0, 0, 0],
"up": [0, 0, 1],
"intrinsics": [600, 600, 320, 240],
- "save_path": "./outputs/videos/eval"
+ "save_path": "./outputs/videos_ppo1/eval"
}
}
}
- }
+ },
+ "renderer": "hybrid"
},
"policy": {
"name": "actor_critic",
"actor": {
"type": "mlp",
"network_cfg": {
- "hidden_sizes": [256, 256],
+ "hidden_sizes": [
+ 256,
+ 256
+ ],
"activation": "relu"
}
},
"critic": {
"type": "mlp",
"network_cfg": {
- "hidden_sizes": [256, 256],
+ "hidden_sizes": [
+ 256,
+ 256
+ ],
"activation": "relu"
}
}
@@ -67,4 +73,4 @@
"max_grad_norm": 0.5
}
}
-}
+}
\ No newline at end of file
diff --git a/configs/agents/rl/push_cube/train_config_grpo.json b/configs/agents/rl/push_cube/train_config_grpo.json
new file mode 100644
index 00000000..df5f6681
--- /dev/null
+++ b/configs/agents/rl/push_cube/train_config_grpo.json
@@ -0,0 +1,65 @@
+{
+ "trainer": {
+ "exp_name": "push_cube_grpo",
+ "gym_config": "configs/agents/rl/push_cube/gym_config.json",
+ "seed": 42,
+ "device": "cuda:0",
+ "headless": true,
+ "gpu_id": 0,
+ "num_envs": 64,
+ "iterations": 1000,
+ "buffer_size": 1024,
+ "enable_eval": true,
+ "num_eval_envs": 16,
+ "num_eval_episodes": 3,
+ "eval_freq": 200,
+ "save_freq": 200,
+ "use_wandb": false,
+ "wandb_project_name": "embodichain-push_cube",
+ "events": {
+ "eval": {
+ "record_camera": {
+ "func": "record_camera_data_async",
+ "mode": "interval",
+ "interval_step": 1,
+ "params": {
+ "name": "main_cam",
+ "resolution": [640, 480],
+ "eye": [-1.4, 1.4, 2.0],
+ "target": [0, 0, 0],
+ "up": [0, 0, 1],
+ "intrinsics": [600, 600, 320, 240],
+ "save_path": "./outputs/videos/eval"
+ }
+ }
+ }
+ }
+ },
+ "policy": {
+ "name": "actor_only",
+ "actor": {
+ "type": "mlp",
+ "network_cfg": {
+ "hidden_sizes": [256, 256],
+ "activation": "relu"
+ }
+ }
+ },
+ "algorithm": {
+ "name": "grpo",
+ "cfg": {
+ "learning_rate": 0.0001,
+ "n_epochs": 10,
+ "batch_size": 8192,
+ "gamma": 0.99,
+ "clip_coef": 0.2,
+ "ent_coef": 0.01,
+ "kl_coef": 0.0,
+ "group_size": 4,
+ "eps": 1e-8,
+ "reset_every_rollout": true,
+ "max_grad_norm": 0.5,
+ "truncate_at_first_done": true
+ }
+ }
+}
diff --git a/configs/gym/pour_water/gym_config_simple.json b/configs/gym/pour_water/gym_config_simple.json
index ca45e80b..bcce5bc4 100644
--- a/configs/gym/pour_water/gym_config_simple.json
+++ b/configs/gym/pour_water/gym_config_simple.json
@@ -203,7 +203,7 @@
"mode": "modify",
"name": "robot/qpos",
"params": {
- "joint_ids": [12, 13, 14, 15]
+ "joint_ids": [6, 13]
}
}
},
@@ -227,7 +227,8 @@
"use_videos": true
}
}
- }
+ },
+ "control_parts": ["left_arm", "left_eef", "right_arm", "right_eef"]
},
"robot": {
"uid": "CobotMagic",
diff --git a/configs/language/README.md b/configs/language/README.md
new file mode 100644
index 00000000..c69148c2
--- /dev/null
+++ b/configs/language/README.md
@@ -0,0 +1,275 @@
+# Language Support for VLA Training
+
+This directory contains configuration and examples for the hierarchical language support feature in EmbodiChain, enabling Vision-Language-Action (VLA) model training with Online Data Streaming (ODS).
+
+## Overview
+
+The language support feature adds hierarchical language descriptions to the rollout buffer, organized at three abstraction levels:
+
+1. **Task Level**: High-level goal or overall task description
+2. **Subtask Level**: Intermediate step descriptions
+3. **Primitive Level**: Low-level action descriptions
+
+This hierarchical structure enables VLA models to learn from multi-scale language representations, similar to human task understanding.
+
+## Features
+
+- **Multiple Language Sources**: Support for file-based, environment-based, template-based, and LLM-generated language
+- **Hierarchical Structure**: Organize instructions at multiple abstraction levels
+- **Flexible Storage**: Support for tokens, embeddings, or hybrid storage modes
+- **Dynamic Chunk Sizes**: Works with variable-length trajectory chunks
+- **Curriculum Learning**: Gradually increase language complexity during training
+- **Token Agnostic**: Works with various tokenizers (GPT, BERT, etc.)
+
+## Quick Start
+
+### 1. Prepare Language Configuration
+
+Create a YAML file with task descriptions:
+
+```yaml
+# tasks.yaml
+pick_and_place:
+ task:
+ - "Pick up the red block and place it in the blue basket."
+
+ subtask:
+ - "Move the gripper to the red block."
+ - "Grasp the red block."
+ - "Lift the block and move to the blue basket."
+ - "Release the block into the basket."
+
+ primitive:
+ - "Close gripper."
+ - "Move up."
+ - "Move right."
+ - "Open gripper."
+```
+
+### 2. Configure ODS Engine
+
+```python
+from embodichain.agents.engine.data import OnlineDataEngine, OnlineDataEngineCfg
+
+language_cfg = {
+ "mode": "tokens",
+ "hierarchy_levels": ["task", "subtask", "primitive"],
+ "max_tokens": 512,
+ "tokenizer": "gpt2",
+ "language_source": "file",
+ "language_config_path": "configs/language/tasks.yaml",
+ "max_instructions_per_level": 5,
+}
+
+engine_cfg = OnlineDataEngineCfg(
+ buffer_size=16,
+ max_episode_steps=300,
+ state_dim=14,
+ gym_config={...},
+ language_cfg=language_cfg,
+)
+
+engine = OnlineDataEngine(engine_cfg)
+engine.start()
+```
+
+### 3. Use Language Data in Training
+
+```python
+from embodichain.agents.datasets.online_data import OnlineDataset
+from torch.utils.data import DataLoader
+
+dataset = OnlineDataset(engine, chunk_size=64, batch_size=8)
+loader = DataLoader(dataset, batch_size=None)
+
+for batch in loader:
+ obs = batch["obs"]
+ actions = batch["actions"]
+ language = batch["language"]
+
+ # Access language at different hierarchy levels
+ task_tokens = language["task_level_tokens"]
+ subtask_tokens = language["subtask_level_tokens"]
+ primitive_tokens = language["primitive_level_tokens"]
+
+ # Train your VLA model
+ # loss = vla_model(obs, language, actions)
+```
+
+## Configuration Options
+
+### Language Configuration
+
+| Parameter | Type | Default | Description |
+|-----------|------|---------|-------------|
+| `mode` | str | "tokens" | Storage mode: 'tokens', 'embeddings', or 'hybrid' |
+| `hierarchy_levels` | list | ["task", "subtask", "primitive"] | Hierarchy levels to store |
+| `max_tokens` | int | 512 | Maximum sequence length per instruction |
+| `tokenizer` | str | "gpt2" | Tokenizer identifier |
+| `pad_token_id` | int | 0 | Token ID used for padding |
+| `max_instructions_per_level` | int | 3 | Maximum number of instructions per level |
+| `embedding_dim` | int | 768 | Dimension for embeddings (if mode='embeddings') |
+| `language_source` | str | "env" | Source of language: 'env', 'file', 'llm', 'template' |
+| `language_config_path` | str | None | Path to language config file (if source='file') |
+
+### Language Sources
+
+#### File-Based (`language_source: "file"`)
+Load language descriptions from YAML or JSON files. Best for static task descriptions.
+
+```python
+language_cfg = {
+ "language_source": "file",
+ "language_config_path": "configs/language/tasks.yaml",
+}
+```
+
+#### Environment-Based (`language_source: "env"`)
+Generate language descriptions from the environment. The environment should implement:
+- `get_task_language(task_id, context) -> HierarchicalLanguageData`
+- Or have a `task_description` attribute
+
+```python
+language_cfg = {
+ "language_source": "env",
+}
+```
+
+#### Template-Based (`language_source: "template"`)
+Use templates with variable substitution for structured tasks.
+
+```python
+language_cfg = {
+ "language_source": "template",
+ "templates": {
+ "pick_and_place": {
+ "task": "Pick up the {color} {object} and place it {location}.",
+ "subtasks": [...],
+ }
+ },
+ "variables": {"color": "red", "object": "block", "location": "in basket"},
+}
+```
+
+#### LLM-Based (`language_source: "llm"`)
+Generate descriptions using an LLM (e.g., GPT-4, Claude).
+
+```python
+language_cfg = {
+ "language_source": "llm",
+ "model": "gpt-4",
+ "api_key": "your-api-key",
+}
+```
+
+## Buffer Structure
+
+When language support is enabled, the rollout buffer includes the following fields:
+
+### Per-Hierarchy-Level Fields
+
+For each level in `hierarchy_levels` (e.g., "task", "subtask", "primitive"):
+
+- `{level}_tokens`: `[batch_size, max_episode_steps, max_instructions, max_tokens]`
+- `{level}_attention_mask`: `[batch_size, max_episode_steps, max_instructions, max_tokens]`
+- `{level}_count`: `[batch_size, max_episode_steps]`
+
+### Global Fields
+
+- `instruction_counts`: `[batch_size, max_episode_steps, 3]` - Counts per hierarchy level
+- `change_points`: `[batch_size, max_episode_steps, max_instructions]` - Timesteps where language changes
+- `hierarchy_depth`: `[batch_size, max_episode_steps]` - Current depth of hierarchy (1-3)
+- `instruction_types`: `[batch_size, max_episode_steps, max_instructions]` - Instruction type IDs
+
+## Advanced Usage
+
+### Custom Language Provider
+
+```python
+from embodichain.lab.gym.envs.managers import LanguageProvider, HierarchicalLanguageData
+
+class MyLanguageProvider(LanguageProvider):
+ def get_language(self, task_id, context=None):
+ # Generate custom language data
+ return HierarchicalLanguageData(
+ task_level=[...],
+ subtask_level=[...],
+ primitive_level=[...],
+ )
+
+ def get_available_tasks(self):
+ return ["task1", "task2"]
+```
+
+### Language Augmentation
+
+```python
+from embodichain.lab.gym.envs.managers import LanguageAugmentationCfg
+
+augmentation_cfg = LanguageAugmentationCfg(
+ synonym_replacement=0.1,
+ template_variation=True,
+ augmentation_prob=0.5,
+)
+```
+
+### Curriculum Learning
+
+```python
+from embodichain.lab.gym.envs.managers import LanguageCurriculumCfg
+
+curriculum_cfg = LanguageCurriculumCfg(
+ enabled=True,
+ stage_duration=1000,
+ stages=[
+ # Simple language first
+ LanguageCurriculumCfg.CurriculumStage(
+ max_words=10,
+ max_sentences=1,
+ max_hierarchy_depth=1,
+ ),
+ # Then more complex
+ LanguageCurriculumCfg.CurriculumStage(
+ max_words=50,
+ max_sentences=3,
+ max_hierarchy_depth=3,
+ ),
+ ],
+)
+```
+
+## Examples
+
+See `usage_example.py` for complete examples of:
+- File-based language loading
+- Environment-based language generation
+- Template-based language
+- Dynamic chunk sizes with language
+- Custom environments with language
+
+## Files
+
+- `tasks_example.yaml` - Example task descriptions in YAML format
+- `usage_example.py` - Complete usage examples
+- `README.md` - This file
+
+## API Reference
+
+### Core Classes
+
+- `LanguageCfg` - Configuration for language data
+- `LanguageManager` - Manages tokenization and language data
+- `LanguageData` - Single-level language data
+- `HierarchicalLanguageData` - Multi-level hierarchical language data
+- `LanguageProvider` - Abstract base for language sources
+- `FileBasedLanguageProvider` - Load from YAML/JSON files
+- `LLMBasedLanguageProvider` - Generate with LLM
+- `EnvBasedLanguageProvider` - Generate from environment
+- `TemplateBasedLanguageProvider` - Template-based generation
+
+## Notes
+
+- Language data is broadcast across all timesteps in an episode
+- Tokenization happens in the simulation subprocess for efficiency
+- Shared memory ensures zero-copy data transfer to training process
+- Compatible with all existing ODS features (dynamic chunks, etc.)
diff --git a/configs/language/tasks_example.yaml b/configs/language/tasks_example.yaml
new file mode 100644
index 00000000..70fe6f5f
--- /dev/null
+++ b/configs/language/tasks_example.yaml
@@ -0,0 +1,108 @@
+# Example language configuration file for VLA training
+# This file demonstrates the hierarchical language structure for task descriptions
+
+pick_and_place:
+ task:
+ - "Pick up the red block and place it in the blue basket."
+
+ subtask:
+ - "Move the gripper to the red block."
+ - "Grasp the red block."
+ - "Lift the block and move to the blue basket."
+ - "Release the block into the basket."
+
+ primitive:
+ - "Close gripper."
+ - "Move up."
+ - "Move right."
+ - "Open gripper."
+
+ change_points: [0, 10, 20, 30]
+
+stack_blocks:
+ task:
+ - "Stack the red block on top of the green block."
+
+ subtask:
+ - "Locate the red block and green block."
+ - "Move the gripper to the red block."
+ - "Grasp the red block."
+ - "Lift the red block."
+ - "Position the red block above the green block."
+ - "Lower the red block onto the green block."
+ - "Release the red block."
+
+ primitive:
+ - "Close gripper."
+ - "Move up."
+ - "Move forward."
+ - "Move left."
+ - "Open gripper."
+
+ change_points: [0, 5, 10, 15, 20, 25, 30]
+
+pour_liquid:
+ task:
+ - "Pour water from the cup into the bowl."
+
+ subtask:
+ - "Approach the cup with the gripper."
+ - "Grasp the cup securely."
+ - "Lift the cup."
+ - "Tilt the cup over the bowl."
+ - "Wait for liquid to pour."
+ - "Return the cup to upright position."
+
+ primitive:
+ - "Close gripper."
+ - "Move up."
+ - "Rotate wrist."
+ - "Wait."
+ - "Rotate wrist back."
+
+ change_points: [0, 5, 10, 15, 25, 30]
+
+button_press:
+ task:
+ - "Press the red button to activate the mechanism."
+
+ subtask:
+ - "Locate the red button."
+ - "Move the end-effector to the button."
+ - "Apply downward force to press the button."
+ - "Release and retract."
+
+ primitive:
+ - "Move forward."
+ - "Move down."
+ - "Apply force."
+ - "Move up."
+ - "Move backward."
+
+ change_points: [0, 5, 10, 15, 20]
+
+door_open:
+ task:
+ - "Open the cabinet door and place the object inside."
+
+ subtask:
+ - "Approach the cabinet handle."
+ - "Grasp the cabinet handle."
+ - "Pull the door open."
+ - "Pick up the object."
+ - "Move the object into the cabinet."
+ - "Release the object."
+ - "Close the cabinet door."
+
+ primitive:
+ - "Move forward."
+ - "Close gripper."
+ - "Move backward."
+ - "Move down."
+ - "Close gripper."
+ - "Move forward."
+ - "Open gripper."
+ - "Move backward."
+ - "Push forward."
+
+ change_points: [0, 5, 10, 15, 20, 25, 30, 35]
diff --git a/configs/language/usage_example.py b/configs/language/usage_example.py
new file mode 100644
index 00000000..2a695b71
--- /dev/null
+++ b/configs/language/usage_example.py
@@ -0,0 +1,327 @@
+# ----------------------------------------------------------------------------
+# Copyright (c) 2021-2026 DexForce Technology Co., Ltd.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ----------------------------------------------------------------------------
+
+"""
+Example: Using Language Support for VLA Training with Online Data Streaming
+
+This example demonstrates how to configure and use the hierarchical language
+support for Vision-Language-Action (VLA) model training.
+"""
+
+import torch
+from torch.utils.data import DataLoader
+
+from embodichain.agents.engine.data import OnlineDataEngine, OnlineDataEngineCfg
+from embodichain.agents.datasets.online_data import OnlineDataset
+from embodichain.lab.gym.envs.managers import (
+ LanguageCfg,
+ LanguageManager,
+ HierarchicalLanguageData,
+)
+
+
+# Example 1: Basic ODS with Language Support (File-based)
+def example_ods_with_language_file():
+ """Set up ODS with language descriptions loaded from a YAML file."""
+
+ # Language configuration
+ language_cfg = {
+ "mode": "tokens",
+ "hierarchy_levels": ["task", "subtask", "primitive"],
+ "max_tokens": 512,
+ "tokenizer": "gpt2",
+ "language_source": "file",
+ "language_config_path": "configs/language/tasks_example.yaml",
+ "max_instructions_per_level": 5,
+ }
+
+ # ODS engine configuration
+ engine_cfg = OnlineDataEngineCfg(
+ buffer_size=16,
+ max_episode_steps=300,
+ state_dim=14,
+ gym_config={
+ "id": "EmbodiedEnv-v1",
+ "env": {"robot": {...}},
+ # ... other env config
+ },
+ language_cfg=language_cfg, # Enable language support
+ )
+
+ # Create and start the engine
+ engine = OnlineDataEngine(engine_cfg)
+ engine.start()
+
+ # Create dataset with language support
+ dataset = OnlineDataset(engine, chunk_size=64, batch_size=8)
+
+ # Create DataLoader
+ loader = DataLoader(
+ dataset,
+ batch_size=None, # Batch mode (dataset handles batching)
+ num_workers=0,
+ collate_fn=OnlineDataset.passthrough_collate_fn,
+ )
+
+ # Training loop
+ for batch in loader:
+ # Access different data modalities
+ obs = batch["obs"] # Vision and proprioception
+ actions = batch["actions"] # Robot actions
+ language = batch["language"] # Hierarchical language data
+
+ # Access language at different hierarchy levels
+ task_tokens = language[
+ "task_level_tokens"
+ ] # [batch, chunk, max_instr, max_tokens]
+ task_mask = language["task_level_attention_mask"]
+
+ subtask_tokens = language["subtask_level_tokens"]
+ subtask_mask = language["subtask_level_attention_mask"]
+
+ primitive_tokens = language["primitive_level_tokens"]
+ primitive_mask = language["primitive_level_attention_mask"]
+
+ # Use for VLA training
+ # train_step(obs, language, actions)
+
+
+# Example 2: Environment-Based Language Generation
+def example_env_based_language():
+ """Set up ODS with language generated by the environment."""
+
+ language_cfg = {
+ "mode": "tokens",
+ "hierarchy_levels": ["task", "subtask"],
+ "max_tokens": 256,
+ "tokenizer": "gpt2",
+ "language_source": "env", # Environment generates language
+ }
+
+ engine_cfg = OnlineDataEngineCfg(
+ buffer_size=16,
+ max_episode_steps=300,
+ state_dim=14,
+ gym_config={...},
+ language_cfg=language_cfg,
+ )
+
+ engine = OnlineDataEngine(engine_cfg)
+ engine.start()
+
+ # Your environment should implement:
+ # - get_task_language(task_id, context) -> HierarchicalLanguageData
+ # - Or have a task_description attribute
+
+
+# Example 3: Template-Based Language
+def example_template_based_language():
+ """Set up ODS with template-based language generation."""
+
+ language_cfg = {
+ "mode": "tokens",
+ "hierarchy_levels": ["task", "subtask"],
+ "max_tokens": 256,
+ "tokenizer": "gpt2",
+ "language_source": "template",
+ "templates": {
+ "pick_and_place": {
+ "task": "Pick up the {color} {object} and place it {location}.",
+ "subtasks": [
+ "Move to the {color} {object}.",
+ "Grasp the {color} {object}.",
+ "Move {location}.",
+ "Release the {object}.",
+ ],
+ }
+ },
+ "variables": {
+ "color": "red",
+ "object": "block",
+ "location": "in the blue basket",
+ },
+ }
+
+ engine_cfg = OnlineDataEngineCfg(
+ buffer_size=16,
+ max_episode_steps=300,
+ state_dim=14,
+ gym_config={...},
+ language_cfg=language_cfg,
+ )
+
+ engine = OnlineDataEngine(engine_cfg)
+ engine.start()
+
+
+# Example 4: Using Language Manager Directly
+def example_language_manager():
+ """Use LanguageManager to tokenize and manage language data."""
+
+ cfg = LanguageCfg(
+ mode="tokens",
+ hierarchy_levels=["task", "subtask", "primitive"],
+ max_tokens=512,
+ tokenizer="gpt2",
+ )
+
+ # Create a simple mock environment
+ class MockEnv:
+ task_name = "pick_and_place"
+ task_description = "Pick up the red block and place it in the basket."
+
+ env = MockEnv()
+ manager = LanguageManager(cfg, env)
+
+ # Create hierarchical language data
+ language_data = manager.create_hierarchical_language_data(
+ task_texts="Pick up the red block and place it in the basket.",
+ subtask_texts=[
+ "Move to the red block.",
+ "Grasp the red block.",
+ "Move to the basket.",
+ "Release the block.",
+ ],
+ primitive_texts=[
+ "Close gripper.",
+ "Move up.",
+ "Move right.",
+ "Open gripper.",
+ ],
+ change_points=[0, 10, 20, 30],
+ )
+
+ # Convert to buffer format
+ buffer_format = language_data.to_buffer_format(cfg)
+
+ # Access tokenized data
+ task_tokens = buffer_format["task_level_tokens"] # [max_instructions, max_tokens]
+ task_mask = buffer_format["task_level_attention_mask"]
+
+
+# Example 5: Dynamic Chunk Size with Language
+def example_dynamic_chunk_language():
+ """Use dynamic chunk sizes with language support."""
+
+ from embodichain.agents.datasets.sampler import UniformChunkSampler
+
+ language_cfg = {
+ "mode": "tokens",
+ "hierarchy_levels": ["task"],
+ "max_tokens": 256,
+ "tokenizer": "gpt2",
+ "language_source": "file",
+ "language_config_path": "configs/language/tasks_example.yaml",
+ }
+
+ engine_cfg = OnlineDataEngineCfg(
+ buffer_size=16,
+ max_episode_steps=300,
+ state_dim=14,
+ gym_config={...},
+ language_cfg=language_cfg,
+ )
+
+ engine = OnlineDataEngine(engine_cfg)
+ engine.start()
+
+ # Dynamic chunk size sampler
+ chunk_sampler = UniformChunkSampler(low=32, high=96)
+
+ # Dataset with dynamic chunk size
+ dataset = OnlineDataset(
+ engine,
+ chunk_size=chunk_sampler,
+ batch_size=8,
+ )
+
+ loader = DataLoader(
+ dataset,
+ batch_size=None,
+ collate_fn=OnlineDataset.passthrough_collate_fn,
+ )
+
+ for batch in loader:
+ # Batch shape is [batch_size, chunk_size, ...]
+ # Chunk dimension varies each iteration
+ print(f"Batch chunk size: {batch.shape[1]}")
+
+ # Language tokens are broadcast across all timesteps
+ language = batch["language"]
+ task_tokens = language[
+ "task_level_tokens"
+ ] # [batch_size, chunk_size, max_instr, max_tokens]
+
+
+# Example 6: Custom Environment with Language
+def example_custom_env_with_language():
+ """Example of a custom environment implementing language generation."""
+
+ from embodichain.lab.gym.envs import EmbodiedEnv, EmbodiedEnvCfg
+
+ class MyTaskEnv(EmbodiedEnv):
+ """Custom environment that provides language descriptions."""
+
+ def __init__(self, cfg, **kwargs):
+ super().__init__(cfg, **kwargs)
+ self.task_name = "my_custom_task"
+
+ def get_task_language(self, task_id, context=None):
+ """Generate hierarchical language for the current task."""
+ return self.language_manager.create_hierarchical_language_data(
+ task_texts="Complete the custom manipulation task.",
+ subtask_texts=[
+ "Approach the object.",
+ "Grasp the object.",
+ "Move to target location.",
+ "Release the object.",
+ ],
+ primitive_texts=[
+ "Move forward.",
+ "Close gripper.",
+ "Move up.",
+ "Move right.",
+ "Open gripper.",
+ ],
+ )
+
+ # Configuration with language
+ env_cfg = EmbodiedEnvCfg(
+ robot={...},
+ language={
+ "mode": "tokens",
+ "hierarchy_levels": ["task", "subtask", "primitive"],
+ "max_tokens": 512,
+ "tokenizer": "gpt2",
+ "language_source": "env", # Environment will generate language
+ },
+ init_rollout_buffer=True,
+ )
+
+ env = MyTaskEnv(env_cfg)
+
+
+if __name__ == "__main__":
+ print("Language Support Examples for VLA Training")
+ print("=" * 50)
+ print("\nAvailable examples:")
+ print("1. example_ods_with_language_file() - File-based language")
+ print("2. example_env_based_language() - Environment-based language")
+ print("3. example_template_based_language() - Template-based language")
+ print("4. example_language_manager() - Direct LanguageManager usage")
+ print("5. example_dynamic_chunk_language() - Dynamic chunk sizes")
+ print("6. example_custom_env_with_language() - Custom environment")
+ print("\nRun any example function to see it in action.")
diff --git a/conftest.py b/conftest.py
new file mode 100644
index 00000000..00987125
--- /dev/null
+++ b/conftest.py
@@ -0,0 +1,24 @@
+# ----------------------------------------------------------------------------
+# Copyright (c) 2021-2026 DexForce Technology Co., Ltd.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ----------------------------------------------------------------------------
+
+from __future__ import annotations
+
+import sys
+from pathlib import Path
+
+# Make the scripts/ directory importable so tests can do:
+# from benchmark.rl.metrics import ...
+sys.path.insert(0, str(Path(__file__).parent / "scripts"))
diff --git a/docs/Makefile b/docs/Makefile
index 864eb2a7..9ded7fad 100644
--- a/docs/Makefile
+++ b/docs/Makefile
@@ -14,8 +14,20 @@ help:
.PHONY: help Makefile
+# Sync README.md -> introduction.rst before building
+.PHONY: sync-readme
+sync-readme:
+ @python3 "$(CURDIR)/scripts/sync_readme.py"
+
# Catch-all target: route all unknown targets to Sphinx using the new
# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
-%: Makefile
+%: Makefile sync-readme
@rm -rf "$(BUILDDIR)"
@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
+
+# Build current version only (for local development / PR verification)
+.PHONY: current-docs
+current-docs: sync-readme
+ @rm -rf "$(BUILDDIR)/html"
+ @$(SPHINXBUILD) -W --keep-going "$(SOURCEDIR)" "$(BUILDDIR)/html" $(SPHINXOPTS) $(O)
+ @python3 "$(CURDIR)/scripts/generate_versions_json.py" --build-dir "$(BUILDDIR)/html"
diff --git a/docs/requirements.txt b/docs/requirements.txt
index 53d9dd9d..87db2c49 100644
--- a/docs/requirements.txt
+++ b/docs/requirements.txt
@@ -8,4 +8,4 @@ sphinx-autosummary-accessors
sphinxcontrib-bibtex
sphinx-design
sphinx_autodoc_typehints
-sphinx-multiversion
\ No newline at end of file
+pypandoc_binary
\ No newline at end of file
diff --git a/docs/scripts/build_versions.py b/docs/scripts/build_versions.py
new file mode 100644
index 00000000..dbbd7224
--- /dev/null
+++ b/docs/scripts/build_versions.py
@@ -0,0 +1,97 @@
+# ----------------------------------------------------------------------------
+# Copyright (c) 2021-2026 DexForce Technology Co., Ltd.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ----------------------------------------------------------------------------
+
+"""Helper script for filtering versions to maintain buffer size."""
+
+import re
+from pathlib import Path
+
+
+def parse_version(tag: str) -> tuple[int, int, int]:
+ """Parse a version tag like 'v1.2.3' into a tuple (1, 2, 3)."""
+ match = re.match(r"^v(\d+)\.(\d+)\.(\d+)$", tag)
+ if not match:
+ return (0, 0, 0)
+ return (int(match.group(1)), int(match.group(2)), int(match.group(3)))
+
+
+def filter_versions(
+ all_versions: list[str],
+ buffer_size: int,
+ main_branch: str = "main",
+) -> list[str]:
+ """Filter versions to maintain buffer size.
+
+ Keeps the latest (buffer_size - 1) release versions plus the main branch.
+
+ Args:
+ all_versions: List of all available version references
+ buffer_size: Total number of versions to keep (releases + main)
+ main_branch: Name of the main branch
+
+ Returns:
+ List of versions to keep
+ """
+ # Separate releases from branches
+ releases = [v for v in all_versions if re.match(r"^v\d+\.\d+\.\d+$", v)]
+ branches = [v for v in all_versions if v not in releases]
+
+ # Sort releases by version (newest first)
+ releases.sort(key=parse_version, reverse=True)
+
+ # Keep latest (buffer_size - 1) releases
+ releases_to_keep = releases[: (buffer_size - 1)]
+
+ # Always include main branch if it exists
+ versions_to_keep = releases_to_keep.copy()
+ if main_branch in branches:
+ versions_to_keep.append(main_branch)
+
+ return versions_to_keep
+
+
+def main():
+ """CLI entry point for version filtering."""
+ import argparse
+
+ parser = argparse.ArgumentParser(
+ description="Filter versions for multi-version docs"
+ )
+ parser.add_argument(
+ "--versions",
+ nargs="+",
+ required=True,
+ help="List of all available versions",
+ )
+ parser.add_argument(
+ "--buffer-size",
+ type=int,
+ default=5,
+ help="Total number of versions to keep (releases + main)",
+ )
+ parser.add_argument(
+ "--main-branch",
+ default="main",
+ help="Name of the main branch",
+ )
+ args = parser.parse_args()
+
+ filtered = filter_versions(args.versions, args.buffer_size, args.main_branch)
+ print(" ".join(filtered))
+
+
+if __name__ == "__main__":
+ main()
diff --git a/docs/scripts/generate_versions_json.py b/docs/scripts/generate_versions_json.py
new file mode 100644
index 00000000..d4905565
--- /dev/null
+++ b/docs/scripts/generate_versions_json.py
@@ -0,0 +1,112 @@
+#!/usr/bin/env python3
+# ----------------------------------------------------------------------------
+# Copyright (c) 2021-2026 DexForce Technology Co., Ltd.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ----------------------------------------------------------------------------
+"""Generate versions.json and root index.html for the docs version selector."""
+
+from __future__ import annotations
+
+import argparse
+import json
+import re
+from pathlib import Path
+
+
+def parse_version(tag: str) -> tuple[int, int, int]:
+ """Parse a version tag like 'v1.2.3' into a tuple (1, 2, 3)."""
+ match = re.match(r"^v(\d+)\.(\d+)\.(\d+)$", tag)
+ if not match:
+ return (0, 0, 0)
+ return (int(match.group(1)), int(match.group(2)), int(match.group(3)))
+
+
+def main() -> None:
+ parser = argparse.ArgumentParser(
+ description="Generate versions.json and root index.html for multi-version docs"
+ )
+ parser.add_argument(
+ "--build-dir",
+ default="build/html",
+ help="Path to build/html directory (default: build/html)",
+ )
+ parser.add_argument(
+ "--output",
+ default=None,
+ help="Output path for versions.json (default: /versions.json)",
+ )
+ parser.add_argument(
+ "--latest",
+ default=None,
+ help="Name of the latest stable version (default: auto-detected from tags, falls back to main)",
+ )
+ args = parser.parse_args()
+
+ html_dir = Path(args.build_dir)
+ output = Path(args.output) if args.output else html_dir / "versions.json"
+
+ if not html_dir.exists():
+ print(f"Error: Build directory '{html_dir}' does not exist.")
+ raise SystemExit(1)
+
+ versions: list[dict[str, str]] = []
+
+ # Collect tag versions (vX.Y.Z directories), sorted newest-first
+ tag_dirs = sorted(
+ [d for d in html_dir.glob("v*") if d.is_dir()],
+ key=lambda d: parse_version(d.name),
+ reverse=True,
+ )
+ for d in tag_dirs:
+ name = d.name
+ versions.append({"name": name, "url": f"./{name}/index.html", "type": "tag"})
+
+ # Collect main (dev branch)
+ if (html_dir / "main").is_dir():
+ versions.append({"name": "main", "url": "./main/index.html", "type": "branch"})
+
+ # Determine latest: explicit arg > newest tag > main
+ if args.latest:
+ latest = args.latest
+ elif versions:
+ tag_names = [v["name"] for v in versions if v["type"] == "tag"]
+ latest = tag_names[0] if tag_names else "main"
+ else:
+ latest = "main"
+
+ manifest = {
+ "latest": latest,
+ "versions": versions,
+ }
+
+ # Write versions.json
+ output.parent.mkdir(parents=True, exist_ok=True)
+ output.write_text(json.dumps(manifest, indent=2))
+ print(f"Generated {output} with {len(versions)} versions (latest: {latest})")
+
+ # Write root index.html redirect
+ index_path = html_dir / "index.html"
+ index_content = (
+ "\n"
+ "\n"
+ f" EmbodiChain Docs\n"
+ f' \n'
+ "\n"
+ )
+ index_path.write_text(index_content)
+ print(f"Generated {index_path} (redirects to ./{latest}/index.html)")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/docs/scripts/sync_readme.py b/docs/scripts/sync_readme.py
new file mode 100644
index 00000000..ca784513
--- /dev/null
+++ b/docs/scripts/sync_readme.py
@@ -0,0 +1,239 @@
+# ----------------------------------------------------------------------------
+# Copyright (c) 2021-2026 DexForce Technology Co., Ltd.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ----------------------------------------------------------------------------
+"""Synchronize README.md to docs/source/introduction.rst.
+
+Uses pypandoc for Markdown-to-RST conversion, then post-processes the output
+to fix Sphinx-specific formatting issues.
+
+Usage:
+ python docs/scripts/sync_readme.py # Overwrite introduction.rst
+ python docs/scripts/sync_readme.py --check # Exit 1 if stale
+"""
+
+from __future__ import annotations
+
+import argparse
+import re
+import sys
+from pathlib import Path
+
+__all__ = ["convert_readme_to_rst", "postprocess_rst"]
+
+# Resolve paths relative to this script
+REPO_ROOT = Path(__file__).resolve().parents[2]
+README_PATH = REPO_ROOT / "README.md"
+RST_PATH = REPO_ROOT / "docs" / "source" / "introduction.rst"
+
+# Prefix to make repo-root-relative paths work from docs/source/
+_DOCS_PATH_PREFIX = "../../"
+
+
+def _fix_image_path(path: str) -> str:
+ """Prefix a repo-root-relative image path for use from docs/source/.
+
+ Args:
+ path: Image path from pandoc output (repo-root-relative).
+
+ Returns:
+ Path adjusted for the RST file location in docs/source/.
+ """
+ if path.startswith(("http://", "https://")):
+ return path
+ return _DOCS_PATH_PREFIX + path
+
+
+def convert_readme_to_rst(readme_content: str) -> str:
+ """Convert Markdown content to RST via pypandoc.
+
+ Args:
+ readme_content: Raw Markdown text from README.md.
+
+ Returns:
+ Raw RST string from pandoc (before post-processing).
+ """
+ import pypandoc
+
+ return pypandoc.convert_text(readme_content, "rst", format="md")
+
+
+def postprocess_rst(rst: str, readme_content: str) -> str:
+ """Fix pandoc RST output for Sphinx compatibility.
+
+ Applies these transformations:
+ 1. Strip badge substitution references and definitions.
+ 2. Convert ``[!NOTE]`` blockquote to ``.. NOTE::`` directive.
+ 3. Convert ``.. raw:: html`` centered-image blocks to ``.. image::``.
+ 4. Replace ``.. code:: bibtex`` with ``.. code-block:: bibtex``.
+ 5. Convert ``.. figure::`` (with caption) to ``.. image::``.
+
+ Args:
+ rst: Raw RST from pandoc.
+ readme_content: Original Markdown (used to extract image paths).
+
+ Returns:
+ Cleaned RST suitable for Sphinx.
+ """
+ # Extract image paths from README tags for centered HTML blocks
+ readme_images = re.findall(r']*src="([^"]+)"[^>]*>', readme_content)
+
+ lines = rst.split("\n")
+ result_lines: list[str] = []
+ i = 0
+
+ while i < len(lines):
+ line = lines[i]
+
+ # --- 1. Strip badge substitution reference lines ---
+ if re.match(r"^\|.*\|", line):
+ i += 1
+ continue
+
+ # --- 1b. Strip badge substitution definitions at the bottom ---
+ if re.match(r"^\.\. \|\w[\w ]*\w\| image::", line):
+ i += 1
+ while i < len(lines) and lines[i].startswith(" "):
+ i += 1
+ continue
+
+ # --- 2. Convert [!NOTE] blockquote to .. NOTE:: ---
+ if re.match(r"^\s+\[!NOTE\]", line):
+ note_match = re.match(r"^\s+\[!NOTE\]\s*(.*)", line)
+ note_text = note_match.group(1) if note_match else ""
+ note_text = note_text.replace("\\*", "*")
+ note_lines: list[str] = []
+ if note_text:
+ note_lines.append(note_text)
+ i += 1
+ while i < len(lines) and lines[i].startswith(" ") and lines[i].strip():
+ cleaned = lines[i].strip().replace("\\*", "*")
+ note_lines.append(cleaned)
+ i += 1
+ result_lines.append(".. NOTE::")
+ for nl in note_lines:
+ result_lines.append(f" {nl}")
+ continue
+
+ # --- 3. Convert .. raw:: html centered blocks to .. image:: ---
+ if line.strip() == ".. raw:: html":
+ # Look ahead (skipping blank lines) for
+ j = i + 1
+ while j < len(lines) and lines[j].strip() == "":
+ j += 1
+ if j < len(lines) and "
raw block
+ i = j + 1 # skip past
line
+ while i < len(lines):
+ if "
" in lines[i]:
+ i += 1
+ # Skip any trailing .. raw:: html for
+ while i < len(lines) and (
+ lines[i].strip() == ""
+ or lines[i].strip() == ".. raw:: html"
+ or "" in lines[i]
+ ):
+ i += 1
+ break
+ i += 1
+ # Insert images from README source
+ for img_src in readme_images:
+ result_lines.append(f".. image:: {_fix_image_path(img_src)}")
+ result_lines.append(" :align: center")
+ result_lines.append("") # blank line after directive
+ continue
+ elif j < len(lines) and "" in lines[j]:
+ i = j + 1
+ continue
+
+ # --- 4. Replace .. code:: bibtex with .. code-block:: bibtex ---
+ if re.match(r"^\.\. code:: bibtex\s*$", line):
+ result_lines.append(".. code-block:: bibtex")
+ i += 1
+ continue
+
+ # --- 5. Convert .. figure:: with caption to .. image:: ---
+ if re.match(r"^\.\. figure::", line):
+ path_match = re.match(r"^\.\. figure:: (.+)", line)
+ if path_match:
+ img_path = path_match.group(1).strip()
+ result_lines.append(f".. image:: {_fix_image_path(img_path)}")
+ i += 1
+ # Skip :alt:, blank line, and caption lines
+ while i < len(lines):
+ if lines[i].startswith(" :"):
+ i += 1
+ continue
+ if lines[i].strip() == "":
+ i += 1
+ continue
+ if lines[i].startswith(" "):
+ i += 1
+ continue
+ break
+ continue
+
+ result_lines.append(line)
+ i += 1
+
+ # Clean up excessive blank lines
+ text = "\n".join(result_lines)
+ text = re.sub(r"\n{3,}", "\n\n", text)
+ return text.strip() + "\n"
+
+
+def main() -> None:
+ """CLI entry point for syncing README.md to introduction.rst."""
+ parser = argparse.ArgumentParser(
+ description="Sync README.md to docs/source/introduction.rst"
+ )
+ parser.add_argument(
+ "--check",
+ action="store_true",
+ help="Check if introduction.rst is up-to-date (exit 1 if stale)",
+ )
+ args = parser.parse_args()
+
+ if not README_PATH.exists():
+ print(f"Error: {README_PATH} not found", file=sys.stderr)
+ sys.exit(1)
+
+ readme_content = README_PATH.read_text(encoding="utf-8")
+ raw_rst = convert_readme_to_rst(readme_content)
+ final_rst = postprocess_rst(raw_rst, readme_content)
+
+ if args.check:
+ if not RST_PATH.exists():
+ print(
+ f"Error: {RST_PATH} does not exist. Run without --check to generate.",
+ file=sys.stderr,
+ )
+ sys.exit(1)
+ current = RST_PATH.read_text(encoding="utf-8")
+ if current != final_rst:
+ print(
+ f"Error: {RST_PATH} is out of sync with README.md. "
+ "Run 'python docs/scripts/sync_readme.py' to update.",
+ file=sys.stderr,
+ )
+ sys.exit(1)
+ print(f"OK: {RST_PATH} is up-to-date.")
+ else:
+ RST_PATH.parent.mkdir(parents=True, exist_ok=True)
+ RST_PATH.write_text(final_rst, encoding="utf-8")
+ print(f"Synced: {README_PATH} -> {RST_PATH}")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/docs/source/_static/version-redirect.js b/docs/source/_static/version-redirect.js
new file mode 100644
index 00000000..effe08cf
--- /dev/null
+++ b/docs/source/_static/version-redirect.js
@@ -0,0 +1,36 @@
+/**
+ * Version redirect script for multi-version documentation.
+ * Redirects to the latest stable release version, or falls back to 'main'.
+ */
+
+(function() {
+ 'use strict';
+
+ // Try to fetch versions.json (generated by generate_versions_json.py)
+ fetch('versions.json')
+ .then(response => {
+ if (!response.ok) {
+ throw new Error('versions.json not found');
+ }
+ return response.json();
+ })
+ .then(data => {
+ // Get the latest version from the JSON
+ const latestVersion = data.latest || data.versions?.[0]?.name || 'main';
+
+ const currentPath = window.location.pathname;
+
+ // If we're at root, redirect to latest version
+ if (currentPath === '/' || currentPath.endsWith('/index.html') || currentPath.endsWith('/')) {
+ window.location.href = latestVersion + '/';
+ }
+ })
+ .catch(error => {
+ console.warn('Version redirect failed:', error.message);
+ // Fallback to main on error
+ const currentPath = window.location.pathname;
+ if (currentPath === '/' || currentPath.endsWith('/index.html') || currentPath.endsWith('/')) {
+ window.location.href = 'main/';
+ }
+ });
+})();
diff --git a/docs/source/_templates/index.html b/docs/source/_templates/index.html
new file mode 100644
index 00000000..f1351f20
--- /dev/null
+++ b/docs/source/_templates/index.html
@@ -0,0 +1,8 @@
+
+
+
+ Redirecting to the latest EmbodiChain documentation
+
+
+
+
diff --git a/docs/source/_templates/versioning.html b/docs/source/_templates/versioning.html
new file mode 100644
index 00000000..a6cb2726
--- /dev/null
+++ b/docs/source/_templates/versioning.html
@@ -0,0 +1,56 @@
+
+
diff --git a/docs/source/api_reference/embodichain/embodichain.agents.rl.algo.rst b/docs/source/api_reference/embodichain/embodichain.agents.rl.algo.rst
index d5a1be05..35b11ab4 100644
--- a/docs/source/api_reference/embodichain/embodichain.agents.rl.algo.rst
+++ b/docs/source/api_reference/embodichain/embodichain.agents.rl.algo.rst
@@ -3,6 +3,11 @@
.. automodule:: embodichain.agents.rl.algo
+Overview
+--------
+
+Algorithm registry and algorithm-construction helpers for RL training.
+
.. rubric:: Functions
@@ -10,4 +15,9 @@
build_algo
get_registered_algo_names
+
+.. automodule:: embodichain.agents.rl.algo
+ :members:
+ :undoc-members:
+ :show-inheritance:
\ No newline at end of file
diff --git a/docs/source/api_reference/embodichain/embodichain.agents.rl.buffer.rst b/docs/source/api_reference/embodichain/embodichain.agents.rl.buffer.rst
index 0a178379..a79f3706 100644
--- a/docs/source/api_reference/embodichain/embodichain.agents.rl.buffer.rst
+++ b/docs/source/api_reference/embodichain/embodichain.agents.rl.buffer.rst
@@ -3,4 +3,35 @@
.. automodule:: embodichain.agents.rl.buffer
+Overview
+--------
+
+The ``buffer`` package provides rollout and replay buffer structures used by
+RL algorithms.
+
+.. rubric:: Submodules
+
+.. autosummary::
+
+ standard_buffer
+ utils
+
+.. currentmodule:: embodichain.agents.rl.buffer
+
+Rollout Buffer Classes
+----------------------
+
+.. automodule:: embodichain.agents.rl.buffer.standard_buffer
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+Buffer Utilities
+----------------
+
+.. automodule:: embodichain.agents.rl.buffer.utils
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
\ No newline at end of file
diff --git a/docs/source/api_reference/embodichain/embodichain.agents.rl.collector.rst b/docs/source/api_reference/embodichain/embodichain.agents.rl.collector.rst
new file mode 100644
index 00000000..4fd639ed
--- /dev/null
+++ b/docs/source/api_reference/embodichain/embodichain.agents.rl.collector.rst
@@ -0,0 +1,33 @@
+embodichain.agents.rl.collector
+================================
+
+.. automodule:: embodichain.agents.rl.collector
+
+Overview
+--------
+
+Collectors are responsible for interacting with vectorized environments and
+assembling rollout data into a preallocated ``TensorDict`` layout.
+
+.. rubric:: Classes
+
+.. autosummary::
+
+ BaseCollector
+ SyncCollector
+
+.. currentmodule:: embodichain.agents.rl.collector
+
+BaseCollector
+-------------
+
+.. autoclass:: BaseCollector
+ :members:
+ :show-inheritance:
+
+SyncCollector
+-------------
+
+.. autoclass:: SyncCollector
+ :members:
+ :show-inheritance:
diff --git a/docs/source/api_reference/embodichain/embodichain.agents.rl.models.rst b/docs/source/api_reference/embodichain/embodichain.agents.rl.models.rst
index d74efb22..6de1449a 100644
--- a/docs/source/api_reference/embodichain/embodichain.agents.rl.models.rst
+++ b/docs/source/api_reference/embodichain/embodichain.agents.rl.models.rst
@@ -3,6 +3,11 @@
.. automodule:: embodichain.agents.rl.models
+Overview
+--------
+
+Policy-network registration and model construction APIs for RL agents.
+
.. rubric:: Functions
@@ -13,4 +18,9 @@
get_policy_class
get_registered_policy_names
register_policy
+
+.. automodule:: embodichain.agents.rl.models
+ :members:
+ :undoc-members:
+ :show-inheritance:
\ No newline at end of file
diff --git a/docs/source/api_reference/embodichain/embodichain.agents.rl.rst b/docs/source/api_reference/embodichain/embodichain.agents.rl.rst
index 2fa64a6e..7dda1a38 100644
--- a/docs/source/api_reference/embodichain/embodichain.agents.rl.rst
+++ b/docs/source/api_reference/embodichain/embodichain.agents.rl.rst
@@ -3,6 +3,12 @@ embodichain.agents.rl
.. automodule:: embodichain.agents.rl
+Overview
+--------
+
+The ``embodichain.agents.rl`` package contains algorithm registries, rollout
+collection logic, policy/model builders, and training entry points.
+
.. rubric:: Submodules
.. autosummary::
@@ -10,6 +16,7 @@ embodichain.agents.rl
algo
buffer
+ collector
models
train
utils
@@ -30,6 +37,14 @@ Rollout Buffer
:undoc-members:
:show-inheritance:
+Collectors
+----------
+
+.. automodule:: embodichain.agents.rl.collector
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
Policy Models
-------------
diff --git a/docs/source/api_reference/embodichain/embodichain.agents.rl.train.rst b/docs/source/api_reference/embodichain/embodichain.agents.rl.train.rst
index 4376c750..7fb189eb 100644
--- a/docs/source/api_reference/embodichain/embodichain.agents.rl.train.rst
+++ b/docs/source/api_reference/embodichain/embodichain.agents.rl.train.rst
@@ -3,6 +3,11 @@
.. automodule:: embodichain.agents.rl.train
+Overview
+--------
+
+Training entry points and command-line helpers for launching RL experiments.
+
.. rubric:: Functions
@@ -11,4 +16,9 @@
main
parse_args
train_from_config
+
+.. automodule:: embodichain.agents.rl.train
+ :members:
+ :undoc-members:
+ :show-inheritance:
\ No newline at end of file
diff --git a/docs/source/api_reference/embodichain/embodichain.agents.rl.utils.rst b/docs/source/api_reference/embodichain/embodichain.agents.rl.utils.rst
index 1f2706a5..b00828a3 100644
--- a/docs/source/api_reference/embodichain/embodichain.agents.rl.utils.rst
+++ b/docs/source/api_reference/embodichain/embodichain.agents.rl.utils.rst
@@ -3,4 +3,42 @@
.. automodule:: embodichain.agents.rl.utils
+Overview
+--------
+
+The ``utils`` package contains helper utilities for RL configuration,
+data conversion, and training orchestration.
+
+.. rubric:: Submodules
+
+.. autosummary::
+
+ config
+ helper
+ trainer
+
+Configuration Helpers
+---------------------
+
+.. automodule:: embodichain.agents.rl.utils.config
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+General Helpers
+---------------
+
+.. automodule:: embodichain.agents.rl.utils.helper
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+Trainer Utilities
+-----------------
+
+.. automodule:: embodichain.agents.rl.utils.trainer
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
\ No newline at end of file
diff --git a/docs/source/api_reference/embodichain/embodichain.agents.rst b/docs/source/api_reference/embodichain/embodichain.agents.rst
index b5942c7e..6b1e5589 100644
--- a/docs/source/api_reference/embodichain/embodichain.agents.rst
+++ b/docs/source/api_reference/embodichain/embodichain.agents.rst
@@ -48,6 +48,7 @@ Reinforcement Learning
algo
buffer
+ collector
models
train
utils
diff --git a/docs/source/api_reference/embodichain/embodichain.data.rst b/docs/source/api_reference/embodichain/embodichain.data.rst
new file mode 100644
index 00000000..9d8b0984
--- /dev/null
+++ b/docs/source/api_reference/embodichain/embodichain.data.rst
@@ -0,0 +1,51 @@
+embodichain.data
+================
+
+.. automodule:: embodichain.data
+
+Data Package Overview
+---------------------
+
+The ``embodichain.data`` package centralizes dataset resolution and asset download
+helpers used by simulation tasks and training pipelines.
+
+.. rubric:: Submodules
+
+.. autosummary::
+
+ constants
+ dataset
+ download
+ enum
+
+Constants
+---------
+
+.. automodule:: embodichain.data.constants
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+Dataset Resolution
+------------------
+
+.. automodule:: embodichain.data.dataset
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+Asset Download CLI
+------------------
+
+.. automodule:: embodichain.data.download
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+Enums
+-----
+
+.. automodule:: embodichain.data.enum
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/api_reference/embodichain/embodichain.lab.sim.atomic_actions.rst b/docs/source/api_reference/embodichain/embodichain.lab.sim.atomic_actions.rst
new file mode 100644
index 00000000..181086c3
--- /dev/null
+++ b/docs/source/api_reference/embodichain/embodichain.lab.sim.atomic_actions.rst
@@ -0,0 +1,89 @@
+embodichain.lab.sim.atomic_actions
+==================================
+
+.. automodule:: embodichain.lab.sim.atomic_actions
+
+ .. rubric:: Classes
+
+ .. autosummary::
+
+ Affordance
+ InteractionPoints
+ ObjectSemantics
+ ActionCfg
+ AtomicAction
+ MoveActionCfg
+ MoveAction
+ PickUpActionCfg
+ PickUpAction
+ PlaceActionCfg
+ PlaceAction
+ AtomicActionEngine
+
+.. currentmodule:: embodichain.lab.sim.atomic_actions
+
+Core
+----
+
+.. autoclass:: Affordance
+ :members:
+ :show-inheritance:
+
+.. autoclass:: InteractionPoints
+ :members:
+ :show-inheritance:
+
+.. autoclass:: ObjectSemantics
+ :members:
+ :show-inheritance:
+
+.. autoclass:: ActionCfg
+ :members:
+ :exclude-members: __init__, copy, replace, to_dict, validate
+
+.. autoclass:: AtomicAction
+ :members:
+ :show-inheritance:
+
+Actions
+-------
+
+.. autoclass:: MoveActionCfg
+ :members:
+ :exclude-members: __init__, copy, replace, to_dict, validate
+ :show-inheritance:
+
+.. autoclass:: MoveAction
+ :members:
+ :show-inheritance:
+
+.. autoclass:: PickUpActionCfg
+ :members:
+ :exclude-members: __init__, copy, replace, to_dict, validate
+ :show-inheritance:
+
+.. autoclass:: PickUpAction
+ :members:
+ :show-inheritance:
+
+.. autoclass:: PlaceActionCfg
+ :members:
+ :exclude-members: __init__, copy, replace, to_dict, validate
+ :show-inheritance:
+
+.. autoclass:: PlaceAction
+ :members:
+ :show-inheritance:
+
+Engine & Registry
+-----------------
+
+.. autoclass:: AtomicActionEngine
+ :members:
+ :show-inheritance:
+
+.. autofunction:: register_action
+
+.. autofunction:: unregister_action
+
+.. autofunction:: get_registered_actions
diff --git a/docs/source/api_reference/embodichain/embodichain.lab.sim.robots.rst b/docs/source/api_reference/embodichain/embodichain.lab.sim.robots.rst
index d6428af3..c3457108 100644
--- a/docs/source/api_reference/embodichain/embodichain.lab.sim.robots.rst
+++ b/docs/source/api_reference/embodichain/embodichain.lab.sim.robots.rst
@@ -3,4 +3,23 @@
.. automodule:: embodichain.lab.sim.robots
+Overview
+--------
+
+This module exposes robot-specific configuration presets for simulation scenes.
+
+.. rubric:: Classes
+
+.. autosummary::
+
+ CobotMagicCfg
+
+.. currentmodule:: embodichain.lab.sim.robots
+
+.. autoclass:: CobotMagicCfg
+ :members:
+ :inherited-members:
+ :show-inheritance:
+ :exclude-members: __init__, copy, replace, to_dict, validate
+
\ No newline at end of file
diff --git a/docs/source/api_reference/embodichain/embodichain.lab.sim.rst b/docs/source/api_reference/embodichain/embodichain.lab.sim.rst
index 2a21fcf0..412f570d 100644
--- a/docs/source/api_reference/embodichain/embodichain.lab.sim.rst
+++ b/docs/source/api_reference/embodichain/embodichain.lab.sim.rst
@@ -3,21 +3,30 @@
.. automodule:: embodichain.lab.sim
- .. rubric:: Submodules
-
- .. autosummary::
- :toctree: .
-
- sim_manager
- cfg
- common
- material
- shapes
- objects
- sensors
- planners
- solvers
- utility
+Overview
+--------
+
+The ``sim`` package provides simulation-core APIs including scene/object
+management, materials, sensors, planning/IK utilities, and action helpers.
+
+.. rubric:: Submodules
+
+.. autosummary::
+ :toctree: .
+
+ sim_manager
+ cfg
+ common
+ material
+ shapes
+ objects
+ robots
+ sensors
+ solvers
+ planners
+ atomic_actions
+ types
+ utility
.. currentmodule:: embodichain.lab.sim
@@ -35,8 +44,8 @@ Simulation Manager
:show-inheritance:
:exclude-members: __init__, copy, replace, to_dict, validate
-Configurations
-------------------
+Configuration
+-------------
.. automodule:: embodichain.lab.sim.cfg
:members:
@@ -44,8 +53,8 @@ Configurations
:show-inheritance:
:exclude-members: __init__, copy, replace, to_dict, validate
-Common Conponents
-------------------
+Common Components
+-----------------
.. automodule:: embodichain.lab.sim.common
:members:
@@ -53,7 +62,7 @@ Common Conponents
:show-inheritance:
Materials
-------------------
+---------
.. automodule:: embodichain.lab.sim.material
:members:
@@ -61,7 +70,7 @@ Materials
:show-inheritance:
Shapes
-------------------
+------
.. automodule:: embodichain.lab.sim.shapes
:members:
@@ -69,6 +78,14 @@ Shapes
:show-inheritance:
:exclude-members: __init__, copy, replace, to_dict, validate
+Atomic Actions
+--------------
+
+.. automodule:: embodichain.lab.sim.atom_actions
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
Objects
-------
@@ -85,6 +102,14 @@ Sensors
embodichain.lab.sim.sensors
+Robot Configurations
+--------------------
+
+.. automodule:: embodichain.lab.sim.robots
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
Solvers
-------
@@ -101,6 +126,21 @@ Planners
embodichain.lab.sim.planners
+Atomic Actions
+--------------
+
+.. toctree::
+ :maxdepth: 1
+
+ embodichain.lab.sim.atomic_actions
+Shared Types
+------------
+
+.. automodule:: embodichain.lab.sim.types
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
Utility
-------
diff --git a/docs/source/api_reference/embodichain/embodichain.lab.sim.types.rst b/docs/source/api_reference/embodichain/embodichain.lab.sim.types.rst
index 5b1c4bd8..f01bae1f 100644
--- a/docs/source/api_reference/embodichain/embodichain.lab.sim.types.rst
+++ b/docs/source/api_reference/embodichain/embodichain.lab.sim.types.rst
@@ -3,4 +3,27 @@
.. automodule:: embodichain.lab.sim.types
+Overview
+--------
+
+Shared tensor/type aliases used across simulation, environment, and policy
+interfaces.
+
+.. rubric:: Type Aliases
+
+.. autosummary::
+
+ Array
+ Device
+ EnvObs
+ EnvAction
+
+.. autodata:: Array
+
+.. autodata:: Device
+
+.. autodata:: EnvObs
+
+.. autodata:: EnvAction
+
\ No newline at end of file
diff --git a/docs/source/api_reference/embodichain/embodichain.lab.sim.utility.rst b/docs/source/api_reference/embodichain/embodichain.lab.sim.utility.rst
index f64d3ce3..2e45ea5d 100644
--- a/docs/source/api_reference/embodichain/embodichain.lab.sim.utility.rst
+++ b/docs/source/api_reference/embodichain/embodichain.lab.sim.utility.rst
@@ -3,21 +3,73 @@ embodichain.lab.sim.utility
.. automodule:: embodichain.lab.sim.utility
-Utility Functions
------------------
+Overview
+--------
-This module contains utility functions for simulation, mesh processing, and URDF handling.
+This package contains helper utilities for simulation state conversion,
+mesh/geometry handling, configuration transforms, keyboard interaction, and
+action/solver adaptation.
.. rubric:: Submodules
.. autosummary::
+ action_utils
+ atom_action_utils
+ cfg_utils
+ gizmo_utils
+ import_utils
+ io_utils
+ keyboard_utils
sim_utils
mesh_utils
- urdf_utils
+ solver_utils
+ tensor
.. currentmodule:: embodichain.lab.sim.utility
+Action Utilities
+~~~~~~~~~~~~~~~~
+
+.. automodule:: embodichain.lab.sim.utility.action_utils
+ :members:
+
+Atomic Action Utilities
+~~~~~~~~~~~~~~~~~~~~~~~
+
+.. automodule:: embodichain.lab.sim.utility.atom_action_utils
+ :members:
+
+Configuration Utilities
+~~~~~~~~~~~~~~~~~~~~~~~
+
+.. automodule:: embodichain.lab.sim.utility.cfg_utils
+ :members:
+
+Gizmo Utilities
+~~~~~~~~~~~~~~~
+
+.. automodule:: embodichain.lab.sim.utility.gizmo_utils
+ :members:
+
+Import Utilities
+~~~~~~~~~~~~~~~~
+
+.. automodule:: embodichain.lab.sim.utility.import_utils
+ :members:
+
+I/O Utilities
+~~~~~~~~~~~~~
+
+.. automodule:: embodichain.lab.sim.utility.io_utils
+ :members:
+
+Keyboard Utilities
+~~~~~~~~~~~~~~~~~~
+
+.. automodule:: embodichain.lab.sim.utility.keyboard_utils
+ :members:
+
Simulation Utils
~~~~~~~~~~~~~~~~
@@ -29,3 +81,15 @@ Mesh Utils
.. automodule:: embodichain.lab.sim.utility.mesh_utils
:members:
+
+Solver Utilities
+~~~~~~~~~~~~~~~~
+
+.. automodule:: embodichain.lab.sim.utility.solver_utils
+ :members:
+
+Tensor Utilities
+~~~~~~~~~~~~~~~~
+
+.. automodule:: embodichain.lab.sim.utility.tensor
+ :members:
diff --git a/docs/source/api_reference/embodichain/embodichain.utils.rst b/docs/source/api_reference/embodichain/embodichain.utils.rst
index 490962ce..c4d131a1 100644
--- a/docs/source/api_reference/embodichain/embodichain.utils.rst
+++ b/docs/source/api_reference/embodichain/embodichain.utils.rst
@@ -3,13 +3,16 @@
.. automodule:: embodichain.utils
- .. Rubric:: Submodules
+ .. rubric:: Submodules
.. autosummary::
warp
+ cfg
configclass
+ device_utils
file
+ img_utils
logger
math
module_utils
diff --git a/docs/source/api_reference/index.rst b/docs/source/api_reference/index.rst
index fa3112ae..f73a7480 100644
--- a/docs/source/api_reference/index.rst
+++ b/docs/source/api_reference/index.rst
@@ -1,7 +1,16 @@
API Reference
=============
-This page provides detailed documentation for all EmbodiChain modules and classes.
+This section provides the API-level documentation for EmbodiChain's public Python
+modules.
+
+Use this reference when you need:
+
+* module-level overviews and responsibilities,
+* public classes, functions, and configuration objects,
+* links into specialized subpackages (simulation, gym environments, RL, and utilities).
+
+The pages are organized from high-level package namespaces to concrete submodules.
Core Framework
--------------
@@ -14,6 +23,7 @@ The following modules are available in the core ``embodichain`` framework:
:toctree: embodichain
agents
+ data
lab
toolkits
utils
diff --git a/docs/source/conf.py b/docs/source/conf.py
index 59145215..a0b23064 100644
--- a/docs/source/conf.py
+++ b/docs/source/conf.py
@@ -41,7 +41,6 @@
"sphinx_design",
"myst_parser", # if you prefer Markdown pages
"sphinx_copybutton",
- "sphinx_multiversion",
]
# Napoleon settings if using Google/NumPy docstring style:
napoleon_google_docstring = True
@@ -65,17 +64,45 @@
exclude_patterns = []
+# -- Version selector sidebar ---------------------------------------------------
+html_sidebars = {
+ "**": [
+ "navbar-logo.html",
+ "versioning.html",
+ "search-field.html",
+ "sbt-sidebar-nav.html",
+ ]
+}
+
+
# -- Options for HTML output -------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output
html_theme = "sphinx_book_theme"
html_static_path = ["_static"]
+# Don't include version-redirect.js automatically - we add it manually to root
+html_js_files = []
# html_logo = "_static/logo_e.png"
-# -- sphinx-multiversion configuration -------------------------------------------------
-# Only build tags that look like v1.0.0 or branches like main/dev
-smv_tag_whitelist = r"^v\d+\.\d+\.\d+$"
-smv_branch_whitelist = r"^(main|dev)$"
-smv_remote_whitelist = r"^origin$"
-smv_released_pattern = r"^tags/v\d+\.\d+\.\d+$"
-smv_outputdir_format = "{ref.name}"
+# Configure HTML base URL for better local previewing
+# Use empty string to use relative paths from the build directory
+html_baseurl = ""
+
+# HTML context for better path handling
+html_context = {
+ "github_user": "dexforce",
+ "github_repo": "EmbodiChain",
+ "github_version": "main",
+ "doc_path": "docs/source",
+}
+
+html_theme_options = {
+ "title": "EmbodiChain",
+ "logo_only": False,
+ "show_toc_level": 2,
+ "collapse_navigation": True,
+ "sticky_navigation": True,
+ "navigation_depth": 4,
+ "includehidden": True,
+ "prev_next_buttons_location": "bottom",
+}
diff --git a/docs/source/features/agents.md b/docs/source/features/agents.md
index 7cb2356d..89602c93 100644
--- a/docs/source/features/agents.md
+++ b/docs/source/features/agents.md
@@ -164,3 +164,12 @@ embodichain/agents/
│ └── prompt/ # Prompt templates (LangChain)
└── prompts/ # Agent prompt templates
```
+
+---
+
+## See Also
+
+- [Online Data Streaming](online_data.md) — Streaming live simulation data for training
+- [RL Architecture](../overview/rl/index.rst) — RL training pipeline and algorithms
+- [Atomic Actions Tutorial](../tutorial/atomic_actions.rst) — Action primitives used by the CodeAgent
+- [Supported Tasks](../resources/task/index.rst) — Available task environments
diff --git a/docs/source/features/interaction/preview_asset.md b/docs/source/features/interaction/preview_asset.md
index 4dc2c4be..df3aa040 100644
--- a/docs/source/features/interaction/preview_asset.md
+++ b/docs/source/features/interaction/preview_asset.md
@@ -75,7 +75,7 @@ asset.set_root_pose(pos=[0, 0, 1.0], rot=[0, 0, 0])
| `--fix_base` | Fix the base of articulations | `True` |
| `--sim_device` | Simulation device | `cpu` |
| `--headless` | Run without rendering window | `False` |
-| `--enable_rt` | Enable ray tracing | `False` |
+| `--renderer` | Renderer backend: `hybrid`, `fast-rt` or `rt` | `hybrid` |
| `--preview` | Enter interactive embed mode after loading | `False` |
## Examples
diff --git a/docs/source/features/interaction/window.md b/docs/source/features/interaction/window.md
index 6c512186..e19b0da0 100644
--- a/docs/source/features/interaction/window.md
+++ b/docs/source/features/interaction/window.md
@@ -9,6 +9,7 @@ The simulation window comes with a set of default controls that enable users to
| Events | Description |
|---------------------------------|--------------------------------------------------------------------------------------------------------------------------------------------------|
| **Raycast Information Display** | Press the right mouse button to select a point and the 'C' key to print the raycast distance and hit position of a surface (world coordinates) to the console. Useful for debugging and checking the position of objects in the simulation. |
+| **Viewer recording (toggle)** | Press **`r`** to **start** recording what the interactive viewer shows, and press **`r`** again to **stop** and save as MP4 videos. Recording uses a hidden camera that follows the live viewer camera pose, so the exported videos match the on-screen view. Useful for debugging and recording the demos.|
> **Note:** We will add more interaction features in future releases. Stay tuned for updates!
diff --git a/docs/source/features/online_data.md b/docs/source/features/online_data.md
index c186aef6..dccd38d1 100644
--- a/docs/source/features/online_data.md
+++ b/docs/source/features/online_data.md
@@ -143,3 +143,11 @@ It shows item mode, batch mode, and dynamic chunk sizes. Run it with:
```bash
python examples/agents/datasets/online_dataset_demo.py
```
+
+---
+
+## See Also
+
+- [EmbodiAgent](agents.md) — Hierarchical agent that uses online data for training
+- [RL Architecture](../overview/rl/index.rst) — RL training pipeline
+- [Data Generation Tutorial](../tutorial/data_generation.rst) — Generating offline datasets
diff --git a/docs/source/features/toolkits/grasp_generator.rst b/docs/source/features/toolkits/grasp_generator.rst
index ba77e77b..7eea272a 100644
--- a/docs/source/features/toolkits/grasp_generator.rst
+++ b/docs/source/features/toolkits/grasp_generator.rst
@@ -24,7 +24,7 @@ The Code Explained
Configuring the simulation
--------------------------
-Command-line arguments are parsed with ``argparse`` to select the number of parallel environments, the compute device, and optional rendering features such as ray tracing and headless mode.
+Command-line arguments are parsed with ``argparse`` to select the number of parallel environments, the compute device, and optional rendering features such as renderer backend and headless mode.
.. literalinclude:: ../../../../scripts/tutorials/grasp/grasp_generator.py
:language: python
@@ -185,7 +185,7 @@ You can customize the run with additional arguments:
.. code-block:: bash
- python scripts/tutorials/grasp/grasp_generator.py --num_envs --device --enable_rt --headless
+ python scripts/tutorials/grasp/grasp_generator.py --num_envs --device --renderer --headless
After confirming the grasp region in the browser, the script will compute a grasp pose, print the elapsed time, and then wait for you to press **Enter** before executing the full grasp trajectory in the simulation. Press **Enter** again to exit once the motion is complete.
diff --git a/docs/source/features/toolkits/urdf_assembly.md b/docs/source/features/toolkits/urdf_assembly.md
index 76f48ddb..dd504956 100644
--- a/docs/source/features/toolkits/urdf_assembly.md
+++ b/docs/source/features/toolkits/urdf_assembly.md
@@ -18,7 +18,7 @@ The tool provides a programmatic way to:
```python
from pathlib import Path
import numpy as np
-from embedichain.toolkits.urdf_assembly import URDFAssemblyManager
+from embodichain.toolkits.urdf_assembly import URDFAssemblyManager
# Initialize the assembly manager
manager = URDFAssemblyManager()
@@ -201,6 +201,72 @@ Get all attached sensors.
manager.get_attached_sensors() -> dict
```
+##### Component name prefixes (`component_prefix`)
+
+`URDFAssemblyManager` uses `component_prefix` to configure name prefixes for
+each supported component type. This attribute is a list of 2-tuples:
+
+- Form: `[(component_name, prefix), ...]`
+- The default value is:
+
+ ```python
+ [
+ ("chassis", None),
+ ("legs", None),
+ ("torso", None),
+ ("head", None),
+ ("left_arm", "left_"),
+ ("right_arm", "right_"),
+ ("left_hand", "left_"),
+ ("right_hand", "right_"),
+ ("arm", None),
+ ("hand", None),
+ ]
+ ```
+
+You can configure it in a *patch-style* manner via the property:
+
+```python
+# Only override prefixes for existing components; do not introduce
+# new component names.
+manager.component_prefix = [
+ ("left_arm", "L_"),
+ ("right_arm", "R_"),
+ ("left_hand", "L_"),
+ ("right_hand", "R_"),
+]
+```
+
+Semantics:
+
+- Only components that already exist in the default configuration (e.g. `chassis/torso/left_arm/...`) may be overridden; new component names are not allowed.
+- Components not listed in `new_prefixes` keep their original prefix.
+- If `new_prefixes` contains an unknown component name, a `ValueError` is raised indicating that new component types cannot be introduced.
+
+##### Name casing policy (`name_case`)
+
+`URDFAssemblyManager` supports a global name casing policy that controls how
+link and joint names are normalized during assembly. This is configured on
+the manager instance after construction:
+
+```python
+manager = URDFAssemblyManager()
+manager.name_case = {
+ "joint": "upper", # or "lower" / "none"
+ "link": "lower", # or "upper" / "none"
+}
+
+Semantics:
+
+- Valid keys: `"joint"`, `"link"`.
+- Valid values: `"upper"`, `"lower"`, `"none"`.
+- Default behavior matches the legacy implementation:
+ - joints are normalized to **UPPERCASE**,
+ - links are normalized to **lowercase**.
+- This policy is propagated to the internal component and connection managers,
+ and is also included in the assembly signature. Changing `name_case` will
+ therefore force a rebuild of the assembled URDF.
+
## Using with URDFCfg for Robot Creation
The URDF Assembly Tool can be used directly with `URDFCfg` to create robots with multiple components in the simulation. This is the recommended approach when building robots from assembled URDF files.
@@ -210,7 +276,7 @@ The URDF Assembly Tool can be used directly with `URDFCfg` to create robots with
The `URDFCfg` class provides a convenient way to define multi-component robots:
```python
-from embedichain.lab.sim.cfg import RobotCfg, URDFCfg
+from embodichain.lab.sim.cfg import RobotCfg, URDFCfg
cfg = RobotCfg(
uid="my_robot",
@@ -232,6 +298,27 @@ cfg = RobotCfg(
)
```
+When using `URDFCfg` to build multi-component robots, you can pass custom
+component prefixes to the internal `URDFAssemblyManager` via
+`URDFCfg.component_prefix`. Its semantics are identical to
+`URDFAssemblyManager.component_prefix`:
+
+- Each element is a `(component_name, prefix)` tuple.
+- Only prefixes for components that exist in the default configuration may be overridden; no new component names can be added.
+- Components not explicitly listed keep their original prefix.
+
+Example:
+
+```python
+urdf_cfg = URDFCfg(
+ components=[...],
+)
+urdf_cfg.component_prefix = [
+ ("left_arm", "L_"),
+ ("right_arm", "R_"),
+]
+```
+
### Complete Example
Here's a complete example from `scripts/tutorials/sim/create_robot.py`:
@@ -241,14 +328,14 @@ import numpy as np
import torch
from scipy.spatial.transform import Rotation as R
-from embedichain.lab.sim import SimulationManager, SimulationManagerCfg
-from embedichain.lab.sim.objects import Robot
-from embedichain.lab.sim.cfg import (
+from embodichain.lab.sim import SimulationManager, SimulationManagerCfg
+from embodichain.lab.sim.objects import Robot
+from embodichain.lab.sim.cfg import (
JointDrivePropertiesCfg,
RobotCfg,
URDFCfg,
)
-from embedichain.data import get_data_path
+from embodichain.data import get_data_path
def create_robot(sim):
@@ -269,7 +356,6 @@ def create_robot(sim):
# Define transformation for hand attachment
hand_transform = np.eye(4)
hand_transform[:3, :3] = R.from_rotvec([90, 0, 0], degrees=True).as_matrix()
- hand_transform[2, 3] = 0.02 # 2cm offset along z-axis
# Create robot configuration
cfg = RobotCfg(
@@ -300,6 +386,86 @@ def create_robot(sim):
return robot
+# Initialize simulation and create robot
+sim = SimulationManager(SimulationManagerCfg(headless=True, num_envs=4))
+robot = create_robot(sim)
+print(f"Robot created with {robot.dof} joints")
+```
+
+```python
+import numpy as np
+import torch
+from scipy.spatial.transform import Rotation as R
+
+from embodichain.lab.sim import SimulationManager, SimulationManagerCfg
+from embodichain.lab.sim.objects import Robot
+from embodichain.lab.sim.cfg import (
+ JointDrivePropertiesCfg,
+ RobotCfg,
+ URDFCfg,
+)
+from embodichain.data import get_data_path
+
+
+def create_robot(sim):
+ """Create and configure a robot with arm and hand components."""
+
+ # Get URDF paths for robot components
+ arm_urdf_path = get_data_path("Rokae/SR5/SR5.urdf")
+ hand_urdf_path = get_data_path(
+ "BrainCoHandRevo1/BrainCoLeftHand/BrainCoLeftHand.urdf"
+ )
+
+ # Define transformation for hand attachment
+ hand_transform = np.eye(4)
+ hand_transform[:3, :3] = R.from_rotvec([90, 0, 0], degrees=True).as_matrix()
+
+ left_arm_base_xpos = np.eye(4)
+ left_arm_base_xpos[1, 3] = 0.3
+
+ right_arm_base_xpos = np.eye(4)
+ right_arm_base_xpos[1, 3] = -0.3
+
+ # Create robot configuration
+ cfg = RobotCfg(
+ uid="dual_sr5",
+ urdf_cfg=URDFCfg(
+ components=[
+ {
+ "component_type": "left_arm",
+ "urdf_path": arm_urdf_path,
+ "transform": left_arm_base_xpos,
+ },
+ {
+ "component_type": "right_arm",
+ "urdf_path": arm_urdf_path,
+ "transform": right_arm_base_xpos,
+ },
+ {
+ "component_type": "left_hand",
+ "urdf_path": hand_urdf_path,
+ "transform": hand_transform,
+ },
+ {
+ "component_type": "right_hand",
+ "urdf_path": hand_urdf_path,
+ "transform": hand_transform,
+ },
+ ],
+ component_prefix=[("left_arm", "L_"), ("right_arm", "R_"), ("left_hand", "left_"), ("right_hand", "right_")],
+ name_case={
+ "joint": "lower",
+ "link": "lower",
+ }
+ ),
+ )
+
+ # Add robot to simulation
+ robot: Robot = sim.add_robot(cfg=cfg)
+
+ return robot
+
+
# Initialize simulation and create robot
sim = SimulationManager(SimulationManagerCfg(headless=True, num_envs=4))
robot = create_robot(sim)
diff --git a/docs/source/guides/add_robot.rst b/docs/source/guides/add_robot.rst
index d58740a1..f437fd0b 100644
--- a/docs/source/guides/add_robot.rst
+++ b/docs/source/guides/add_robot.rst
@@ -1,563 +1,54 @@
-.. _tutorial_add_robot:
+.. _guide_add_robot:
-Adding a New Robot
-==================
+Adding a New Robot — Quick Reference
+=====================================
-.. currentmodule:: embodichain.lab.sim.robots
+This guide provides a checklist and key reference for adding a new robot to EmbodiChain. For the full step-by-step walkthrough with code examples, see :doc:`/tutorial/add_robot`.
-This tutorial guides you through adding a new robot to EmbodiChain. You'll learn the file structure, key components, and patterns used for robot definitions.
+Checklist
+---------
-EmbodiChain supports two approaches for defining robots:
+1. **Prepare the URDF** — Place your URDF file (and associated meshes) in the robot assets directory.
+2. **Create the config class** — Inherit from ``RobotCfg``, implement ``from_dict`` and ``_build_default_cfgs``.
+3. **Define control parts** — Group joints into logical sets (e.g., ``arm``, ``gripper``).
+4. **Configure IK solver** — Choose ``OPWSolverCfg``, ``SRSSolverCfg``, or a generic ``SolverCfg``.
+5. **Set drive properties** — Configure stiffness, damping, and max effort per joint group.
+6. **Implement** ``build_pk_serial_chain`` — Required for PyTorch-Kinematics IK support.
+7. **Register in** ``embodichain/lab/sim/robots/__init__.py``.
+8. **Add documentation** — Create ``docs/source/resources/robot/my_robot.md`` and update ``resources/robot/index.rst``.
+9. **Test** — Add a ``__main__`` block or use the ``preview-asset`` CLI to verify.
-1. **Single-file approach**: For simpler robots (like ``CobotMagic``)
-2. **Package approach**: For complex robots with multiple variants (like ``DexforceW1``)
+Approaches
+----------
-Choose the approach based on your robot's complexity.
+- **Single-file** (simple robots): One ``my_robot.py`` with everything.
+- **Package** (complex robots): Directory with ``types.py``, ``params.py``, ``utils.py``, ``cfg.py``, ``__init__.py``.
----
-
-Prerequisites
-~~~~~~~~~~~~~~
-
-Before adding a new robot, ensure you have:
-
-- URDF file(s) for your robot
-- Robot's kinematic parameters (DH parameters or joint limits)
-- Understanding of your robot's joint structure and control parts
-
----
-
-Approach 1: Single-File Robot (Simple Robots)
-~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
-
-Use this approach for robots with a single variant and straightforward configuration.
-
-File: ``embodichain/lab/sim/robots/my_robot.py``
-
-.. dropdown:: Complete Example: CobotMagic-style Robot
- :icon: code
-
- .. literalinclude:: ../../../embodichain/lab/sim/robots/cobotmagic.py
- :language: python
- :linenos:
-
-Step-by-Step Guide
-------------------
-
-1. **Create the configuration class** inheriting from ``RobotCfg``:
-
- .. code-block:: python
-
- from __future__ import annotations
-
- from typing import Dict, List, Any
- import numpy as np
-
- from embodichain.lab.sim.cfg import (
- RobotCfg,
- URDFCfg,
- JointDrivePropertiesCfg,
- RigidBodyAttributesCfg,
- )
- from embodichain.lab.sim.solvers import SolverCfg, OPWSolverCfg
- from embodichain.lab.sim.utility.cfg_utils import merge_robot_cfg
- from embodichain.data import get_data_path
- from embodichain.utils import configclass
-
- @configclass
- class MyRobotCfg(RobotCfg):
- urdf_cfg: URDFCfg = None
- control_parts: Dict[str, List[str]] | None = None
- solver_cfg: Dict[str, "SolverCfg"] | None = None
-
-2. **Implement the ``from_dict`` class method** for flexible initialization:
-
- .. code-block:: python
-
- @classmethod
- def from_dict(cls, init_dict: Dict[str, Any]) -> "MyRobotCfg":
- cfg = cls()
- default_cfgs = cls()._build_default_cfgs()
- for key, value in default_cfgs.items():
- setattr(cfg, key, value)
- cfg = merge_robot_cfg(cfg, init_dict)
- return cfg
-
-3. **Define ``_build_default_cfgs``** with your robot's defaults:
-
- .. code-block:: python
-
- @staticmethod
- def _build_default_cfgs() -> Dict[str, Any]:
- # URDF path
- urdf_path = get_data_path("MyRobot/my_robot.urdf")
-
- # URDF configuration (for multi-component robots)
- urdf_cfg = URDFCfg(
- components=[
- {
- "component_type": "arm",
- "urdf_path": urdf_path,
- "transform": np.eye(4), # 4x4 transform matrix
- },
- ]
- )
-
- # Control parts - group joints for control
- control_parts = {
- "arm": [
- "JOINT1", "JOINT2", "JOINT3",
- "JOINT4", "JOINT5", "JOINT6",
- ],
- "gripper": ["JOINT7", "JOINT8"],
- }
-
- # Solver configuration for IK
- solver_cfg = {
- "arm": OPWSolverCfg(
- end_link_name="link6",
- root_link_name="base_link",
- tcp=np.array([...]), # Tool center point transform
- ),
- }
-
- # Drive properties - joint physics parameters
- drive_pros = JointDrivePropertiesCfg(
- stiffness={
- "JOINT[1-6]": 7e4, # Regex pattern for joints 1-6
- "JOINT[7-8]": 3e2,
- },
- damping={
- "JOINT[1-6]": 1e3,
- "JOINT[7-8]": 3e1,
- },
- max_effort={
- "JOINT[1-6]": 3e6,
- "JOINT[7-8]": 3e3,
- },
- )
-
- return {
- "uid": "MyRobot",
- "urdf_cfg": urdf_cfg,
- "control_parts": control_parts,
- "solver_cfg": solver_cfg,
- "drive_pros": drive_pros,
- "attrs": RigidBodyAttributesCfg(
- mass=0.1,
- static_friction=0.95,
- dynamic_friction=0.9,
- linear_damping=0.7,
- angular_damping=0.7,
- ),
- }
-
-4. **Implement ``build_pk_serial_chain``** for PyTorch-Kinematics:
-
- .. code-block:: python
-
- def build_pk_serial_chain(
- self, device: torch.device = torch.device("cpu"), **kwargs
- ) -> Dict[str, "pk.SerialChain"]:
- from embodichain.lab.sim.utility.solver_utils import (
- create_pk_chain,
- create_pk_serial_chain,
- )
-
- urdf_path = get_data_path("MyRobot/my_robot.urdf")
- chain = create_pk_chain(urdf_path, device)
-
- arm_chain = create_pk_serial_chain(
- chain=chain,
- end_link_name="link6",
- root_link_name="base_link"
- ).to(device=device)
-
- return {"arm": arm_chain}
-
-5. **Register in** ``embodichain/lab/sim/robots/__init__.py``:
-
- .. code-block:: python
-
- from .my_robot import MyRobotCfg
-
----
-
-Approach 2: Package-Based Robot (Complex Robots)
-~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
-
-Use this approach for robots with multiple variants (e.g., different arm types, versions, or configurations).
-
-File Structure
+Key Parameters
--------------
-For complex robots, create a package directory:
-
-.. code-block::
-
- robots/
- └── my_robot/
- ├── __init__.py # Exports the main config class
- ├── types.py # Enums for robot variants
- ├── params.py # Kinematics parameters
- ├── utils.py # Manager classes and builders
- └── cfg.py # Main configuration class
-
-Step-by-Step Guide
------------------
-
-1. **types.py** - Define enums for robot variants:
-
- .. code-block:: python
-
- from enum import Enum
-
- class MyRobotVersion(Enum):
- V010 = "v010"
- V020 = "v020"
-
- class MyRobotArmKind(Enum):
- STANDARD = "standard"
- EXTENDED = "extended"
-
- class MyRobotSide(Enum):
- LEFT = "left"
- RIGHT = "right"
-
-2. **params.py** - Define kinematics parameters:
-
- .. code-block:: python
-
- from dataclasses import dataclass
- import numpy as np
- from typing import Optional
-
- @dataclass
- class MyRobotArmKineParams:
- arm_side: MyRobotSide
- arm_kind: MyRobotArmKind
- version: MyRobotVersion
-
- dh_params: np.ndarray = None # DH parameters (N x 4)
- qpos_limits: np.ndarray = None # Joint limits (N x 2)
- link_lengths: np.ndarray = None # Link lengths
- T_b_ob: np.ndarray = None # Base to origin transform
- T_e_oe: np.ndarray = None # End-effector transform
-
-3. **utils.py** - Manager classes and builder functions:
-
- .. code-block:: python
-
- class ArmManager:
- """Manages arm URDF and configuration."""
- pass
-
- def build_my_robot_assembly_urdf_cfg(...):
- """Build URDF assembly from components."""
- pass
-
- def build_my_robot_cfg(...):
- """Build complete robot configuration."""
- pass
-
-4. **cfg.py** - Main configuration class:
-
- .. code-block:: python
-
- @configclass
- class MyRobotCfg(RobotCfg):
- version: MyRobotVersion = MyRobotVersion.V010
- arm_kind: MyRobotArmKind = MyRobotArmKind.STANDARD
-
- @classmethod
- def from_dict(cls, init_dict: Dict) -> "MyRobotCfg":
- # Implementation similar to single-file approach
- pass
-
-5. **__init__.py** - Export the config:
-
- .. code-block:: python
-
- from .cfg import MyRobotCfg
-
-6. **Register in** ``robots/__init__.py``:
-
- .. code-block:: python
-
- from .my_robot import *
-
----
-
-Key Configuration Parameters
-~~~~~~~~~~~~~~~~~~~~~~~~~~~~
-
-Regardless of the approach, your robot config needs these core parameters:
-
-+---------------------+------------------------+----------------------------------+
-| Parameter | Type | Description |
-+=====================+========================+==================================+
-| ``uid`` | str | Unique robot identifier |
-+---------------------+------------------------+----------------------------------+
-| ``urdf_cfg`` | URDFCfg | URDF file and components |
-+---------------------+------------------------+----------------------------------+
-| ``control_parts`` | Dict[str, List[str]] | Joint groups for control |
-+---------------------+------------------------+----------------------------------+
-| ``solver_cfg`` | Dict[str, SolverCfg] | IK solver configurations |
-+---------------------+------------------------+----------------------------------+
-| ``drive_pros`` | JointDrivePropertiesCfg | Joint stiffness, damping, force |
-+---------------------+------------------------+----------------------------------+
-| ``attrs`` | RigidBodyAttributesCfg | Mass, friction, damping |
-+---------------------+------------------------+----------------------------------+
-
-URDF Configuration
------------------
-
-The ``URDFCfg`` allows composing robots from multiple URDF files:
-
-.. code-block:: python
-
- urdf_cfg = URDFCfg(
- components=[
- {
- "component_type": "arm",
- "urdf_path": arm_urdf,
- "transform": np.eye(4),
- },
- {
- "component_type": "gripper",
- "urdf_path": gripper_urdf,
- "transform": gripper_transform,
- },
- ]
- )
-
-Control Parts
--------------
-
-Group joints logically for different control modes:
-
-.. code-block:: python
-
- control_parts = {
- "arm": ["JOINT1", "JOINT2", "JOINT3", "JOINT4", "JOINT5", "JOINT6"],
- "gripper": ["JOINT7", "JOINT8"],
- }
-
-Use regex patterns for flexible matching:
-- ``"JOINT[1-6]"`` matches JOINT1 through JOINT6
-- ``"(LEFT|RIGHT)_ARM.*"`` matches all arm joints
-
-Drive Properties
-----------------
-
-Configure joint physics behavior:
-
-.. code-block:: python
-
- drive_pros = JointDrivePropertiesCfg(
- stiffness={
- "ARM_JOINTS": 1e4, # High stiffness for arm joints
- "GRIPPER_JOINTS": 3e2, # Lower stiffness for gripper
- },
- damping={
- "ARM_JOINTS": 1e3,
- "GRIPPER_JOINTS": 3e1,
- },
- max_effort={
- "ARM_JOINTS": 1e5,
- "GRIPPER_JOINTS": 1e3,
- },
- )
-
-IK Solver Configuration
------------------------
-
-Choose the appropriate solver for your robot:
-
-- **OPWSolverCfg**: For 6-axis industrial arms (like CobotMagic)
-- **SRSSolverCfg**: For robots with specific kinematics (like DexforceW1)
-- **SolverCfg**: Generic solver configuration
-
-.. code-block:: python
-
- solver_cfg = {
- "arm": OPWSolverCfg(
- end_link_name="link6",
- root_link_name="base_link",
- tcp=np.array([...]), # Tool center point
- ),
- }
-
----
-
-Using Your Robot
-~~~~~~~~~~~~~~~~
-
-After adding the robot, use it in your code:
-
-.. code-block:: python
-
- from embodichain.lab.sim import SimulationManager, SimulationManagerCfg
- from embodichain.lab.sim.robots import MyRobotCfg
-
- # Create simulation
- sim_cfg = SimulationManagerCfg(headless=False, num_envs=2)
- sim = SimulationManager(sim_cfg)
-
- # Create robot config
- robot_cfg = MyRobotCfg.from_dict({
- "uid": "my_robot",
- })
-
- # Add robot to simulation
- robot = sim.add_robot(cfg=robot_cfg)
-
----
-
-Testing Your Robot
-~~~~~~~~~~~~~~~~~~
-
-Add a test block at the bottom of your robot config file:
-
-.. code-block:: python
-
- if __name__ == "__main__":
- from embodichain.lab.sim import SimulationManager, SimulationManagerCfg
-
- sim_cfg = SimulationManagerCfg(headless=True, num_envs=2)
- sim = SimulationManager(sim_cfg)
-
- robot_cfg = MyRobotCfg.from_dict({"uid": "my_robot"})
- robot = sim.add_robot(cfg=robot_cfg)
-
- print("Robot added successfully!")
-
----
-
-Best Practices
-~~~~~~~~~~~~~~
-
-1. **Use the** ``@configclass`` **decorator** for all config classes
-2. **Provide** ``from_dict`` **method** for flexible initialization
-3. **Use regex patterns** for joint names in drive properties
-4. **Keep kinematics parameters** separate in ``params.py`` for complex robots
-5. **Include** ``build_pk_serial_chain`` **method** for IK support
-6. **Add** ``to_dict`` **and** ``save_to_file`` **methods** for serialization
-7. **Test with** ``__main__`` **block** before integrating
-8. **Add robot documentation** in ``docs/source/resources/robot/`` for user reference
-
----
-
-Adding Robot Documentation
-~~~~~~~~~~~~~~~~~~~~~~~~~~~
-
-When adding a new robot, create documentation in ``docs/source/resources/robot/`` to help users understand and use your robot.
-
-File Location
--------------
-
-Create a markdown file: ``docs/source/resources/robot/my_robot.md``
-
-Recommended Structure
----------------------
-
-.. code-block:: markdown
-
- # MyRobot
-
- Brief description of the robot and its manufacturer.
-
-
-
-
MyRobot
-
-
- ## Key Features
-
- - Feature 1
- - Feature 2
- - Feature 3
-
- ---
-
- ## Robot Parameters
-
- | Parameter | Description |
- |-----------|-------------|
- | Joints | Number of joints |
- | DOF | Degrees of freedom |
- | ... | ... |
-
- ---
-
- ## Quick Initialization Example
-
- ```python
- from embodichain.lab.sim import SimulationManager, SimulationManagerCfg
- from embodichain.lab.sim.robots import MyRobotCfg
-
- config = SimulationManagerCfg(headless=False, sim_device="cpu", num_envs=2)
- sim = SimulationManager(config)
-
- robot = sim.add_robot(cfg=MyRobotCfg.from_dict({}))
- ```
-
- ---
-
- ## Configuration Parameters
-
- ### Main Configuration Items
-
- - **uid**: Unique identifier
- - **urdf_cfg**: URDF configuration
- - **control_parts**: Control groups
- - **solver_cfg**: IK solver configuration
- - **drive_pros**: Joint drive properties
- - **attrs**: Physical attributes
-
- ### Custom Usage Example
-
- ```python
- custom_cfg = {
- "uid": "my_robot",
- # Add parameters
- }
- cfg = MyRobotCfg.from_dict(custom_cfg)
- robot = sim.add_robot(cfg=cfg)
- ```
-
- ---
-
- ## References
-
- - Manufacturer product page
- - URDF file paths
- - Related documentation
-
-Register the Robot in Index
----------------------------
-
-After creating the robot documentation, add it to the index file at ``docs/source/resources/robot/index.rst``:
-
-.. code-block:: rst
-
- .. toctree::
- :maxdepth: 1
-
- Dexforce W1
- CobotMagic
- MyRobot # Add your robot here
-
----
-
-Next Steps
-~~~~~~~~~~
-
-After adding your robot:
-
-- Add robot documentation in ``docs/source/resources/robot/``
-- Update ``docs/source/resources/robot/index.rst`` to include the new robot
-- Add task environments that use your robot
-- Configure sensors (cameras, force sensors)
-- Implement custom IK solvers if needed
-- Add motion planning support
++---------------------+----------------------------+----------------------------------+
+| Parameter | Type | Description |
++=====================+============================+==================================+
+| ``uid`` | str | Unique robot identifier |
++---------------------+----------------------------+----------------------------------+
+| ``urdf_cfg`` | URDFCfg | URDF file and components |
++---------------------+----------------------------+----------------------------------+
+| ``control_parts`` | Dict[str, List[str]] | Joint groups for control |
++---------------------+----------------------------+----------------------------------+
+| ``solver_cfg`` | Dict[str, SolverCfg] | IK solver configurations |
++---------------------+----------------------------+----------------------------------+
+| ``drive_pros`` | JointDrivePropertiesCfg | Joint stiffness, damping, force |
++---------------------+----------------------------+----------------------------------+
+
+.. tip::
+
+ See the :doc:`full tutorial ` for complete code examples of both approaches.
+
+See Also
+--------
+
+- :doc:`/tutorial/add_robot` — Full step-by-step tutorial
+- :doc:`/tutorial/robot` — Using robots in simulation
+- :doc:`/overview/sim/solvers/index` — IK solver reference
+- :doc:`/resources/robot/index` — Existing robot documentation
diff --git a/docs/source/guides/cli.md b/docs/source/guides/cli.md
index debb0078..639183ca 100644
--- a/docs/source/guides/cli.md
+++ b/docs/source/guides/cli.md
@@ -64,7 +64,7 @@ python -m embodichain preview-asset \
| ``--fix_base`` | ``True`` | Fix the base of articulations |
| ``--sim_device`` | ``cpu`` | Simulation device |
| ``--headless`` | ``False`` | Run without rendering window |
-| ``--enable_rt`` | ``False`` | Enable ray tracing |
+| ``--renderer`` | ``hybrid`` | Renderer backend: ``legacy``, ``hybrid``, ``fast-rt``, or ``rt`` |
| ``--preview`` | ``False`` | Enter interactive embed mode after loading |
### Preview Mode
@@ -108,7 +108,7 @@ python -m embodichain run-env --gym_config config.json --headless
| ``--num_envs`` | ``1`` | Number of parallel environments |
| ``--device`` | ``cpu`` | Device (``cpu`` or ``cuda``) |
| ``--headless`` | ``False`` | Run in headless mode |
-| ``--enable_rt`` | ``False`` | Use RTX rendering backend |
+| ``--renderer`` | ``hybrid`` | Renderer backend: ``legacy``, ``hybrid``, ``fast-rt`` or ``rt`` |
| ``--arena_space`` | ``5.0`` | Arena space size |
| ``--gpu_id`` | ``0`` | GPU ID to use |
| ``--preview`` | ``False`` | Enter interactive preview mode |
diff --git a/docs/source/guides/configuration.md b/docs/source/guides/configuration.md
new file mode 100644
index 00000000..c031b891
--- /dev/null
+++ b/docs/source/guides/configuration.md
@@ -0,0 +1,293 @@
+# Configuration Guide
+
+EmbodiChain uses a declarative configuration system built on Python dataclasses. This guide explains the key patterns: `@configclass`, `FunctorCfg`, and JSON configuration files.
+
+---
+
+## The `@configclass` Decorator
+
+All configuration objects use the `@configclass` decorator, which is similar to Python's `@dataclass` with additional validation and serialization support.
+
+```python
+from embodichain.utils import configclass
+from dataclasses import MISSING
+
+
+@configclass
+class MyManagerCfg:
+ param_a: float = 1.0
+ param_b: str = MISSING # Required — must be set by caller
+ param_c: int = 10
+```
+
+- **Optional parameters** have default values.
+- **Required parameters** use `MISSING` as the default — callers must provide them.
+- All parameters are typed for IDE auto-completion and static analysis.
+
+---
+
+## Configuration Hierarchy
+
+EmbodiChain configs form a nested hierarchy:
+
+```
+EmbodiedEnvCfg
+├── sim_cfg: SimulationManagerCfg
+│ ├── render_cfg: RenderCfg
+│ ├── physics_config: PhysicsCfg
+│ └── gpu_memory_config: GPUMemoryCfg
+├── robot: RobotCfg
+│ ├── urdf_cfg: URDFCfg
+│ ├── drive_pros: JointDrivePropertiesCfg
+│ └── solver_cfg: Dict[str, SolverCfg]
+├── sensor: List[SensorCfg]
+├── events: EventCfg
+├── observations: ObservationCfg
+├── rewards: RewardCfg
+├── actions: ActionTermCfg
+├── dataset: DatasetFunctorCfg
+└── extensions: Dict[str, Any]
+```
+
+Each sub-config can be set independently, allowing fine-grained control over the environment.
+
+---
+
+## Functor Configuration
+
+Functors are configured through specialized config classes that inherit from `FunctorCfg`. The base class has three fields:
+
+```python
+@configclass
+class FunctorCfg:
+ func: Callable | Functor = MISSING # The function or class to call
+ params: dict[str, Any] = dict() # Keyword arguments
+ extra: dict[str, Any] = dict() # Optional metadata
+```
+
+### Specialized Config Classes
+
+| Config Class | Extra Fields | Used By |
+|---|---|---|
+| `ObservationCfg` | `mode`, `name` | ObservationManager |
+| `EventCfg` | `mode`, `interval_step`, `is_global` | EventManager |
+| `RewardCfg` | `weight`, `mode` | RewardManager |
+| `ActionTermCfg` | `mode` | ActionManager |
+| `DatasetFunctorCfg` | `mode` | DatasetManager |
+
+### Python Config Example
+
+```python
+from embodichain.utils import configclass
+from embodichain.lab.gym.envs.managers.cfg import (
+ ObservationCfg,
+ RewardCfg,
+ EventCfg,
+ SceneEntityCfg,
+)
+from embodichain.lab.gym.envs.managers.observations import get_object_pose
+
+
+@configclass
+class MyObsCfg:
+ object_pose: ObservationCfg = ObservationCfg(
+ func=get_object_pose,
+ mode="add",
+ name="object/pose",
+ params={"entity_cfg": SceneEntityCfg(uid="my_cube")},
+ )
+
+
+@configclass
+class MyRewardCfg:
+ distance: RewardCfg = RewardCfg(
+ func="distance_between_objects",
+ weight=0.5,
+ params={
+ "source_entity_cfg": SceneEntityCfg(uid="cube"),
+ "target_entity_cfg": SceneEntityCfg(uid="target"),
+ },
+ )
+
+
+@configclass
+class MyEventCfg:
+ randomize_light: EventCfg = EventCfg(
+ func="randomize_light",
+ mode="interval",
+ interval_step=5,
+ params={"light_uid": "main_light"},
+ )
+```
+
+---
+
+## JSON Configuration
+
+For RL training and data generation, EmbodiChain uses JSON config files. The JSON config mirrors the Python config structure but uses string names instead of direct function references.
+
+### Environment Config (`gym_config.json`)
+
+```json
+{
+ "max_episodes": 100,
+ "max_episode_steps": 600,
+ "env": {
+ "num_envs": 4,
+ "sim_cfg": {
+ "sim_device": "cuda:0",
+ "headless": true
+ },
+ "robot": {
+ "uid": "robot",
+ "urdf_cfg": {"fpath": "robots/my_robot/my_robot.urdf"}
+ },
+ "control_parts": ["arm"],
+ "sensor": [
+ {
+ "uid": "cam_high",
+ "type": "StereoCamera",
+ "height": 540,
+ "width": 960
+ }
+ ],
+ "actions": {
+ "delta_qpos": {
+ "func": "DeltaQposTerm",
+ "params": {"scale": 0.1}
+ }
+ },
+ "events": {
+ "randomize_table": {
+ "func": "randomize_visual_material",
+ "mode": "interval",
+ "interval_step": 10,
+ "params": {"uid": "table"}
+ }
+ },
+ "observations": {
+ "obj_pose": {
+ "func": "get_object_pose",
+ "mode": "add",
+ "name": "object/pose",
+ "params": {"entity_cfg": {"uid": "cube"}}
+ }
+ },
+ "rewards": {
+ "distance": {
+ "func": "distance_between_objects",
+ "weight": 0.5,
+ "params": {
+ "source_entity_cfg": {"uid": "cube"},
+ "target_entity_cfg": {"uid": "target"}
+ }
+ }
+ },
+ "dataset": {
+ "lerobot": {
+ "func": "LeRobotRecorder",
+ "mode": "save",
+ "params": {
+ "save_path": "/path/to/output",
+ "robot_meta": {"robot_type": "DexforceW1"},
+ "use_videos": true
+ }
+ }
+ },
+ "extensions": {
+ "success_threshold": 0.1
+ }
+ }
+}
+```
+
+### RL Training Config (`train_config.json`)
+
+```json
+{
+ "trainer": {
+ "exp_name": "push_cube",
+ "seed": 42,
+ "device": "cuda:0",
+ "iterations": 500,
+ "buffer_size": 1024
+ },
+ "env": {
+ "id": "PushCubeRL",
+ "cfg": {
+ "num_envs": 4,
+ "actions": {
+ "delta_qpos": {
+ "func": "DeltaQposTerm",
+ "params": {"scale": 0.1}
+ }
+ }
+ }
+ },
+ "policy": {
+ "name": "actor_critic",
+ "actor": {
+ "type": "mlp",
+ "network_cfg": {"hidden_sizes": [256, 256], "activation": "relu"}
+ },
+ "critic": {
+ "type": "mlp",
+ "network_cfg": {"hidden_sizes": [256, 256], "activation": "relu"}
+ }
+ },
+ "algorithm": {
+ "name": "ppo",
+ "cfg": {
+ "learning_rate": 0.0001,
+ "n_epochs": 10,
+ "batch_size": 64,
+ "gamma": 0.99,
+ "gae_lambda": 0.95,
+ "clip_coef": 0.2
+ }
+ }
+}
+```
+
+---
+
+## String-Based Function Resolution
+
+In JSON configs, functor functions are specified by name (string). EmbodiChain resolves these strings at runtime by searching registered modules. For example:
+
+- `"distance_between_objects"` resolves to `embodichain.lab.gym.envs.managers.rewards.distance_between_objects`
+- `"DeltaQposTerm"` resolves to `embodichain.lab.gym.envs.managers.actions.DeltaQposTerm`
+- `"get_object_pose"` resolves to `embodichain.lab.gym.envs.managers.observations.get_object_pose`
+
+When writing custom functors, make sure they are imported in the module's `__init__.py` so the resolver can find them.
+
+---
+
+## `SceneEntityCfg` in JSON
+
+When referencing scene entities in JSON, use a dictionary with a `uid` key:
+
+```json
+{"uid": "my_cube"}
+```
+
+This is automatically converted to a `SceneEntityCfg` object at runtime.
+
+---
+
+## Tips
+
+1. **Start from an existing config.** Copy a config file from `configs/gym/` and modify it for your task.
+2. **Use Python configs for development.** They provide IDE auto-completion and type checking.
+3. **Use JSON configs for experiments.** They are easier to version, diff, and share.
+4. **Validate configs early.** Run your environment with a short episode count to catch config errors before long training runs.
+5. **Keep config pairs together.** For action-bank tasks, version `gym_config.json` and `action_config.json` together.
+
+---
+
+## See Also
+
+- [Custom Functors Guide](custom_functors.md) — How to write observation, reward, event, and action functors
+- [Embodied Environments](../overview/gym/env.md) — Full environment configuration reference
+- [Tutorial: Modular Environment](../tutorial/modular_env.rst) — Complete example using config-driven setup
+- [Tutorial: RL Training](../tutorial/rl.rst) — RL training configuration walkthrough
diff --git a/docs/source/guides/custom_functors.md b/docs/source/guides/custom_functors.md
new file mode 100644
index 00000000..383754f1
--- /dev/null
+++ b/docs/source/guides/custom_functors.md
@@ -0,0 +1,390 @@
+# Writing Custom Functors
+
+Functors are the building blocks of EmbodiChain's manager system. They define how observations are computed, rewards are calculated, events are triggered, actions are preprocessed, and datasets are recorded.
+
+This guide explains the two functor styles (function and class), how to register them in manager configs, and provides examples for each functor type.
+
+---
+
+## Functor Basics
+
+Every functor is configured through a `FunctorCfg` object with three fields:
+
+| Field | Type | Description |
+|-------|------|-------------|
+| `func` | `Callable \| Functor` | The function or class to call. **Required.** |
+| `params` | `dict` | Keyword arguments passed to the function. |
+| `extra` | `dict` | Optional metadata (e.g., observation shapes). |
+
+The `func` field can be:
+- A **function** (callable) — receives the environment as the first argument, plus any `params` as keyword arguments.
+- A **class** inheriting from `Functor` — instantiated with `(cfg, env)`, then called via `__call__`.
+
+---
+
+## Function-Style Functors
+
+Function-style functors are plain Python functions. They are stateless and easy to write. Use them when your functor is a simple computation that doesn't need to maintain state between calls.
+
+### General Pattern
+
+```python
+def my_functor(env, obs, **kwargs) -> torch.Tensor:
+ """Compute something from the environment state.
+
+ Args:
+ env: The environment instance.
+ obs: The current observation dictionary.
+ **kwargs: Additional parameters from FunctorCfg.params.
+
+ Returns:
+ A tensor of shape (num_envs, ...).
+ """
+ # Access environment state
+ value = compute_value(env)
+
+ return value
+```
+
+The exact signature depends on the functor type (see below).
+
+### Example: Observation Functor
+
+Observation functors receive `(env, obs)` plus any params. They must return a tensor.
+
+```python
+from __future__ import annotations
+import torch
+from embodichain.lab.gym.envs import EmbodiedEnv
+from embodichain.lab.gym.envs.managers.observations import EnvObs
+from embodichain.lab.sim.cfg import SceneEntityCfg
+
+
+def get_object_height(
+ env: EmbodiedEnv,
+ obs: EnvObs,
+ entity_cfg: SceneEntityCfg,
+) -> torch.Tensor:
+ """Get the Z-coordinate (height) of an object.
+
+ Args:
+ env: The environment instance.
+ obs: The current observation dictionary.
+ entity_cfg: Scene entity configuration with the object UID.
+
+ Returns:
+ Tensor of shape (num_envs, 1) with the object height.
+ """
+ obj = env.sim.get_rigid_object(entity_cfg.uid)
+ pose = obj.get_local_pose(to_matrix=True) # (num_envs, 4, 4)
+ height = pose[:, 2, 3:4] # Extract Z from translation
+ return height
+```
+
+Register it in your environment config:
+
+```python
+from embodichain.lab.gym.envs.managers.cfg import ObservationCfg, SceneEntityCfg
+from embodichain.utils import configclass
+
+
+@configclass
+class MyObsCfg:
+ obj_height: ObservationCfg = ObservationCfg(
+ func=get_object_height,
+ mode="add",
+ name="object/height",
+ params={"entity_cfg": SceneEntityCfg(uid="my_cube")},
+ )
+```
+
+Or in JSON:
+
+```json
+"observations": {
+ "obj_height": {
+ "func": "get_object_height",
+ "mode": "add",
+ "name": "object/height",
+ "params": {"entity_cfg": {"uid": "my_cube"}}
+ }
+}
+```
+
+### Example: Reward Functor
+
+Reward functors receive `(env, obs, action, info)` plus any params. They return a tensor of shape `(num_envs,)`.
+
+```python
+import torch
+from embodichain.lab.gym.envs import EmbodiedEnv
+from embodichain.lab.sim.cfg import SceneEntityCfg
+
+
+def target_height_reward(
+ env: EmbodiedEnv,
+ obs: dict,
+ action,
+ info: dict,
+ entity_cfg: SceneEntityCfg = None,
+ target_height: float = 0.5,
+) -> torch.Tensor:
+ """Reward for lifting an object to a target height.
+
+ Returns:
+ Negative distance to the target height. Shape (num_envs,).
+ """
+ obj = env.sim.get_rigid_object(entity_cfg.uid)
+ pose = obj.get_local_pose(to_matrix=True)
+ current_height = pose[:, 2, 3]
+ return -torch.abs(current_height - target_height)
+```
+
+Register it:
+
+```python
+from embodichain.lab.gym.envs.managers.cfg import RewardCfg
+from embodichain.utils import configclass
+
+
+@configclass
+class MyRewardCfg:
+ lift_reward: RewardCfg = RewardCfg(
+ func=target_height_reward,
+ weight=1.0,
+ params={
+ "entity_cfg": SceneEntityCfg(uid="my_cube"),
+ "target_height": 0.5,
+ },
+ )
+```
+
+---
+
+## Class-Style Functors
+
+Class-style functors inherit from `Functor` and implement `__init__(cfg, env)` and `__call__(...)`. Use them when you need to:
+
+- Maintain state across calls (e.g., caching, counters)
+- Perform expensive initialization once
+- Implement a `reset()` method for per-episode cleanup
+
+### General Pattern
+
+```python
+from embodichain.lab.gym.envs.managers import Functor
+from embodichain.lab.gym.envs.managers.cfg import FunctorCfg
+
+
+class MyFunctor(Functor):
+ """A stateful functor."""
+
+ def __init__(self, cfg: FunctorCfg, env):
+ super().__init__(cfg, env)
+ # Initialize state, buffers, etc.
+ self._counter = 0
+
+ def reset(self, env_ids=None):
+ """Called on environment reset."""
+ self._counter = 0
+
+ def __call__(self, env, obs, **kwargs):
+ """Called every step."""
+ self._counter += 1
+ # Compute and return result
+```
+
+### Example: Observation Functor with Caching
+
+```python
+from __future__ import annotations
+import torch
+from embodichain.lab.gym.envs import EmbodiedEnv
+from embodichain.lab.gym.envs.managers import Functor
+from embodichain.lab.gym.envs.managers.cfg import FunctorCfg, ObservationCfg
+from embodichain.lab.sim.cfg import SceneEntityCfg
+
+
+class get_object_mass(Functor):
+ """Get the mass of a rigid object, with caching.
+
+ Caches the result to avoid repeated queries to the physics engine.
+ Cache is cleared on environment reset.
+ """
+
+ def __init__(self, cfg: FunctorCfg, env: EmbodiedEnv):
+ super().__init__(cfg, env)
+ self._cache = {}
+
+ def reset(self, env_ids=None):
+ self._cache.clear()
+
+ def __call__(
+ self,
+ env: EmbodiedEnv,
+ obs,
+ entity_cfg: SceneEntityCfg,
+ ) -> torch.Tensor:
+ uid = entity_cfg.uid
+ if uid in self._cache:
+ return self._cache[uid].clone()
+
+ obj = env.sim.get_rigid_object(uid)
+ mass = obj.get_mass() # (num_envs, 1)
+
+ self._cache[uid] = mass.clone()
+ return mass
+```
+
+### Example: Action Functor
+
+Action functors inherit from `ActionTerm` and implement `process_action`. They transform raw policy actions into robot control commands.
+
+```python
+from __future__ import annotations
+import torch
+from embodichain.lab.gym.envs.managers.actions import ActionTerm
+from embodichain.lab.gym.envs.managers.cfg import ActionTermCfg
+
+
+class DeltaQposTerm(ActionTerm):
+ """Delta joint position: current_qpos + scale * action -> target qpos.
+
+ The policy outputs a position offset, which is added to the current
+ joint positions to get the target.
+ """
+
+ def __init__(self, cfg: ActionTermCfg, env):
+ super().__init__(cfg, env)
+ self._scale = cfg.params.get("scale", 1.0)
+
+ @property
+ def input_key(self) -> str:
+ return "qpos"
+
+ @property
+ def action_dim(self) -> int:
+ return len(self._env.active_joint_ids)
+
+ def process_action(self, action: torch.Tensor) -> torch.Tensor:
+ return action * self._scale + self._env.robot.get_qpos()
+```
+
+Register it in JSON config:
+
+```json
+"actions": {
+ "delta_qpos": {
+ "func": "DeltaQposTerm",
+ "params": {"scale": 0.1}
+ }
+}
+```
+
+---
+
+## Functor Signature Reference
+
+Each functor type has a specific call signature:
+
+### Observation Functors
+
+```python
+def my_obs_functor(env, obs, **params) -> torch.Tensor
+```
+
+- `env`: The environment instance.
+- `obs`: The current observation dictionary.
+- Additional params from `ObservationCfg.params`.
+- Returns: tensor of shape `(num_envs, ...)`.
+
+Config class: `ObservationCfg` with `mode` (`"add"` or `"modify"`) and `name`.
+
+### Reward Functors
+
+```python
+def my_reward_functor(env, obs, action, info, **params) -> torch.Tensor
+```
+
+- `env`: The environment instance.
+- `obs`: The current observation dictionary.
+- `action`: The action taken this step.
+- `info`: The info dictionary.
+- Additional params from `RewardCfg.params`.
+- Returns: tensor of shape `(num_envs,)`.
+
+Config class: `RewardCfg` with `weight` and `mode` (`"add"` or `"replace"`).
+
+### Event Functors
+
+```python
+def my_event_functor(env, env_ids, **params) -> None
+```
+
+- `env`: The environment instance.
+- `env_ids`: The environment IDs affected by this event.
+- Additional params from `EventCfg.params`.
+- Returns: `None` (events modify the environment in-place).
+
+Config class: `EventCfg` with `mode` (`"startup"`, `"reset"`, or `"interval"`) and `interval_step`.
+
+### Action Functors
+
+```python
+class MyActionTerm(ActionTerm):
+ def process_action(self, action: torch.Tensor) -> torch.Tensor
+```
+
+- `action`: Raw action from the policy, shape `(num_envs, action_dim)`.
+- Returns: transformed action tensor.
+
+Config class: `ActionTermCfg` with `mode` (`"pre"` or `"post"`).
+
+### Dataset Functors
+
+Dataset functors handle recording and saving. In most cases you should use the built-in `LeRobotRecorder` rather than writing a custom one.
+
+Config class: `DatasetFunctorCfg` with `mode` (`"save"`).
+
+---
+
+## Using `SceneEntityCfg` in Params
+
+Many functors need to reference scene objects (robots, rigid objects, sensors). Instead of passing string UIDs directly, use `SceneEntityCfg`:
+
+```python
+from embodichain.lab.sim.cfg import SceneEntityCfg
+
+params = {
+ "entity_cfg": SceneEntityCfg(uid="my_cube"),
+}
+```
+
+The manager automatically resolves `SceneEntityCfg` objects to the actual simulation entities at runtime.
+
+---
+
+## File Placement
+
+| Functor Type | Recommended Location |
+|---|---|
+| Observation | `embodichain/lab/gym/envs/managers/observations.py` |
+| Reward | `embodichain/lab/gym/envs/managers/rewards.py` |
+| Event | `embodichain/lab/gym/envs/managers/events.py` or `embodichain/lab/gym/envs/managers/randomization/` |
+| Action | `embodichain/lab/gym/envs/managers/actions.py` |
+| Dataset | `embodichain/lab/gym/envs/managers/datasets.py` |
+
+For task-specific functors, place them in the task module file (e.g., alongside the task environment class).
+
+Remember to:
+- Add the functor to `__all__` in the module.
+- Add the Apache 2.0 license header.
+- Use type annotations with `from __future__ import annotations`.
+
+---
+
+## See Also
+
+- [Configuration Guide](configuration.md) — How to set up `@configclass` configs and JSON files
+- [Embodied Environments](../overview/gym/env.md) — Full environment architecture
+- [Tutorial: Modular Environment](../tutorial/modular_env.rst) — Using functors in a complete environment
diff --git a/docs/source/guides/index.rst b/docs/source/guides/index.rst
index e5c0f2de..f44ad5a0 100644
--- a/docs/source/guides/index.rst
+++ b/docs/source/guides/index.rst
@@ -1,10 +1,14 @@
How-to Guides
-=========
+=============
+
+Practical guides for common tasks in EmbodiChain.
.. toctree::
:maxdepth: 1
:hidden:
+ custom_functors
+ configuration
add_robot
cli
diff --git a/docs/source/index.rst b/docs/source/index.rst
index c3a47f2f..4bae98ae 100644
--- a/docs/source/index.rst
+++ b/docs/source/index.rst
@@ -1,7 +1,9 @@
EmbodiChain Documentation
=========================
-Welcome to the EmbodiChain!
+EmbodiChain is a GPU-accelerated robotics simulation framework for embodied AI research. It provides tools for building generating and processing simulation assets and scenes, creating robot learning environments, generating expert demonstration data, training policies with imitation learning and reinforcement learning, and deploying models into real world.
+
+The framework is built on top of `DexSim `_, a high-performance physics and rendering engine, designed for Embodied AI research and production use.
Table of Contents
=================
@@ -59,4 +61,3 @@ Table of Contents
:titlesonly:
api_reference/index
-
diff --git a/docs/source/introduction.rst b/docs/source/introduction.rst
index 3f2254d4..d437b4fb 100644
--- a/docs/source/introduction.rst
+++ b/docs/source/introduction.rst
@@ -1,59 +1,73 @@
-.. EmbodiChain documentation master file, created by
- sphinx-quickstart on Tue Nov 19 11:00:25 2024.
- You can adapt this file completely to your liking, but it should at least
- contain the root `toctree` directive.
-
EmbodiChain
-======================================
+===========
.. image:: ../../assets/imgs/teaser.jpg
- :alt: teaser
-
----
-EmbodiChain is an end-to-end, GPU-accelerated framework for Embodied AI. It streamlines research and development by unifying high-performance simulation, real-to-sim data pipelines, modular model architectures, and efficient training workflows. This integration enables rapid experimentation, seamless deployment of intelligent agents, and effective Sim2Real transfer for real-world robotic systems.
+EmbodiChain is an end-to-end, GPU-accelerated framework for Embodied AI.
+It streamlines research and development by unifying high-performance
+simulation, automated generative data pipelines, modular model
+architectures, and efficient training workflows. This integration
+enables rapid experimentation, seamless deployment of intelligent
+agents, and effective Sim2Real transfer for real-world robotic systems.
.. NOTE::
- EmbodiChain is in Alpha and under active development:
-
- * More features will be continually added in the coming months. You can find more details in the `roadmap `_.
- * Since this is an early release, we welcome feedback (bug reports, feature requests, etc.) via GitHub Issues.
-
+ EmbodiChain is in Alpha and under active development: * More
+ features will be continually added in the coming months. You can find
+ more details in the
+ `roadmap `__.
+ * Since this is an early release, we welcome feedback (bug reports,
+ feature requests, etc.) via GitHub Issues.
Key Features
------------
-* 🚀 **High-Fidelity GPU Simulation**: Realistic physics for rigid & deformable objects, advanced ray-traced sensors, all GPU-accelerated for high-throughput batch simulation.
-* 🤖 **Unified Robot Learning Environment**: Standardized interfaces for Imitation Learning, Reinforcement Learning, and more.
-* 📊 **Scalable Data Pipeline**: Automated data collection, efficient processing, and large-scale generation for model training.
-* ⚡ **Efficient Training & Evaluation**: Online data streaming, parallel environment rollouts, and modern training paradigms.
-* 🧩 **Modular & Extensible**: Easily integrate new robots, environments, and learning algorithms.
+- 🚀 **High-Fidelity GPU Simulation**: Realistic physics for rigid &
+ deformable objects, advanced ray-traced sensors, all GPU-accelerated
+ for high-throughput batch simulation.
+- 🤖 **Unified Robot Learning Environment**: Standardized interfaces for
+ Imitation Learning, Reinforcement Learning, and more.
+- 📊 **Scalable Data Pipeline**: Automated data collection, efficient
+ processing, and large-scale generation for model training.
+- ⚡ **Efficient Training & Evaluation**: Online data streaming,
+ parallel environment rollouts, and modern training paradigms.
+- 🧩 **Modular & Extensible**: Easily integrate new robots,
+ environments, and learning algorithms.
The figure below illustrates the overall architecture of EmbodiChain:
.. image:: ../../assets/imgs/frameworks.jpg
- :alt: frameworks
+ :align: center
Getting Started
---------------
To get started with EmbodiChain, follow these steps:
-* `Installation Guide `_
-* `Quick Start Tutorial `_
-* `API Reference `_
+- `Installation
+ Guide `__
+- `Quick Start
+ Tutorial `__
+- `API
+ Reference `__
+Contribution Guide
+------------------
+
+We welcome contributions! Please see the
+`CONTRIBUTING.md `__ file in this repository for
+guidelines on how to get started.
Citation
--------
-If you find EmbodiChain helpful for your research, please consider citing our work:
+If you find EmbodiChain helpful for your research, please consider
+citing our work:
.. code-block:: bibtex
@misc{EmbodiChain,
author = {EmbodiChain Developers},
- title = {EmbodiChain: An end-to-end, GPU-accelerated, and modular platform for building generalized Embodied Intelligence.},
+ title = {EmbodiChain: An end-to-end, GPU-accelerated, and modular platform for building generalized Embodied Intelligence},
month = {November},
year = {2025},
url = {https://github.com/DexForce/EmbodiChain}
@@ -68,15 +82,14 @@ If you find EmbodiChain helpful for your research, please consider citing our wo
month = {October},
year = {2025},
journal = {TechRxiv}
- }
+ }
.. code-block:: bibtex
@inproceedings{Sim2RealVLA,
- title = {Sim2Real {VLA}: Zero-Shot Generalization of Synthesized Skills to Realistic Manipulation},
- author = {Runyi Zhao, Sheng Xu, Ruixing Jin, Yueci Deng, Yunxin Tai, Kui Jia, Guiliang Liu},
- booktitle = {The Fourteenth International Conference on Learning Representations, ICLR},
- year = {2026},
- url = {https://openreview.net/forum?id=H4SyKHjd4c}
+ title = {Sim2Real {VLA}: Zero-Shot Generalization of Synthesized Skills to Realistic Manipulation},
+ author = {Runyi Zhao, Sheng Xu, Ruixing Jin, Yueci Deng, Yunxin Tai, Kui Jia, Guiliang Liu},
+ booktitle = {The Fourteenth International Conference on Learning Representations, ICLR},
+ year = {2026},
+ url = {https://openreview.net/forum?id=H4SyKHjd4c}
}
-
diff --git a/docs/source/overview/gym/action_functors.md b/docs/source/overview/gym/action_functors.md
index 670fa078..225424da 100644
--- a/docs/source/overview/gym/action_functors.md
+++ b/docs/source/overview/gym/action_functors.md
@@ -5,6 +5,10 @@
This page lists all available action terms that can be used with the Action Manager. Action terms are configured using {class}`~cfg.ActionTermCfg` and are responsible for processing raw actions from the policy and converting them to the format expected by the robot (e.g., qpos, qvel, qf).
+````{tip}
+**Using an AI coding agent?** Use the **`/add-functor`** skill to scaffold a new action term with the correct class structure, `ActionTermCfg` registration, and module placement in `actions.py`.
+````
+
## Joint Position Control
```{list-table} Joint Position Action Terms
diff --git a/docs/source/overview/gym/dataset_functors.md b/docs/source/overview/gym/dataset_functors.md
index a418bc6e..c043ee68 100644
--- a/docs/source/overview/gym/dataset_functors.md
+++ b/docs/source/overview/gym/dataset_functors.md
@@ -5,6 +5,10 @@
This page lists all available dataset functors that can be used with the Dataset Manager. Dataset functors are configured using {class}`~cfg.DatasetFunctorCfg` and are responsible for collecting and saving episode data during environment interaction.
+````{tip}
+**Using an AI coding agent?** Use the **`/add-functor`** skill to scaffold a new dataset functor with the correct signature, `DatasetFunctorCfg` registration, and module placement in `datasets.py`.
+````
+
## Recording Functors
```{list-table} Dataset Recording Functors
diff --git a/docs/source/overview/gym/env.md b/docs/source/overview/gym/env.md
index fa7c9bc9..88f44fb9 100644
--- a/docs/source/overview/gym/env.md
+++ b/docs/source/overview/gym/env.md
@@ -229,6 +229,16 @@ In JSON config, use the ``actions`` section:
## Creating a Custom Task
+````{tip}
+**Using an AI coding agent?** The following skills can scaffold boilerplate for you:
+
+- **`/add-task-env`** — Generate a new task environment with the correct file structure, `@register_env` decorator, base class methods, `__init__.py` update, and test stub.
+- **`/add-functor`** — Add observation, reward, event, or randomization functors with the correct signature and module placement.
+- **`/add-test`** — Write tests following project conventions (pytest or class style, mock patterns, correct file placement).
+- **`/pre-commit-check`** — Run all local CI checks (black, headers, `__all__`, type annotations) before committing.
+
+````
+
### For Reinforcement Learning Tasks
Inherit from {class}`~envs.EmbodiedEnv` and implement the task-specific logic. Configure the Action Manager via ``actions`` in your config:
@@ -295,6 +305,8 @@ For a complete example of a modular environment setup, please refer to the {ref}
- {ref}`tutorial_modular_env` - Advanced modular environment setup
- {ref}`tutorial_rl` - Reinforcement learning training guide
- {doc}`/api_reference/embodichain/embodichain.lab.gym.envs` - Complete API reference for EmbodiedEnv and configurations
+- {doc}`/guides/custom_functors` - How to write custom functors
+- {doc}`/guides/configuration` - Configuration system guide
```{toctree}
:maxdepth: 1
diff --git a/docs/source/overview/gym/event_functors.md b/docs/source/overview/gym/event_functors.md
index 2ddbb19f..46ed991a 100644
--- a/docs/source/overview/gym/event_functors.md
+++ b/docs/source/overview/gym/event_functors.md
@@ -5,6 +5,10 @@
This page lists all available event functors that can be used with the Event Manager. Event functors are configured using {class}`~cfg.EventCfg` and can be triggered at different stages: ``startup``, ``reset``, or ``interval``.
+````{tip}
+**Using an AI coding agent?** Use the **`/add-functor`** skill to scaffold a new event or randomization functor with the correct signature (`env, env_ids, ...`), function or class style, and module placement. Use **`/add-test`** to generate mock-based tests.
+````
+
## Physics Randomization
```{list-table} Physics Randomization Functors
diff --git a/docs/source/overview/gym/observation_functors.md b/docs/source/overview/gym/observation_functors.md
index bf2b7915..bb67cce6 100644
--- a/docs/source/overview/gym/observation_functors.md
+++ b/docs/source/overview/gym/observation_functors.md
@@ -5,6 +5,10 @@
This page lists all available observation functors that can be used with the Observation Manager. Observation functors are configured using {class}`~cfg.ObservationCfg` and can operate in two modes: ``modify`` (update existing observations) or ``add`` (add new observations).
+````{tip}
+**Using an AI coding agent?** Use the **`/add-functor`** skill to scaffold a new observation functor with the correct signature (`env, obs, entity_cfg, ...`), module placement in `observations.py`, and `__all__` export. Use **`/add-test`** to generate mock-based tests.
+````
+
## Pose Computations
```{list-table} Pose Computation Functors
diff --git a/docs/source/overview/gym/reward_functors.md b/docs/source/overview/gym/reward_functors.md
index ad0255fd..fb91cbf0 100644
--- a/docs/source/overview/gym/reward_functors.md
+++ b/docs/source/overview/gym/reward_functors.md
@@ -5,6 +5,10 @@
This page lists all available reward functors that can be used with the Reward Manager. Reward functors are configured using {class}`~cfg.RewardCfg` and return scalar reward tensors that are weighted and summed to form the total environment reward.
+````{tip}
+**Using an AI coding agent?** Use the **`/add-functor`** skill to scaffold a new reward functor with the correct signature (`env, obs, action, info, ...`), module placement in `rewards.py`, and `__all__` export. Use **`/add-test`** to generate mock-based tests.
+````
+
## Distance-Based Rewards
```{list-table} Distance-Based Reward Functors
diff --git a/docs/source/overview/rl/index.rst b/docs/source/overview/rl/index.rst
index cac282f4..df2fd29e 100644
--- a/docs/source/overview/rl/index.rst
+++ b/docs/source/overview/rl/index.rst
@@ -79,3 +79,11 @@ See also
config.md
train_script.md
multi_gpu.md
+
+See Also
+--------
+
+- :doc:`/tutorial/rl` — Step-by-step RL training tutorial
+- :doc:`/overview/gym/env` — EmbodiedEnv configuration and Action Manager
+- :doc:`/features/online_data` — Online data streaming pipeline
+- :doc:`/resources/task/index` — Available RL task environments
diff --git a/docs/source/overview/sim/atomic_actions.md b/docs/source/overview/sim/atomic_actions.md
new file mode 100644
index 00000000..979df571
--- /dev/null
+++ b/docs/source/overview/sim/atomic_actions.md
@@ -0,0 +1,241 @@
+# Atomic Actions
+
+```{currentmodule} embodichain.lab.sim.atomic_actions
+```
+
+Atomic actions are the building blocks for automated robot motion generation. Each action encapsulates a complete, self-contained motion primitive — such as picking up an object or moving to a pose — that can be chained together to form complex manipulation workflows.
+
+## Design Overview
+
+The module is organized into three layers:
+
+```
+AtomicActionEngine ← orchestrates a sequence of actions
+ │
+ ├── AtomicAction(s) ← each action plans one motion primitive
+ │ │
+ │ └── MotionGenerator ← low-level trajectory planner (IK + trajectory optimization)
+ │
+ └── SemanticAnalyzer ← resolves object labels → ObjectSemantics
+```
+
+Each action receives a target (object semantics or a pose tensor), runs its planning pipeline,
+and returns a joint trajectory. The engine threads the end state of each action as the start
+state of the next, then concatenates all trajectories into one contiguous sequence:
+
+```
+ObjectSemantics ──► AffordanceEstimation ──► AtomicAction.execute()
+(label + geometry │
+ + affordance ├─ IK solve
+ + entity) ├─ Motion plan
+ └─ Gripper interpolation
+ │
+AtomicActionEngine ◄─────────────── PlanResult ───────┘
+(sequences actions, accumulates
+ full-robot trajectory)
+```
+
+### Core Concepts
+
+**`ObjectSemantics`** describes an interaction target. It bundles:
+- `geometry` — mesh data (vertices, triangles) used for grasp annotation
+- `affordance` — *how* to interact with the object (e.g. antipodal grasp poses)
+- `entity` — a live reference to the simulation object, so actions can read its current pose
+
+**`Affordance`** is a data class that encodes a specific interaction capability. The built-in affordance types are:
+
+| Class | Use case |
+|---|---|
+| `AntipodalAffordance` | Parallel-jaw grasping via antipodal point pairs |
+| `InteractionPoints` | Contact-based interactions (push, poke, touch) |
+
+**`AtomicAction`** is the abstract base class for all motion primitives. Every action must implement:
+- `execute(target, start_qpos)` — plan and return a joint trajectory
+- `validate(target, start_qpos)` — fast feasibility check without full planning
+
+**`AtomicActionEngine`** manages a named registry of actions and runs them in sequence via `execute_static()`, threading the end state of each action as the start state of the next.
+
+---
+
+## Built-in Actions
+
+(supported_atomic_actions)=
+
+The following actions are available out of the box:
+
+| Action | Config class | Target type | Motion phases |
+|---|---|---|---|
+| `MoveAction` | `MoveActionCfg` | `Tensor (4,4)` — EEF pose | Move arm to pose |
+| `PickUpAction` | `PickUpActionCfg` | `ObjectSemantics` or `Tensor (4,4)` | Approach → close gripper → lift |
+| `PlaceAction` | `PlaceActionCfg` | `Tensor (4,4)` — EEF release pose | Lower → open gripper → retract |
+
+### `MoveAction`
+
+Moves the end-effector to a target pose in free space.
+
+| Config field | Default | Description |
+|---|---|---|
+| `control_part` | `"arm"` | Robot control part to move |
+| `sample_interval` | `50` | Number of waypoints in the trajectory |
+
+**Target:** `torch.Tensor` of shape `(4, 4)` or `(n_envs, 4, 4)` — a homogeneous EEF pose.
+
+---
+
+### `PickUpAction`
+
+Three-phase grasp motion: *approach → close gripper → lift*.
+
+| Config field | Default | Description |
+|---|---|---|
+| `approach_direction` | `[0, 0, -1]` | Gripper approach direction in object frame |
+| `pre_grasp_distance` | `0.15` | Hover distance before descending (m) |
+| `lift_height` | `0.10` | Lift height after grasping (m) |
+| `hand_open_qpos` | `None` | **Required.** Gripper open joint positions |
+| `hand_close_qpos` | `None` | **Required.** Gripper closed joint positions |
+| `hand_control_part` | `"hand"` | Robot control part for the gripper |
+| `hand_interp_steps` | `5` | Waypoints for the gripper close phase |
+| `sample_interval` | `80` | Total waypoints across all three phases |
+
+**Target:** `ObjectSemantics` (grasp pose computed automatically) **or** a `torch.Tensor` EEF pose.
+
+---
+
+### `PlaceAction`
+
+Three-phase release motion: *lower → open gripper → retract*. Mirrors `PickUpAction`.
+
+Inherits all gripper config fields from `GraspActionCfg`. The `approach_direction` field is not used — the arm moves straight down to the target pose.
+
+**Target:** `torch.Tensor` of shape `(4, 4)` or `(n_envs, 4, 4)` — the EEF pose at release.
+
+---
+
+## Typical Workflow
+
+```python
+from embodichain.lab.sim.atomic_actions import (
+ AtomicActionEngine,
+ ObjectSemantics,
+ AntipodalAffordance,
+ PickUpActionCfg,
+ PlaceActionCfg,
+ MoveActionCfg,
+)
+
+# 1. Configure each action
+pickup_cfg = PickUpActionCfg(
+ control_part="arm",
+ hand_control_part="hand",
+ hand_open_qpos=torch.tensor([0.0, 0.0]),
+ hand_close_qpos=torch.tensor([0.025, 0.025]),
+)
+place_cfg = PlaceActionCfg(...)
+move_cfg = MoveActionCfg(control_part="arm")
+
+# 2. Build the engine — action order matches target_list order
+engine = AtomicActionEngine(
+ motion_generator=motion_gen,
+ actions_cfg_list=[pickup_cfg, place_cfg, move_cfg],
+)
+
+# 3. Describe the object to pick
+semantics = ObjectSemantics(
+ label="mug",
+ geometry={"mesh_vertices": ..., "mesh_triangles": ...},
+ affordance=AntipodalAffordance(object_label="mug", ...),
+ entity=mug,
+)
+
+# 4. Plan the full sequence and replay
+is_success, traj = engine.execute_static(
+ target_list=[semantics, place_pose, rest_pose]
+)
+# traj: (n_envs, n_waypoints, dof)
+```
+
+---
+
+## How to Extend: Adding a Custom Action
+
+You can add any motion primitive by subclassing `AtomicAction` and registering it with the engine.
+
+### Step 1 — Define the config
+
+```python
+from embodichain.utils import configclass
+from embodichain.lab.sim.atomic_actions import ActionCfg
+
+@configclass
+class PushActionCfg(ActionCfg):
+ name: str = "push"
+ push_distance: float = 0.05 # metres to push forward
+ push_speed: int = 30 # waypoints for the push phase
+```
+
+### Step 2 — Implement the action
+
+```python
+import torch
+from typing import Optional, Union
+from embodichain.lab.sim.atomic_actions import AtomicAction, ObjectSemantics
+from embodichain.lab.sim.planners import PlanState, MoveType
+
+class PushAction(AtomicAction):
+ def __init__(self, motion_generator, cfg: PushActionCfg | None = None):
+ super().__init__(motion_generator, cfg=cfg or PushActionCfg())
+ self.arm_joint_ids = self.robot.get_joint_ids(name=self.cfg.control_part)
+
+ def execute(
+ self,
+ target: Union[torch.Tensor, ObjectSemantics],
+ start_qpos: Optional[torch.Tensor] = None,
+ **kwargs,
+ ) -> tuple[bool, torch.Tensor, list]:
+ # Resolve target to a batched [n_envs, 4, 4] EEF pose
+ # ... your planning logic here ...
+ return is_success, trajectory, self.arm_joint_ids
+
+ def validate(self, target, start_qpos=None, **kwargs) -> bool:
+ return True # add IK check here if needed
+```
+
+### Step 3 — Register and use
+
+```python
+from embodichain.lab.sim.atomic_actions import register_action
+
+register_action("push", PushAction, PushActionCfg)
+
+engine = AtomicActionEngine(
+ motion_generator=motion_gen,
+ actions_cfg_list=[PushActionCfg(push_distance=0.08)],
+)
+is_success, traj = engine.execute_static(target_list=[target_pose])
+```
+
+> **Tip:** The `execute()` return signature is always `(is_success, trajectory, joint_ids)`.
+> `trajectory` has shape `(n_envs, n_waypoints, len(joint_ids))`.
+> `joint_ids` tells the engine which columns of the full robot DOF vector the trajectory covers.
+
+---
+
+## Target Resolution
+
+`AtomicActionEngine` accepts several target formats in `target_list`, giving you flexibility without boilerplate:
+
+| Input type | Resolved to |
+|---|---|
+| `torch.Tensor (4,4)` or `(n_envs,4,4)` | EEF pose, broadcast across envs |
+| `ObjectSemantics` | Passed directly to the action |
+| `str` (object label) | Looked up in `SemanticAnalyzer` cache |
+| `dict` with `"pose"` key | Unwrapped to tensor |
+| `dict` with `"label"` key | Analyzed via `SemanticAnalyzer` |
+
+---
+
+## Further Reading
+
+- {doc}`planners/motion_generator` — the trajectory planner used by every action
+- {doc}`sim_robot` — how control parts and IK solvers are configured
+- Tutorial: `scripts/tutorials/sim/atomic_actions.py`
diff --git a/docs/source/overview/sim/index.rst b/docs/source/overview/sim/index.rst
index 56f98ef2..60cdfd56 100644
--- a/docs/source/overview/sim/index.rst
+++ b/docs/source/overview/sim/index.rst
@@ -22,3 +22,4 @@ Overview of the Simulation Framework:
sim_sensor.md
solvers/index
planners/index
+ atomic_actions.md
diff --git a/docs/source/overview/sim/sim_manager.md b/docs/source/overview/sim/sim_manager.md
index b7d86691..5897dfd0 100644
--- a/docs/source/overview/sim/sim_manager.md
+++ b/docs/source/overview/sim/sim_manager.md
@@ -33,9 +33,7 @@ sim_config = SimulationManagerCfg(
| `width` | `int` | `1920` | The width of the simulation window. |
| `height` | `int` | `1080` | The height of the simulation window. |
| `headless` | `bool` | `False` | Whether to run the simulation in headless mode (no Window). |
-| `enable_rt` | `bool` | `False` | Whether to enable ray tracing rendering. |
-| `enable_denoiser` | `bool` | `True` | Whether to enable denoising for ray tracing rendering. |
-| `spp` | `int` | `64` | Samples per pixel for ray tracing rendering. Only valid when ray tracing is enabled and denoiser is False. |
+| `render_cfg` | `RenderCfg` | `RenderCfg()` | The rendering configuration parameters. |
| `gpu_id` | `int` | `0` | The gpu index that the simulation engine will be used. Affects gpu physics device. |
| `thread_mode` | `ThreadMode` | `RENDER_SHARE_ENGINE` | The threading mode for the simulation engine. |
| `cpu_num` | `int` | `1` | The number of CPU threads to use for the simulation engine. |
@@ -60,6 +58,29 @@ The {class}`~cfg.PhysicsCfg` class controls the global physics simulation parame
For more parameters and details, refer to the [PhysicsCfg](https://dexforce.github.io/EmbodiChain/api_reference/embodichain/embodichain.lab.sim.html#embodichain.lab.sim.cfg.PhysicsCfg) documentation.
+### Render Configuration
+
+The {class}`~cfg.RenderCfg` class controls the rendering backend and quality settings.
+
+| Parameter | Type | Default | Description |
+| :--- | :--- | :--- | :--- |
+| `renderer` | `str` | `"hybrid"` | Renderer backend to use. Options are `'hybrid'` (ray tracing for shadows/reflections + rasterization), `'fast-rt'` (full ray tracing), and `'rt'` (offline ray-traced renderer for maximum visual fidelity). |
+| `enable_denoiser` | `bool` | `True` | Whether to enable denoising. Only valid when `renderer` is `'hybrid'`, `'fast-rt'` or `'rt'`. |
+| `spp` | `int` | `64` | Samples per pixel for ray tracing rendering. Only valid when `renderer` is `'hybrid'`, `'fast-rt'` or `'rt'` and `enable_denoiser` is `False`. |
+
+```python
+from embodichain.lab.sim import SimulationManagerCfg
+from embodichain.lab.sim.cfg import RenderCfg
+
+sim_config = SimulationManagerCfg(
+ render_cfg=RenderCfg(
+ renderer="fast-rt", # Use full ray tracing
+ enable_denoiser=True, # Enable denoising
+ spp=64, # Samples per pixel (used when denoiser is off)
+ )
+)
+```
+
## Initialization
diff --git a/docs/source/overview/sim/sim_rigid_object.md b/docs/source/overview/sim/sim_rigid_object.md
index af636ab2..185a533d 100644
--- a/docs/source/overview/sim/sim_rigid_object.md
+++ b/docs/source/overview/sim/sim_rigid_object.md
@@ -110,9 +110,12 @@ Rigid objects are observed and controlled via single poses and linear/angular ve
| `get_local_pose(to_matrix=False)` | `(N, 7)` or `(N, 4, 4)` | Get object local pose as (x, y, z, qw, qx, qy, qz) or 4x4 matrix per environment. |
| `set_local_pose(pose, env_ids=None)` | `pose: (N, 7)` or `(N, 4, 4)` | Teleport object to given pose (requires calling `sim.update()` to apply). |
| `body_data.pose` | `(N, 7)` | Access object pose directly (for dynamic/kinematic bodies). |
-| `body_data.lin_vel` | `(N, 3)` | Access linear velocity of object root (for dynamic/kinematic bodies). |
-| `body_data.ang_vel` | `(N, 3)` | Access angular velocity of object root (for dynamic/kinematic bodies). |
+| `body_data.lin_vel` | `(N, 3)` | Access linear velocity of object root (for dynamic bodies). |
+| `body_data.ang_vel` | `(N, 3)` | Access angular velocity of object root (for dynamic bodies). |
| `body_data.vel` | `(N, 6)` | Concatenated linear and angular velocities. |
+| `body_data.lin_acc` | `(N, 3)` | Access linear acceleration of object root (for dynamic bodies). |
+| `body_data.ang_acc` | `(N, 3)` | Access angular acceleration of object root (for dynamic bodies). |
+| `body_data.acc` | `(N, 6)` | Concatenated linear and angular accelerations. |
| `body_data.com_pose` | `(N, 7)` | Get center of mass pose of rigid bodies. |
| `body_data.default_com_pose` | `(N, 7)` | Default center of mass pose. |
| `body_state` | `(N, 13)` | Get full body state: [x, y, z, qw, qx, qy, qz, lin_x, lin_y, lin_z, ang_x, ang_y, ang_z]. |
diff --git a/docs/source/overview/sim/solvers/srs_solver.md b/docs/source/overview/sim/solvers/srs_solver.md
index 3cabb57e..2b26ee6d 100644
--- a/docs/source/overview/sim/solvers/srs_solver.md
+++ b/docs/source/overview/sim/solvers/srs_solver.md
@@ -51,7 +51,7 @@ cfg = SRSSolverCfg(
end_link_name="left_ee",
root_link_name="left_arm_base",
dh_params=arm_params.dh_params,
- qpos_limits=arm_params.qpos_limits,
+ user_qpos_limit=arm_params.qpos_limits,
T_e_oe=arm_params.T_e_oe,
T_b_ob=arm_params.T_b_ob,
link_lengths=arm_params.link_lengths,
diff --git a/docs/source/quick_start/docs.md b/docs/source/quick_start/docs.md
index c62a3d71..1a8aef4d 100644
--- a/docs/source/quick_start/docs.md
+++ b/docs/source/quick_start/docs.md
@@ -10,9 +10,47 @@ pip install -r docs/requirements.txt
## 2. Build the HTML site
+### Local development (current version only)
+
```bash
cd docs
-make html
+make current-docs
```
Then you can preview the documentation in your browser at `docs/build/html/index.html`.
+
+### Multi-version docs (CI/production)
+
+The production docs site hosts multiple versions side by side. Each version is built independently into its own subdirectory under `docs/build/html/`:
+
+```
+docs/build/html/
+├── index.html # Redirect → latest stable
+├── versions.json # Version manifest for the sidebar selector
+├── main/ # Dev docs (latest main branch)
+├── v0.1.3/ # Release docs
+└── v0.1.2/ # Release docs
+```
+
+To build a specific version into this layout:
+
+```bash
+cd docs
+sphinx-build source build/html/
+```
+
+For example, to build the `main` branch docs:
+
+```bash
+sphinx-build source build/html/main
+```
+
+Then generate the version manifest and root redirect:
+
+```bash
+python3 scripts/generate_versions_json.py --build-dir build/html
+```
+
+This generates both `versions.json` (for the sidebar version selector) and `index.html` (redirects to the latest stable version, falling back to `main`).
+
+> Old release versions beyond `DOCS_MAX_VERSIONS` (default: 4) are automatically pruned during CI builds.
diff --git a/docs/source/quick_start/install.md b/docs/source/quick_start/install.md
index 4c655f2e..49aed084 100644
--- a/docs/source/quick_start/install.md
+++ b/docs/source/quick_start/install.md
@@ -2,81 +2,91 @@
## System Requirements
-The following minimum system requirements are recommended to run EmbodiChain reliably. These are the tested configurations during development — other Linux distributions and versions may work but are not officially supported.
+| Component | Requirement |
+|-----------|------------|
+| **OS** | Linux (x86_64): Ubuntu 20.04+ |
+| **GPU** | NVIDIA with compute capability 7.0+ |
+| **NVIDIA Driver** | 535 - 570 (580+ is untested and may be unstable) |
+| **Python** | 3.10 or 3.11 |
-- Operating System:
- - Linux (x86_64): Ubuntu 20.04+
-
-- NVIDIA GPU and drivers:
- - Hardware: NVIDIA GPU with compute capability 7.0 or higher
- - NVIDIA Driver: 535 or higher (recommended 570)
-
-
-- Python:
- - 3.10
- - 3.11
-
-Notes:
+> [!NOTE]
+> Ensure your NVIDIA driver is compatible with your chosen PyTorch wheel. We recommend installing PyTorch from the [official PyTorch instructions](https://pytorch.org/get-started/locally/) for your CUDA version.
-- Ensure your NVIDIA driver is compatible with your chosen PyTorch wheel.
-- We recommend installing PyTorch from the official PyTorch instructions for your CUDA version: https://pytorch.org/get-started/locally/
+## Installation
----
+### Docker (Recommended)
-### Recommended: Install with Docker
+We strongly recommend using our pre-configured Docker environment, which contains all necessary dependencies including CUDA, Vulkan, and GPU rendering support.
-We strongly recommend using our pre-configured Docker environment, which contains all necessary dependencies.
+**1. Pull the image:**
```bash
docker pull dexforce/embodichain:ubuntu22.04-cuda12.8
```
-After pulling the Docker image, you can run a container with the provided [scripts](../../../docker/docker_run.sh).
+**2. Start a container:**
+
+Use the provided run script ([`docker/docker_run.sh`](../../../docker/docker_run.sh)), which handles GPU driver and Vulkan mounting:
```bash
-./docker_run.sh [container_name] [data_path]
+./docker/docker_run.sh
```
----
+### uv (Recommended for local development)
+> [!TIP]
+> [uv](https://github.com/astral-sh/uv) is an extremely fast Python package manager and project manager. We recommend using `uv` for local development due to its significantly faster dependency resolution and installation times compared to pip.
-### Install EmbodiChain
+**Install uv:**
-> **We strongly recommend using a virtual environment to avoid dependency conflicts.**
+```bash
+curl -LsSf https://astral.sh/uv/install.sh | sh
+```
-To install EmbodiChain from pypi, run:
+**Install from PyPI:**
```bash
-pip install embodichain --extra-index-url http://pyp.open3dv.site:2345/simple/ --trusted-host pyp.open3dv.site
-
-# Or install with the lerobot extras:
-pip install embodichain[lerobot] --extra-index-url http://pyp.open3dv.site:2345/simple/ --trusted-host pyp.open3dv.site
+uv pip install embodichain --extra-index-url http://pyp.open3dv.site:2345/simple/ --trusted-host pyp.open3dv.site
```
-To install the Embodichain from source, clone the EmbodiChain repository:
+**Install from source (editable mode):**
+
```bash
git clone https://github.com/DexForce/EmbodiChain.git
+cd EmbodiChain
+uv pip install -e . --extra-index-url http://pyp.open3dv.site:2345/simple/ --trusted-host pyp.open3dv.site
+```
+
+### pip (PyPI)
+
+> [!TIP]
+> We strongly recommend using a virtual environment to avoid dependency conflicts.
+
+```bash
+pip install embodichain --extra-index-url http://pyp.open3dv.site:2345/simple/ --trusted-host pyp.open3dv.site
```
-Install the project in development mode:
+### From Source
+
+> [!TIP]
+> We strongly recommend using a virtual environment to avoid dependency conflicts.
```bash
+git clone https://github.com/DexForce/EmbodiChain.git
+cd EmbodiChain
pip install -e . --extra-index-url http://pyp.open3dv.site:2345/simple/ --trusted-host pyp.open3dv.site
-
-# Or install with the lerobot extras:
-pip install -e .[lerobot] --extra-index-url http://pyp.open3dv.site:2345/simple/ --trusted-host pyp.open3dv.site
```
-> [!NOTE]
-> * [LeRobot](https://huggingface.co/docs/lerobot/installation) is an optional module for EmbodiChain that provides data saving and loading functionalities for robot learning tasks. Installing with the `lerobot` extras will include this module and its dependencies.
+## Verify Installation
-### Verify Installation
-To verify that EmbodiChain is installed correctly, run a simple demo script to create a simulation scene:
+Run the demo script to confirm everything is set up correctly:
```bash
python scripts/tutorials/sim/create_scene.py
+```
+
+If the installation is successful, you will see a simulation window with a rendered scene. To run without a display:
-# Or run in headless mode.
+```bash
python scripts/tutorials/sim/create_scene.py --headless
```
----
diff --git a/docs/source/resources/roadmap.md b/docs/source/resources/roadmap.md
index 899a9c56..c4870a9e 100644
--- a/docs/source/resources/roadmap.md
+++ b/docs/source/resources/roadmap.md
@@ -1,36 +1,92 @@
# Roadmap
-Currently, EmbodiChain is under active development. Our roadmap includes the following planned features and enhancements:
-
-- Simulation:
- - Rendering:
- - Improve ray-tracing backend performance and fix some konwn issues.
- - Add a high performance Hybrid rendering backend for better visual quality and speed trade-off.
- - Support a more efficient real-time denoiser.
- - Add a new rasterization backend for basic rendering tasks.
- - Physics:
- - Improve GPU physics throughput.
- - We are working on research and development of next-generation physics backend, supporting high-accuracy simulation, differentiable dynamics, and neural physical models for end-to-end AI integration.
- - Sensors:
- - Add more physical sensors (eg, force sensor) with examples.
- - Motion Generation:
- - Add more advanced motion generation methods with examples.
- - Useful Tools:
- - We are working on USD support for EmbodiChain to enable better asset management and interoperability.
- - Robots Integration:
- - Add support for more robot models (eg: LeRobot, Unitree H1/G1, etc).
-
-- Data Pipeline Coming Soon:
- - We will release a Real2Sim pipeline, which enables automatic data generation and scaling from real-world seeding priors.
- - We will release an agentic skill generation framework for automated expert trajectory generation.
- - Add assets and scenes generator and the integration with data pipeline.
-
-- Models & Training Infrastructure Coming Soon:
- - We will release a modular VLA framework for fast prototyping and training of embodied agents.
- - Add online data streaming pipeline for model training.
-
-- Embodied Tasks Coming Soon:
- - Add more benchmark tasks for EmbodiChain.
- - Add more tasks with reinforcement learning support.
- - Add a set of manipulation tasks for demonstration of data generation pipeline.
-
\ No newline at end of file
+EmbodiChain is in alpha and under active development. This roadmap summarizes
+the main areas we are improving and the capabilities planned for upcoming
+releases.
+
+The roadmap is organized by product area so new work can be added without
+changing the whole page. Each item should be short, user-facing, and grouped
+under the area it improves.
+
+## Status Legend
+
+| Marker | Status | Meaning |
+| --- | --- | --- |
+| 🚧 | In progress | Work is actively being designed, implemented, or validated. |
+| 📌 | Planned | Work is on the project roadmap but not yet released. |
+| 🔬 | Research | Work is exploratory and may change as the technical approach matures. |
+
+## Simulation
+
+### Rendering
+
+| Status | Planned capability |
+| --- | --- |
+| 🚧 | Support a more efficient real-time denoiser. |
+| 🔬 | Add 3DGS support for rendering and data generation. |
+
+### Physics
+
+| Status | Planned capability |
+| --- | --- |
+| 🔬 | Develop a next-generation physics backend with high-accuracy simulation, differentiable dynamics, and neural physical models for end-to-end AI integration. |
+
+### Sensors
+
+| Status | Planned capability |
+| --- | --- |
+| 📌 | Add more physical sensor models, such as force sensors, with runnable examples. |
+
+### Motion Generation
+
+| Status | Planned capability |
+| --- | --- |
+| 📌 | Add more advanced motion generation methods with examples. |
+
+### Robot Integration
+
+| Status | Planned capability |
+| --- | --- |
+| 📌 | Add support for more robot models, including LeRobot and Unitree H1/G1. |
+
+## Data Pipeline
+
+| Status | Planned capability |
+| --- | --- |
+| 📌 | Release a Real2Sim pipeline for automatic data generation and scaling from real-world seeding priors. |
+| 📌 | Release an agentic skill generation framework for automated expert trajectory generation. |
+| 📌 | Release a sim-ready asset and scene-layout generation framework for fast environment prototyping. |
+
+## Models and Training Infrastructure
+
+| Status | Planned capability |
+| --- | --- |
+| 📌 | Release a modular VLA framework for fast prototyping and training of embodied agents. |
+
+## Embodied Tasks
+
+| Status | Planned capability |
+| --- | --- |
+| 📌 | Add more benchmark tasks for EmbodiChain. |
+| 📌 | Add more tasks with reinforcement learning support. |
+| 📌 | Add manipulation tasks that demonstrate the data generation pipeline. |
+
+## Extending This Roadmap
+
+When adding roadmap items:
+
+- Add the item under the closest existing area before creating a new section.
+- Use one row per user-facing capability.
+- Keep status markers limited to the status legend above unless the legend is
+ updated at the same time.
+- Prefer concrete outcomes over implementation details.
+
+New sections should follow this template:
+
+```md
+## Area Name
+
+| Status | Planned capability |
+| --- | --- |
+| 📌 | Describe the capability and the user-facing outcome. |
+```
diff --git a/docs/source/resources/task/index.rst b/docs/source/resources/task/index.rst
index 998f6614..1c65e7e1 100644
--- a/docs/source/resources/task/index.rst
+++ b/docs/source/resources/task/index.rst
@@ -6,6 +6,5 @@ Supported Tasks
.. toctree::
:maxdepth: 1
- Push Cube
Pour Water
diff --git a/docs/source/tutorial/atomic_actions.rst b/docs/source/tutorial/atomic_actions.rst
new file mode 100644
index 00000000..10b8e97c
--- /dev/null
+++ b/docs/source/tutorial/atomic_actions.rst
@@ -0,0 +1,170 @@
+.. _tutorial_atomic_actions:
+
+Atomic Actions
+==============
+
+EmbodiChain's **atomic action** layer provides a high-level, composable interface for common
+manipulation primitives such as *move*, *pick up*, and *place*. Each action encapsulates the
+full planning pipeline — grasp-pose estimation, IK, trajectory generation, and gripper
+interpolation — behind a single ``execute()`` call, making it straightforward to chain
+multiple actions together into complex robot behaviours.
+
+Key Features
+------------
+
+- **Semantic-aware execution** — actions accept either a raw pose tensor or an
+ ``ObjectSemantics`` descriptor that bundles affordance data (grasp poses, interaction
+ points) with the simulation entity.
+- **Three built-in primitives** — ``MoveAction``, ``PickUpAction``, and ``PlaceAction``
+ cover the most common tabletop manipulation workflows out of the box.
+ See the :ref:`supported_atomic_actions` table for configs and target types.
+- **Extensible registry** — custom actions can be registered globally with
+ ``register_action`` and discovered by the engine at runtime.
+- **Engine orchestration** — ``AtomicActionEngine`` sequences multiple actions,
+ threads ``start_qpos`` from one action to the next, and returns a single concatenated
+ trajectory ready to replay in the simulator.
+
+For the full design overview, architecture diagram, and extension guide see
+:doc:`/overview/sim/atomic_actions`.
+
+The Code
+--------
+
+The tutorial corresponds to the ``atomic_actions.py`` script in the ``scripts/tutorials/sim``
+directory.
+
+.. dropdown:: Code for atomic_actions.py
+ :icon: code
+
+ .. literalinclude:: ../../../scripts/tutorials/sim/atomic_actions.py
+ :language: python
+ :linenos:
+
+Typical Usage
+-------------
+
+Setting up the engine
+~~~~~~~~~~~~~~~~~~~~~
+
+.. code-block:: python
+
+ import torch
+ from embodichain.lab.sim.planners import MotionGenerator, MotionGenCfg
+ from embodichain.lab.sim.atomic_actions import (
+ AtomicActionEngine,
+ PickUpActionCfg,
+ PlaceActionCfg,
+ MoveActionCfg,
+ )
+
+ motion_gen = MotionGenerator(cfg=MotionGenCfg(...))
+
+ hand_open = torch.tensor([0.00, 0.00], dtype=torch.float32, device=device)
+ hand_close = torch.tensor([0.025, 0.025], dtype=torch.float32, device=device)
+
+ pickup_cfg = PickUpActionCfg(
+ hand_open_qpos=hand_open,
+ hand_close_qpos=hand_close,
+ control_part="arm",
+ hand_control_part="hand",
+ approach_direction=torch.tensor([0.0, 0.0, -1.0], dtype=torch.float32, device=device),
+ pre_grasp_distance=0.15,
+ lift_height=0.15,
+ )
+ place_cfg = PlaceActionCfg(
+ hand_open_qpos=hand_open,
+ hand_close_qpos=hand_close,
+ control_part="arm",
+ hand_control_part="hand",
+ lift_height=0.15,
+ )
+ move_cfg = MoveActionCfg(control_part="arm")
+
+ engine = AtomicActionEngine(
+ motion_generator=motion_gen,
+ actions_cfg_list=[pickup_cfg, place_cfg, move_cfg],
+ )
+
+Defining object semantics
+~~~~~~~~~~~~~~~~~~~~~~~~~
+
+.. code-block:: python
+
+ from embodichain.lab.sim.atomic_actions import (
+ ObjectSemantics,
+ AntipodalAffordance,
+ )
+ from embodichain.toolkits.graspkit.pg_grasp import GraspGeneratorCfg, AntipodalSamplerCfg
+ from embodichain.toolkits.graspkit.pg_grasp.gripper_collision_checker import GripperCollisionCfg
+
+ affordance = AntipodalAffordance(
+ object_label="mug",
+ force_reannotate=False,
+ custom_config={
+ "gripper_collision_cfg": GripperCollisionCfg(
+ max_open_length=0.088, finger_length=0.078, point_sample_dense=0.012
+ ),
+ "generator_cfg": GraspGeneratorCfg(
+ antipodal_sampler_cfg=AntipodalSamplerCfg(
+ n_sample=20000, max_length=0.088, min_length=0.003
+ )
+ ),
+ },
+ )
+
+ semantics = ObjectSemantics(
+ label="mug",
+ geometry={
+ "mesh_vertices": mug.get_vertices(env_ids=[0], scale=True)[0],
+ "mesh_triangles": mug.get_triangles(env_ids=[0])[0],
+ },
+ affordance=affordance,
+ entity=mug, # required so the action can query the live object pose
+ )
+
+Executing a pick-place-move sequence
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+.. code-block:: python
+
+ place_xpos = ... # torch.Tensor [4, 4] — target placement pose
+ rest_xpos = ... # torch.Tensor [4, 4] — resting pose after placing
+
+ is_success, trajectory = engine.execute_static(
+ target_list=[semantics, place_xpos, rest_xpos]
+ )
+ # trajectory: [n_envs, n_waypoints, robot_dof]
+
+ for i in range(trajectory.shape[1]):
+ robot.set_qpos(trajectory[:, i])
+ sim.update(step=4)
+
+Registering custom actions
+~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+.. code-block:: python
+
+ from embodichain.lab.sim.atomic_actions import AtomicAction, ActionCfg, register_action
+
+ class PushAction(AtomicAction):
+ def execute(self, target, start_qpos=None, **kwargs):
+ # ... your planning logic ...
+ return is_success, trajectory, joint_ids
+
+ def validate(self, target, start_qpos=None, **kwargs):
+ return True # quick feasibility check
+
+ register_action("push", PushAction)
+
+Notes & Best Practices
+----------------------
+
+- ``PickUpAction`` expects an ``AntipodalAffordance`` with valid mesh data
+ (``mesh_vertices`` / ``mesh_triangles``) so the grasp generator can annotate the object.
+ Set ``force_reannotate=False`` (the default) to reuse cached annotations across episodes.
+- ``ObjectSemantics.entity`` must be set when using semantic targets so the action can read
+ the object's current world pose at planning time.
+- For static (non-physics) playback, iterate over ``trajectory[:, i]`` and call
+ ``robot.set_qpos`` directly; for physics-enabled playback, feed waypoints through your
+ controller or gym wrapper instead.
+- To add a new action type, see :doc:`/overview/sim/atomic_actions`.
diff --git a/docs/source/tutorial/basic_env.rst b/docs/source/tutorial/basic_env.rst
index 6de0c48b..443fbe97 100644
--- a/docs/source/tutorial/basic_env.rst
+++ b/docs/source/tutorial/basic_env.rst
@@ -182,3 +182,14 @@ This tutorial showcases several important features of EmbodiChain environments:
4. **Custom Objects**: Adding and manipulating scene objects
5. **Flexible Actions**: Customizable action spaces and execution methods
6. **Extensible Observations**: Adding task-specific observation data
+
+.. tip::
+ **Using an AI coding agent?** Once you're ready to create your own task environment, use the **/add-task-env** skill to scaffold the file with the correct structure, ``@register_env`` decorator, base class methods, and test stub. Use **/add-test** to write tests and **/pre-commit-check** to verify everything passes CI before committing.
+
+Next Steps
+~~~~~~~~~~
+
+- :doc:`modular_env` — Build advanced config-driven environments with ``EmbodiedEnv``
+- :doc:`rl` — Train RL agents with PPO or GRPO
+- :doc:`/overview/gym/env` — Full environment architecture and manager reference
+- :doc:`/guides/custom_functors` — Write custom observation, reward, and event functors
diff --git a/docs/source/tutorial/create_scene.rst b/docs/source/tutorial/create_scene.rst
index 244bd932..da13d5ec 100644
--- a/docs/source/tutorial/create_scene.rst
+++ b/docs/source/tutorial/create_scene.rst
@@ -89,3 +89,12 @@ You can also pass arguments to customize the simulation. For example, to run in
python scripts/tutorials/sim/create_scene.py --headless --num_envs --device
Now that we have a basic understanding of how to create a scene, let's move on to more advanced topics.
+
+Next Steps
+~~~~~~~~~~
+
+- :doc:`create_softbody` — Add deformable bodies to your scene
+- :doc:`robot` — Load and control a robot
+- :doc:`sensor` — Add cameras and capture sensor data
+- :doc:`basic_env` — Create your first Gymnasium environment
+- :doc:`/overview/sim/sim_manager` — Full SimulationManager API reference
diff --git a/docs/source/tutorial/data_generation.rst b/docs/source/tutorial/data_generation.rst
new file mode 100644
index 00000000..ca994f3d
--- /dev/null
+++ b/docs/source/tutorial/data_generation.rst
@@ -0,0 +1,189 @@
+.. _tutorial_data_generation:
+
+Data Generation
+===============
+
+.. currentmodule:: embodichain.lab.gym
+
+This tutorial shows how to generate synthetic expert demonstration datasets using EmbodiChain's built-in environment rollout and dataset manager. You will learn how to configure LeRobot recording in ``gym_config.json``, how ``run_env.py`` builds an environment from configuration files, and how completed episodes are automatically saved to disk.
+
+Overview
+~~~~~~~~
+
+EmbodiChain provides a built-in data generation workflow for imitation-learning and manipulation tasks:
+
+- **Gym Configuration**: Describes the scene, robot, sensors, randomization events, observations, dataset recorder, and rollout settings.
+- **Action Configuration**: Describes the task-specific expert action graph for tasks that use the action bank.
+- **Environment Rollout**: Builds the environment directly from configuration files and executes offline generation.
+- **Expert Policy**: Each task provides ``create_demo_action_list()`` or another scripted policy entry to generate expert actions.
+- **Dataset Manager**: Records observation-action pairs during ``env.step()``.
+- **LeRobotRecorder**: Converts completed episodes into LeRobot-compatible datasets, with optional video export.
+
+What This Tutorial Records
+--------------------------
+
+This page documents the full path from task configuration to saved dataset:
+
+1. Prepare a task ``gym_config.json``.
+2. Prepare an ``action_config.json`` if the task uses the action bank.
+3. Launch the environment rollout with ``run-env``.
+4. Let the dataset manager automatically save completed episodes.
+
+Example Task
+------------
+
+As a concrete example, this tutorial uses a real action-bank task shipped in the repository:
+
+- ``configs/gym/pour_water/gym_config.json`` defines the simulation scene and dataset recording behavior.
+- ``configs/gym/pour_water/action_config.json`` defines the action-bank graph used to solve the task.
+
+The Code
+~~~~~~~~
+
+The tutorial corresponds to the ``run_env.py`` script in ``embodichain/lab/scripts``.
+
+.. dropdown:: Code for run_env.py
+ :icon: code
+
+ .. literalinclude:: ../../../embodichain/lab/scripts/run_env.py
+ :language: python
+ :linenos:
+
+
+The Code Explained
+~~~~~~~~~~~~~~~~~~
+
+The rollout script builds the environment from configuration, generates expert trajectories, executes them step by step, and relies on the dataset manager to auto-save valid episodes.
+
+Step 1: Prepare the Task Configuration
+--------------------------------------
+
+The first input to the pipeline is the task ``gym_config.json``. In the example below, the same file contains rollout settings, scene randomization, observations, dataset recording, and robot or sensor definitions.
+
+The rollout settings include the episode count:
+
+.. literalinclude:: ../../../configs/gym/pour_water/gym_config.json
+ :language: json
+ :lines: 2-4
+
+The dataset-related part looks like this:
+
+.. literalinclude:: ../../../configs/gym/pour_water/gym_config.json
+ :language: json
+ :lines: 261-281
+
+Important parameters are:
+
+- **max_episodes**: Number of rollout episodes generated by ``run_env.py``.
+- **max_episode_steps**: Maximum number of environment steps per episode.
+- **dataset.lerobot.params.robot_meta**: Robot metadata such as robot type and control frequency.
+- **dataset.lerobot.params.instruction**: Task language instruction stored together with the dataset.
+- **dataset.lerobot.params.extra**: Additional metadata such as scene type and task description.
+- **dataset.lerobot.params.use_videos**: Whether camera observations should be stored as videos.
+- **env.control_parts**: Controlled robot parts in the environment.
+
+
+In the current implementation, ``LeRobotRecorder`` stores robot state and action features such as ``observation.qpos``, ``observation.qvel``, ``observation.qf``, ``action``, and camera images when sensors are present.
+
+Step 2: Prepare the Action Configuration
+----------------------------------------
+
+For tasks that use the action bank, the second input is ``action_config.json``. This file defines the expert action graph consumed by ``create_demo_action_list()``. In the example below, the file is organized around ``scope``, ``node``, ``edge``, and ``sync``.
+
+.. dropdown:: Action bank structure in the example task Pour_Water
+ :icon: code
+
+ **Scope Configuration**
+
+ .. literalinclude:: ../../../configs/gym/pour_water/action_config.json
+ :language: json
+ :lines: 2-57
+
+ **Node Configuration**
+
+ .. literalinclude:: ../../../configs/gym/pour_water/action_config.json
+ :language: json
+ :lines: 96-177
+
+ **Edge Configuration**
+
+ .. literalinclude:: ../../../configs/gym/pour_water/action_config.json
+ :language: json
+ :lines: 763-790
+
+ **Synchronization**
+
+ .. literalinclude:: ../../../configs/gym/pour_water/action_config.json
+ :language: json
+ :lines: 906-932
+
+This structure defines the expert rollout as follows:
+
+- **Scope**: Defines controllable sub-graphs such as ``right_arm``, ``left_arm``, ``right_eef``, and ``left_eef``.
+- **Node**: Defines key poses, targets computed from object affordances, and IK-generated joint targets.
+- **Edge**: Defines executable transitions between nodes, including duration and execution function.
+- **Sync**: Defines execution order rules between independently configured sub-actions.
+
+Note: Action bank is not the only way to generate demonstrations. Depending on the task design, trajectories can also be produced by other scripted generation methods.
+
+Step 3: Launch the Environment Rollout
+--------------------------------------
+
+The rollout script parses command-line arguments, loads ``gym_config.json`` and ``action_config.json``, converts them into environment configuration objects, creates the environment instance, and then runs offline rollout for ``max_episodes`` episodes:
+
+.. literalinclude:: ../../../embodichain/lab/scripts/run_env.py
+ :language: python
+ :start-at: def cli():
+ :end-at: main(args, env, gym_config)
+
+Each rollout internally calls ``create_demo_action_list()``, validates the returned sequence, executes actions with ``env.step(action)``, and discards invalid rollouts by resetting with ``save_data=False``.
+
+The recommended CLI entrypoint is:
+
+.. code-block:: bash
+
+ python -m embodichain run-env \
+ --gym_config configs/gym/pour_water/gym_config.json \
+ --action_config configs/gym/pour_water/action_config.json \
+ --headless
+
+For interactive inspection, you can use preview mode: replace ``--headless`` with ``--preview``.
+When ``--preview`` is enabled, the script opens the environment in an interactive debugging mode. This mode is for inspection and does not save datasets.
+
+
+Useful CLI arguments:
+
+- **--gym_config**: Path to the task JSON configuration.
+- **--action_config**: Path to the action-bank configuration.
+- **--num_envs**: Number of environments to run in parallel.
+- **--device**: Simulation device, such as ``cpu`` or ``cuda``.
+- **--headless**: Run without GUI for faster generation.
+- **--enable_rt**: Enable ray tracing for higher-quality visual observations.
+- **--preview**: Launch the environment in interactive preview mode.
+- **--filter_dataset_saving**: Disable dataset saving for debugging.
+
+For the complete CLI argument list, see :doc:`CLI Reference `.
+
+Outputs
+~~~~~~~
+
+After successful execution, completed episodes are saved under the configured dataset root. A LeRobot dataset typically contains:
+
+If no explicit save path is provided and ``EMBODICHAIN_DATASET_ROOT`` is not set, ``LeRobotRecorder`` uses ``~/.cache/embodichain_datasets`` as the default dataset root.
+
+- **data/**: Recorded action and state data.
+- **videos/**: Camera observations saved as videos when ``use_videos=True``.
+- **meta/**: Dataset metadata such as task information and robot description.
+
+Dataset folders are automatically numbered, which makes it easy to run repeated generations without overwriting previous results.
+
+In a practical workflow, the output of this stage is the synthesized dataset itself. Later training scripts typically consume these saved LeRobot episodes instead of regenerating trajectories each time.
+
+Best Practices
+~~~~~~~~~~~~~~
+
+- **Keep the config pair together**: Version ``gym_config.json`` and ``action_config.json`` together for action-bank tasks.
+- **Use valid scripted policies**: Make sure ``create_demo_action_list()`` returns executable trajectories for the current scene.
+- **Use ``--headless`` for throughput**: Disable the GUI when generating large datasets.
+- **Use ``--preview`` and ``--filter_dataset_saving`` for debugging**: Inspect task logic without writing datasets.
+- **Discard invalid rollouts**: Keep the default validation logic so failed trajectories are not saved.
diff --git a/docs/source/tutorial/gizmo.rst b/docs/source/tutorial/gizmo.rst
index b0d39b2c..6f2a5b7a 100644
--- a/docs/source/tutorial/gizmo.rst
+++ b/docs/source/tutorial/gizmo.rst
@@ -213,7 +213,7 @@ Command-line options:
- ``--device cpu|cuda``: Choose simulation device
- ``--num_envs N``: Number of parallel environments
- ``--headless``: Run without GUI for automated testing
-- ``--enable_rt``: Enable ray tracing for better visuals
+- ``--renderer``: Enable ray tracing for better visuals
Once running:
diff --git a/docs/source/tutorial/index.rst b/docs/source/tutorial/index.rst
index ef6efe79..33b95a8b 100644
--- a/docs/source/tutorial/index.rst
+++ b/docs/source/tutorial/index.rst
@@ -1,6 +1,36 @@
Tutorials
=========
+These tutorials walk you through EmbodiChain step by step, from creating your first simulation scene to training RL agents. Each tutorial includes a complete runnable script and a line-by-line explanation.
+
+Suggested Learning Path
+~~~~~~~~~~~~~~~~~~~~~~~
+
+Follow the tutorials in this order for the best learning experience:
+
+**Phase 1: Simulation Basics**
+
+1. :doc:`create_scene` — Set up a simulation, add objects, and run the render loop. **Start here.**
+2. :doc:`create_softbody` and :doc:`create_cloth` — Add deformable bodies to your scenes.
+3. :doc:`rigid_object_group` — Manage collections of rigid objects efficiently.
+4. :doc:`robot` — Load and control a robot in simulation.
+5. :doc:`sensor` — Add cameras and capture RGB/depth/segmentation data.
+6. :doc:`solver` — Configure IK solvers for end-effector control.
+7. :doc:`motion_gen` — Generate smooth trajectories with motion planners.
+8. :doc:`atomic_actions` — Use built-in action primitives (pick, place, move).
+9. :doc:`gizmo` — Interactively control robots with on-screen gizmos.
+
+**Phase 2: Environments**
+
+10. :doc:`basic_env` — Create a simple Gymnasium environment with ``BaseEnv``. Prerequisite: Phase 1 basics.
+11. :doc:`modular_env` — Build a config-driven environment with ``EmbodiedEnv``, managers, and randomization. Prerequisite: :doc:`basic_env`.
+12. :doc:`data_generation` — Generate expert demonstration datasets for imitation learning. Prerequisite: :doc:`modular_env`.
+13. :doc:`rl` — Train RL agents with PPO or GRPO. Prerequisite: :doc:`basic_env`.
+
+**Phase 3: Extending the Framework**
+
+14. :doc:`add_robot` — Add a new robot model to EmbodiChain.
+
.. toctree::
:maxdepth: 1
:hidden:
@@ -14,8 +44,9 @@ Tutorials
solver
sensor
motion_gen
+ atomic_actions
gizmo
basic_env
modular_env
+ data_generation
rl
-
diff --git a/docs/source/tutorial/modular_env.rst b/docs/source/tutorial/modular_env.rst
index 356a7ac4..eef801c3 100644
--- a/docs/source/tutorial/modular_env.rst
+++ b/docs/source/tutorial/modular_env.rst
@@ -64,7 +64,7 @@ The ``randomize_table_mat`` event varies visual appearance:
- **Mode**: ``"interval"`` - triggers every 10 steps
- **Features**: Random textures from COCO dataset and base color variations
-for more randomization events, please refer
+For more randomization events, please refer to :doc:`/overview/gym/event_functors`.
Observation Configuration
-------------------------
@@ -235,3 +235,9 @@ This tutorial showcases the most advanced features of EmbodiChain environments:
This tutorial demonstrates the full power of EmbodiChain's modular environment system, providing the foundation for creating sophisticated robotic learning scenarios.
+
+.. tip::
+ **Using an AI coding agent?** These skills can help you build on this tutorial:
+
+ - **/add-task-env** — Scaffold a new task environment with the correct file structure, ``@register_env`` decorator, base class methods, ``__init__.py`` update, and test stub.
+ - **/add-functor** — Add observation, reward, event, or randomization functors with the correct signature and module placement.
\ No newline at end of file
diff --git a/docs/source/tutorial/rl.rst b/docs/source/tutorial/rl.rst
index 28054648..db1c7ab1 100644
--- a/docs/source/tutorial/rl.rst
+++ b/docs/source/tutorial/rl.rst
@@ -420,3 +420,11 @@ Best Practices
- **Checkpoints**: Regular checkpoints are saved to ``outputs//checkpoints/``. Use these to resume training or evaluate policies.
+See Also
+--------
+
+- :doc:`/overview/rl/index` — RL module architecture and component reference
+- :doc:`/overview/gym/env` — EmbodiedEnv configuration and Action Manager
+- :doc:`basic_env` — Creating basic Gymnasium environments
+- :doc:`modular_env` — Advanced modular environments with managers
+- :doc:`/resources/task/index` — List of available RL task environments
diff --git a/docs/source/tutorial/robot.rst b/docs/source/tutorial/robot.rst
index 8312ad27..c3a54ab5 100644
--- a/docs/source/tutorial/robot.rst
+++ b/docs/source/tutorial/robot.rst
@@ -116,7 +116,7 @@ You can customize the simulation with various command-line options:
python scripts/tutorials/sim/create_robot.py --headless
# Enable ray tracing rendering
- python scripts/tutorials/sim/create_robot.py --enable_rt
+ python scripts/tutorials/sim/create_robot.py --renderer
The simulation will show the robot moving through different poses, demonstrating basic joint control capabilities.
diff --git a/docs/source/tutorial/sensor.rst b/docs/source/tutorial/sensor.rst
index 1d5c4dc9..9119d1ea 100644
--- a/docs/source/tutorial/sensor.rst
+++ b/docs/source/tutorial/sensor.rst
@@ -89,7 +89,7 @@ You can customize the simulation with the following command-line options:
python scripts/tutorials/sim/create_sensor.py --headless
# Enable ray tracing rendering
- python scripts/tutorials/sim/create_sensor.py --enable_rt
+ python scripts/tutorials/sim/create_sensor.py --renderer
# Attach the camera to the robot end-effector
python scripts/tutorials/sim/create_sensor.py --attach_sensor
diff --git a/docs/source/tutorial/solver.rst b/docs/source/tutorial/solver.rst
index b3c95807..61300096 100644
--- a/docs/source/tutorial/solver.rst
+++ b/docs/source/tutorial/solver.rst
@@ -95,7 +95,7 @@ API Reference
"""Compute the Jacobian matrix for the given joint positions."""
- **set_ik_nearst_weight**: Set weights for IK nearest neighbor search.
-- **set_position_limits / get_position_limits**: Set or get joint position limits.
+- **set_qpos_limits / get_qpos_limits**: Set or get joint position limits.
- **set_tcp / get_tcp**: Set or get the tool center point (TCP) transformation.
Configuration
diff --git a/docs/sync_readme.py b/docs/sync_readme.py
index a3198b6e..67620ef2 100644
--- a/docs/sync_readme.py
+++ b/docs/sync_readme.py
@@ -3,6 +3,7 @@
Idempotent copy. Exit code 0 on success.
"""
+
import shutil
from pathlib import Path
import sys
diff --git a/embodichain/agents/datasets/online_data.py b/embodichain/agents/datasets/online_data.py
index ac359020..2f2a117f 100644
--- a/embodichain/agents/datasets/online_data.py
+++ b/embodichain/agents/datasets/online_data.py
@@ -24,7 +24,6 @@
from embodichain.agents.engine.data import OnlineDataEngine
from embodichain.agents.datasets.sampler import ChunkSizeSampler
-
__all__ = [
"OnlineDataset",
]
diff --git a/embodichain/agents/datasets/sampler.py b/embodichain/agents/datasets/sampler.py
index 464af009..70385484 100644
--- a/embodichain/agents/datasets/sampler.py
+++ b/embodichain/agents/datasets/sampler.py
@@ -20,7 +20,6 @@
from abc import ABC, abstractmethod
from typing import Callable, Iterator, List, Optional, Union
-
__all__ = [
"ChunkSizeSampler",
"UniformChunkSampler",
diff --git a/embodichain/agents/engine/data.py b/embodichain/agents/engine/data.py
index f25987ab..a836a013 100644
--- a/embodichain/agents/engine/data.py
+++ b/embodichain/agents/engine/data.py
@@ -25,6 +25,7 @@
from tensordict import TensorDict
from tqdm import tqdm
+from embodichain.lab.sim.cfg import RenderCfg
from embodichain.utils.logger import log_info, log_error
from embodichain.utils import configclass
@@ -60,6 +61,31 @@ class OnlineDataEngineCfg:
amortising the cost of environment simulation over many training steps.
"""
+ language_cfg: Union[dict, None] = None
+ """Language configuration for VLA training.
+
+ If provided, the shared buffer will include hierarchical language data fields
+ and the simulation subprocess will collect language descriptions during rollouts.
+
+ The configuration should include:
+ - mode: Storage mode ('tokens', 'embeddings', 'hybrid')
+ - hierarchy_levels: List of hierarchy levels ('task', 'subtask', 'primitive')
+ - max_tokens: Maximum sequence length per instruction
+ - tokenizer: Tokenizer identifier
+ - language_source: Source of language ('env', 'file', 'llm', 'template')
+ - language_config_path: Path to language descriptions (if source='file')
+
+ Example:
+ language_cfg = {
+ "mode": "tokens",
+ "hierarchy_levels": ["task", "subtask", "primitive"],
+ "max_tokens": 512,
+ "tokenizer": "gpt2",
+ "language_source": "file",
+ "language_config_path": "config/language/tasks.yaml",
+ }
+ """
+
# ---------------------------------------------------------------------------
# Subprocess entry point (module-level so it can be pickled by multiprocessing)
@@ -109,10 +135,19 @@ def _sim_worker_fn(
env_cfg = config_to_cfg(gym_config, manager_modules=DEFAULT_MANAGER_MODULES)
env_cfg.filter_dataset_saving = True
env_cfg.init_rollout_buffer = False
+
+ # Add language configuration if provided
+ if cfg.language_cfg is not None:
+ env_cfg.language = cfg.language_cfg
+ log_info(
+ f"[Simulation Process] Language configuration added: {cfg.language_cfg.get('mode', 'tokens')}, "
+ f"hierarchy={cfg.language_cfg.get('hierarchy_levels', ['task', 'subtask', 'primitive'])}"
+ )
+
env_cfg.sim_cfg = SimulationManagerCfg(
headless=gym_config.get("headless", True),
sim_device=gym_config.get("device", "cpu"),
- enable_rt=gym_config.get("enable_rt", True),
+ render_cfg=RenderCfg(renderer=gym_config.get("renderer", "hybrid")),
gpu_id=gym_config.get("gpu_id", 0),
)
@@ -363,6 +398,9 @@ def _create_buffer(self) -> TensorDict:
placed in CPU shared memory so it can be safely accessed from both the
main process and the simulation subprocess.
+ If language configuration is provided, the buffer will also include
+ hierarchical language data fields for VLA training.
+
Returns:
TensorDict in shared memory.
"""
@@ -379,6 +417,7 @@ def _create_buffer(self) -> TensorDict:
batch_size=self.cfg.buffer_size,
max_episode_steps=max_episode_steps,
state_dim=self.cfg.state_dim,
+ language_cfg=self.cfg.language_cfg,
)
if shared_td.device.type == "cpu":
diff --git a/embodichain/agents/rl/models/mlp.py b/embodichain/agents/rl/models/mlp.py
index f788dfed..459e08e3 100644
--- a/embodichain/agents/rl/models/mlp.py
+++ b/embodichain/agents/rl/models/mlp.py
@@ -22,7 +22,6 @@
import torch
import torch.nn as nn
-
ActivationName = Union[str, None]
diff --git a/embodichain/agents/rl/train.py b/embodichain/agents/rl/train.py
index fa1f5948..0c74843a 100644
--- a/embodichain/agents/rl/train.py
+++ b/embodichain/agents/rl/train.py
@@ -37,6 +37,7 @@
from embodichain.utils.utility import load_json
from embodichain.utils.module_utils import find_function_from_modules
from embodichain.lab.sim import SimulationManagerCfg
+from embodichain.lab.sim.cfg import RenderCfg
from embodichain.lab.gym.envs.managers.cfg import EventCfg
@@ -113,7 +114,7 @@ def train_from_config(config_path: str, distributed: bool | None = None):
save_freq = int(trainer_cfg.get("save_freq", 50000))
num_eval_episodes = int(trainer_cfg.get("num_eval_episodes", 5))
headless = bool(trainer_cfg.get("headless", True))
- enable_rt = bool(trainer_cfg.get("enable_rt", False))
+ renderer = trainer_cfg.get("renderer", "hybrid")
gpu_id = int(trainer_cfg.get("gpu_id", 0))
num_envs = trainer_cfg.get("num_envs", None)
wandb_project_name = trainer_cfg.get("wandb_project_name", "embodichain-generic")
@@ -205,13 +206,12 @@ def train_from_config(config_path: str, distributed: bool | None = None):
else:
gym_env_cfg.sim_cfg.sim_device = torch.device("cpu")
gym_env_cfg.sim_cfg.headless = headless
- gym_env_cfg.sim_cfg.enable_rt = enable_rt
- gym_env_cfg.sim_cfg.gpu_id = local_rank if distributed else gpu_id
+ gym_env_cfg.sim_cfg.render_cfg = RenderCfg(renderer=renderer)
+ gym_env_cfg.sim_cfg.gpu_id = gpu_id
- if rank == 0:
- logger.log_info(
- f"Loaded gym_config from {gym_config_path} (env_id={gym_config_data['id']}, num_envs={gym_env_cfg.num_envs}, headless={gym_env_cfg.sim_cfg.headless}, enable_rt={gym_env_cfg.sim_cfg.enable_rt}, sim_device={gym_env_cfg.sim_cfg.sim_device})"
- )
+ logger.log_info(
+ f"Loaded gym_config from {gym_config_path} (env_id={gym_config_data['id']}, num_envs={gym_env_cfg.num_envs}, headless={gym_env_cfg.sim_cfg.headless}, renderer={gym_env_cfg.sim_cfg.render_cfg.renderer}, sim_device={gym_env_cfg.sim_cfg.sim_device})"
+ )
env = build_env(gym_config_data["id"], base_env_cfg=gym_env_cfg)
sample_obs, _ = env.reset()
diff --git a/embodichain/agents/rl/utils/trainer.py b/embodichain/agents/rl/utils/trainer.py
index 56ea0db2..93d01acf 100644
--- a/embodichain/agents/rl/utils/trainer.py
+++ b/embodichain/agents/rl/utils/trainer.py
@@ -16,6 +16,7 @@
from __future__ import annotations
+from typing import Any, Dict
import time
import numpy as np
import torch
@@ -85,6 +86,11 @@ def __init__(
self.start_time = time.time()
self.ret_window = deque(maxlen=100)
self.len_window = deque(maxlen=100)
+ self.train_history: list[dict[str, float]] = []
+ self.eval_history: list[dict[str, float]] = []
+ self.last_eval_metrics: dict[str, float] = {}
+ self.last_train_metrics: dict[str, float] = {}
+ self.latest_checkpoint_path: str | None = None
num_envs = getattr(self.env, "num_envs", None)
if num_envs is None:
raise RuntimeError("Env must expose num_envs for trainer statistics.")
@@ -146,9 +152,9 @@ def _pack_log_dict(self, prefix: str, data: dict) -> dict:
continue
return out
- def train(self, total_timesteps: int):
+ def train(self, total_timesteps: int) -> Dict[str, Any]:
if self.rank == 0:
- logger.log_info(f"Start training, total steps: {total_timesteps}")
+ print(f"Start training, total steps: {total_timesteps}")
while self.global_step < total_timesteps:
self._collect_rollout()
losses = self.algorithm.update(self.buffer.get(flatten=False))
@@ -161,6 +167,7 @@ def train(self, total_timesteps: int):
self._eval_once(num_episodes=self.num_eval_episodes)
if self.global_step % self.save_freq == 0:
self.save_checkpoint()
+ return self.get_summary()
@torch.no_grad()
def _collect_rollout(self):
@@ -197,9 +204,10 @@ def on_step(tensordict: TensorDict, info: dict):
if log_dict and self.use_wandb:
wandb.log(log_dict, step=self.global_step)
+ rollout = self.buffer.start_rollout()
rollout = self.collector.collect(
num_steps=self.buffer_size,
- rollout=self.buffer.start_rollout(),
+ rollout=rollout,
on_step_callback=on_step,
)
self.buffer.add(rollout)
@@ -278,13 +286,23 @@ def _sync_episode_stats(self) -> None:
self.len_window.extend(all_len[start:])
def _log_train(self, losses: Dict[str, float]):
- if self.rank != 0:
- return
+ elapsed = max(1e-6, time.time() - self.start_time)
+ sps = self.global_step / elapsed
+ avgR = np.mean(self.ret_window) if len(self.ret_window) > 0 else float("nan")
+ avgL = np.mean(self.len_window) if len(self.len_window) > 0 else float("nan")
+ history_entry = {
+ "global_step": float(self.global_step),
+ "charts/SPS": float(sps),
+ "charts/episode_reward_avg_100": float(avgR),
+ "charts/episode_length_avg_100": float(avgL),
+ }
+ history_entry.update({f"train/{k}": float(v) for k, v in losses.items()})
+ self.train_history.append(history_entry)
+ self.last_train_metrics = history_entry
+
if self.writer:
for k, v in losses.items():
self.writer.add_scalar(f"train/{k}", v, self.global_step)
- elapsed = max(1e-6, time.time() - self.start_time)
- sps = self.global_step / elapsed
self.writer.add_scalar("charts/SPS", sps, self.global_step)
if len(self.ret_window) > 0:
self.writer.add_scalar(
@@ -298,26 +316,24 @@ def _log_train(self, losses: Dict[str, float]):
float(np.mean(self.len_window)),
self.global_step,
)
- # console
- sps = self.global_step / max(1e-6, time.time() - self.start_time)
- avgR = np.mean(self.ret_window) if len(self.ret_window) > 0 else float("nan")
- avgL = np.mean(self.len_window) if len(self.len_window) > 0 else float("nan")
- print(
- f"[train] step={self.global_step} sps={sps:.0f} avgReward(100)={avgR:.3f} avgLength(100)={avgL:.1f}"
- )
+ # console and external logging are rank-0 only in distributed mode.
+ if self.rank == 0:
+ print(
+ f"[train] step={self.global_step} sps={sps:.0f} avgReward(100)={avgR:.3f} avgLength(100)={avgL:.1f}"
+ )
- # wandb (mirror TB logs)
- if self.use_wandb:
- log_dict = {f"train/{k}": v for k, v in losses.items()}
- log_dict["charts/SPS"] = sps
- if not np.isnan(avgR):
- log_dict["charts/episode_reward_avg_100"] = float(avgR)
- if not np.isnan(avgL):
- log_dict["charts/episode_length_avg_100"] = float(avgL)
- wandb.log(log_dict, step=self.global_step)
+ # wandb (mirror TB logs)
+ if self.use_wandb:
+ log_dict = {f"train/{k}": v for k, v in losses.items()}
+ log_dict["charts/SPS"] = sps
+ if not np.isnan(avgR):
+ log_dict["charts/episode_reward_avg_100"] = float(avgR)
+ if not np.isnan(avgL):
+ log_dict["charts/episode_length_avg_100"] = float(avgL)
+ wandb.log(log_dict, step=self.global_step)
@torch.no_grad()
- def _eval_once(self, num_episodes: int = 5):
+ def _eval_once(self, num_episodes: int = 5) -> Dict[str, float]:
"""Run evaluation for specified number of episodes.
Each episode runs all parallel environments until completion, allowing
@@ -329,8 +345,11 @@ def _eval_once(self, num_episodes: int = 5):
self.policy.eval()
episode_returns = []
episode_lengths = []
+ episode_successes = []
+ metric_values: dict[str, list[float]] = {}
- self.eval_env.set_rollout_buffer(self.buffer.buffer)
+ # Evaluation does not consume the training rollout buffer; binding it here can
+ # overflow the shared RL buffer when eval episodes are longer than buffer_size.
for _ in range(num_episodes):
# Reset and initialize episode tracking
obs, _ = self.eval_env.reset()
@@ -372,6 +391,17 @@ def _eval_once(self, num_episodes: int = 5):
still_running = ~done_mask
cumulative_reward[still_running] += reward[still_running].float()
step_count[still_running] += 1
+ newly_done = done & (~done_mask)
+ if newly_done.any():
+ if isinstance(info, dict) and "success" in info:
+ successes = info["success"][newly_done].detach().cpu().tolist()
+ episode_successes.extend([float(v) for v in successes])
+ if isinstance(info, dict) and "metrics" in info:
+ for key, value in info["metrics"].items():
+ values = value[newly_done].detach().cpu().tolist()
+ metric_values.setdefault(key, []).extend(
+ [float(v) for v in values]
+ )
done_mask |= done
# Trigger evaluation events (e.g., video recording)
@@ -404,11 +434,44 @@ def _eval_once(self, num_episodes: int = 5):
self.writer.add_scalar(
"eval/avg_length", float(np.mean(episode_lengths)), self.global_step
)
+ if episode_successes:
+ self.writer.add_scalar(
+ "eval/success_rate",
+ float(np.mean(episode_successes)),
+ self.global_step,
+ )
- def save_checkpoint(self):
- if self.rank != 0:
- return
+ summary = {
+ "global_step": float(self.global_step),
+ "eval/avg_reward": (
+ float(np.mean(episode_returns)) if episode_returns else float("nan")
+ ),
+ "eval/avg_length": (
+ float(np.mean(episode_lengths)) if episode_lengths else float("nan")
+ ),
+ "eval/success_rate": (
+ float(np.mean(episode_successes)) if episode_successes else float("nan")
+ ),
+ }
+ for key, values in metric_values.items():
+ if values:
+ summary[f"eval/metrics/{key}"] = float(np.mean(values))
+ self.eval_history.append(summary)
+ self.last_eval_metrics = summary
+ if self.rank == 0 and self.use_wandb:
+ log_dict = {
+ key: value
+ for key, value in summary.items()
+ if key != "global_step" and not np.isnan(value)
+ }
+ if log_dict:
+ wandb.log(log_dict, step=self.global_step)
+ return summary
+
+ def save_checkpoint(self) -> str | None:
# minimal model-only checkpoint; trainer/algorithm states can be added
+ if self.rank != 0:
+ return None
path = f"{self.checkpoint_dir}/{self.exp_name}_step_{self.global_step}.pt"
policy_state = (
self.policy.module.state_dict()
@@ -422,4 +485,19 @@ def save_checkpoint(self):
},
path,
)
+ self.latest_checkpoint_path = path
print(f"Checkpoint saved: {path}")
+ return path
+
+ def get_summary(self) -> Dict[str, Any]:
+ elapsed = max(1e-6, time.time() - self.start_time)
+ return {
+ "global_step": int(self.global_step),
+ "elapsed_time_sec": float(elapsed),
+ "training_fps": float(self.global_step / elapsed),
+ "last_train_metrics": dict(self.last_train_metrics),
+ "last_eval_metrics": dict(self.last_eval_metrics),
+ "train_history": list(self.train_history),
+ "eval_history": list(self.eval_history),
+ "latest_checkpoint_path": self.latest_checkpoint_path,
+ }
diff --git a/embodichain/data/assets/eef_assets.py b/embodichain/data/assets/eef_assets.py
index b2644712..75918c7e 100644
--- a/embodichain/data/assets/eef_assets.py
+++ b/embodichain/data/assets/eef_assets.py
@@ -23,7 +23,6 @@
EMBODICHAIN_DEFAULT_DATA_ROOT,
)
-
eef_assets = "eef_assets"
diff --git a/embodichain/data/assets/materials.py b/embodichain/data/assets/materials.py
index 8243cb8a..ced7f82a 100644
--- a/embodichain/data/assets/materials.py
+++ b/embodichain/data/assets/materials.py
@@ -27,7 +27,6 @@
EMBODICHAIN_DEFAULT_DATA_ROOT,
)
-
material_assets = "materials"
diff --git a/embodichain/data/assets/obj_assets.py b/embodichain/data/assets/obj_assets.py
index e81fd252..89f28d0d 100644
--- a/embodichain/data/assets/obj_assets.py
+++ b/embodichain/data/assets/obj_assets.py
@@ -23,7 +23,6 @@
EMBODICHAIN_DEFAULT_DATA_ROOT,
)
-
obj_assets = "obj_assets"
diff --git a/embodichain/data/assets/robot_assets.py b/embodichain/data/assets/robot_assets.py
index 55cd17a7..f37cfd3a 100644
--- a/embodichain/data/assets/robot_assets.py
+++ b/embodichain/data/assets/robot_assets.py
@@ -23,7 +23,6 @@
EMBODICHAIN_DEFAULT_DATA_ROOT,
)
-
robot_assets = "robot_assets"
@@ -54,9 +53,9 @@ class CobotMagicArm(EmbodiChainDataset):
def __init__(self, data_root: str = None):
data_descriptor = o3d.data.DataDescriptor(
os.path.join(
- EMBODICHAIN_DOWNLOAD_PREFIX, robot_assets, "CobotMagicArmV2.zip"
+ EMBODICHAIN_DOWNLOAD_PREFIX, robot_assets, "CobotMagicArmV3.zip"
),
- "14af3e84b74193680899a59fc74e8337",
+ "12a249e231bfc2faf0fd55f9e2646b8d",
)
prefix = type(self).__name__
path = EMBODICHAIN_DEFAULT_DATA_ROOT if data_root is None else data_root
diff --git a/embodichain/data/assets/scene_assets.py b/embodichain/data/assets/scene_assets.py
index 5b7b90bb..751dc01a 100644
--- a/embodichain/data/assets/scene_assets.py
+++ b/embodichain/data/assets/scene_assets.py
@@ -23,7 +23,6 @@
EMBODICHAIN_DEFAULT_DATA_ROOT,
)
-
scene_assets = "scene_assets"
diff --git a/embodichain/lab/gym/envs/action_bank/configurable_action.py b/embodichain/lab/gym/envs/action_bank/configurable_action.py
index c0e7130d..964c2b1c 100644
--- a/embodichain/lab/gym/envs/action_bank/configurable_action.py
+++ b/embodichain/lab/gym/envs/action_bank/configurable_action.py
@@ -997,7 +997,7 @@ def get_xpos_name(affordance_name: str) -> str:
def get_control_part(env, agent_uid):
- control_parts = env.metadata["dataset"]["robot_meta"].get("control_parts", [])
+ control_parts = env.cfg.control_parts
if agent_uid in control_parts:
return agent_uid
@@ -1324,7 +1324,8 @@ def plan_trajectory(
if len(filtered_keyposes) == 1 and len(ref_poses) == 0:
- ret = np.array([filtered_keyposes[0]] * duration)
+ return np.array([filtered_keyposes[0]] * duration).T
+
else:
mo_gen = MotionGenerator(
cfg=MotionGenCfg(planner_cfg=ToppraPlannerCfg(robot_uid=env.robot.uid))
diff --git a/embodichain/lab/gym/envs/action_bank/utils.py b/embodichain/lab/gym/envs/action_bank/utils.py
index 8e7d149e..58cfb368 100644
--- a/embodichain/lab/gym/envs/action_bank/utils.py
+++ b/embodichain/lab/gym/envs/action_bank/utils.py
@@ -20,7 +20,6 @@
from embodichain.utils import logger
from embodichain.lab.gym.utils.misc import validation_with_process_from_name
-
"""Node Generation Utils"""
diff --git a/embodichain/lab/gym/envs/base_env.py b/embodichain/lab/gym/envs/base_env.py
index fcd89c98..1a0fa89e 100644
--- a/embodichain/lab/gym/envs/base_env.py
+++ b/embodichain/lab/gym/envs/base_env.py
@@ -239,8 +239,7 @@ def add_camera_group_id(self, group_id: int) -> None:
"""
if not hasattr(self, "_camera_group_ids"):
self._camera_group_ids: List[int] = []
- if self.sim.is_rt_enabled:
- self._camera_group_ids.append(group_id)
+ self._camera_group_ids.append(group_id)
def _setup_scene(self, **kwargs):
# Init sim manager.
@@ -273,10 +272,9 @@ def _setup_scene(self, **kwargs):
# Setup camera groups for rendering.
self._camera_group_ids: List[int] = []
- if self.sim.is_rt_enabled:
- for sensor in self.sensors.values():
- if isinstance(sensor, Camera):
- self._camera_group_ids.append(sensor.group_id)
+ for sensor in self.sensors.values():
+ if isinstance(sensor, Camera):
+ self._camera_group_ids.append(sensor.group_id)
def _setup_robot(self, **kwargs) -> Robot:
"""Load the robot agent, setup the controller and action space.
@@ -367,10 +365,8 @@ def _get_sensor_obs(self, **kwargs) -> TensorDict[str, any]:
"""
obs = TensorDict({}, batch_size=[self.num_envs], device=self.device)
- fetch_only = False
- if self.sim.is_rt_enabled:
- fetch_only = True
- self.sim.render_camera_group(self._camera_group_ids)
+ fetch_only = True
+ self.sim.render_camera_group(self._camera_group_ids)
for sensor_name, sensor in self.sensors.items():
sensor.update(fetch_only=fetch_only)
diff --git a/embodichain/lab/gym/envs/embodied_env.py b/embodichain/lab/gym/envs/embodied_env.py
index 3e699620..a1ede4f1 100644
--- a/embodichain/lab/gym/envs/embodied_env.py
+++ b/embodichain/lab/gym/envs/embodied_env.py
@@ -15,6 +15,7 @@
# ----------------------------------------------------------------------------
from math import log
+from functools import wraps
import os
import torch
import numpy as np
@@ -54,7 +55,6 @@
)
from embodichain.utils import configclass, logger
-
__all__ = ["EmbodiedEnvCfg", "EmbodiedEnv"]
@@ -205,6 +205,30 @@ class EnvLightCfg:
If filter_dataset_saving is False and a dataset manager is configured, the rollout buffer will be initialized by default
"""
+ language: Union[Dict[str, Any], None] = None
+ """Language settings for VLA training.
+
+ When configured, enables hierarchical language data collection for
+ Vision-Language-Action model training. Supports:
+
+ - mode: Storage mode ('tokens', 'embeddings', 'hybrid')
+ - hierarchy_levels: List of levels ('task', 'subtask', 'primitive')
+ - max_tokens: Maximum sequence length per instruction
+ - tokenizer: Tokenizer identifier
+ - language_source: Source of language ('env', 'file', 'llm', 'template')
+ - language_config_path: Path to language descriptions (if source='file')
+
+ Example:
+ language = {
+ "mode": "tokens",
+ "hierarchy_levels": ["task", "subtask", "primitive"],
+ "max_tokens": 512,
+ "tokenizer": "gpt2",
+ "language_source": "file",
+ "language_config_path": "config/language/tasks.yaml",
+ }
+ """
+
@register_env("EmbodiedEnv-v1")
class EmbodiedEnv(BaseEnv):
@@ -231,6 +255,27 @@ class EmbodiedEnv(BaseEnv):
- affordance_datas: The affordance data that can be used to store the intermediate results or information
"""
+ @classmethod
+ def __init_subclass__(cls, **kwargs):
+ """Automatically wrap subclass demo-action builders with shape checks.
+
+ Any subclass overriding ``create_demo_action_list`` will be wrapped so its
+ returned action sequence is validated and, when possible, converted to the
+ environment action dimension.
+ """
+ super().__init_subclass__(**kwargs)
+ method = cls.__dict__.get("create_demo_action_list")
+ if method is None or getattr(method, "_demo_action_shape_wrapped", False):
+ return
+
+ @wraps(method)
+ def wrapped_create_demo_action_list(self, *args, **kwargs):
+ action_list = method(self, *args, **kwargs)
+ return self._normalize_demo_action_list(action_list)
+
+ wrapped_create_demo_action_list._demo_action_shape_wrapped = True
+ setattr(cls, "create_demo_action_list", wrapped_create_demo_action_list)
+
def __init__(self, cfg: EmbodiedEnvCfg, **kwargs):
self.affordance_datas = {}
self.action_bank = None
@@ -246,6 +291,7 @@ def __init__(self, cfg: EmbodiedEnvCfg, **kwargs):
self.reward_manager: RewardManager | None = None
self.action_manager: ActionManager | None = None
self.dataset_manager: DatasetManager | None = None
+ self.language_manager = None
super().__init__(cfg, **kwargs)
@@ -253,12 +299,65 @@ def __init__(self, cfg: EmbodiedEnvCfg, **kwargs):
self.dataset_manager = DatasetManager(self.cfg.dataset, self)
self.cfg.init_rollout_buffer = True
+ # Initialize LanguageManager for VLA training
+ if self.cfg.language:
+ from embodichain.lab.gym.envs.managers import (
+ LanguageCfg,
+ LanguageManager,
+ LanguageProvider,
+ FileBasedLanguageProvider,
+ LLMBasedLanguageProvider,
+ EnvBasedLanguageProvider,
+ TemplateBasedLanguageProvider,
+ )
+
+ # Create language config
+ language_cfg = LanguageCfg(**self.cfg.language)
+
+ # Initialize language provider based on source
+ language_source = self.cfg.language.get("language_source", "env")
+ if language_source == "file":
+ language_config_path = self.cfg.language.get("language_config_path")
+ if language_config_path is None:
+ log_error(
+ "language_config_path must be provided when language_source='file'",
+ error_type=ValueError,
+ )
+ self.language_provider = FileBasedLanguageProvider(
+ language_cfg, language_config_path
+ )
+ elif language_source == "llm":
+ model = self.cfg.language.get("model", "gpt-4")
+ api_key = self.cfg.language.get("api_key")
+ self.language_provider = LLMBasedLanguageProvider(
+ language_cfg, model, api_key
+ )
+ elif language_source == "template":
+ templates = self.cfg.language.get("templates", {})
+ variables = self.cfg.language.get("variables", {})
+ self.language_provider = TemplateBasedLanguageProvider(
+ language_cfg, templates, variables
+ )
+ else: # env or default
+ self.language_provider = EnvBasedLanguageProvider(language_cfg, self)
+
+ # Initialize language manager
+ self.language_manager = LanguageManager(language_cfg, self)
+ log_info(
+ f"[EmbodiedEnv] LanguageManager initialized with source={language_source}, "
+ f"mode={language_cfg.mode}, hierarchy={language_cfg.hierarchy_levels}"
+ )
+ else:
+ self.language_manager = None
+ self.language_provider = None
+
# Rollout buffer for episode data collection.
# The shape of the buffer is (num_envs, max_episode_steps, *data_shape) for each key.
# The default key in the buffer are:
# - obs: the observation returned by the environment.
# - action: the action applied to the environment.
# - reward: the reward returned by the environment.
+ # - language: Hierarchical language data for VLA training (if language_manager is set)
# TODO: we may add more keys and make the buffer extensible in the future.
# This buffer should also be support initialized from outside of the environment.
# For example, a shared rollout buffer initialized in model training process and passed to the environment for data collection.
@@ -266,6 +365,8 @@ def __init__(self, cfg: EmbodiedEnvCfg, **kwargs):
self._max_rollout_steps = 0
self._rollout_buffer_mode: str | None = None
if self.cfg.init_rollout_buffer:
+ # Determine if we need to initialize language fields
+ language_cfg = self.cfg.language if self.cfg.language else None
self.rollout_buffer = init_rollout_buffer_from_gym_space(
obs_space=self.observation_space,
action_space=self.action_space,
@@ -530,6 +631,19 @@ def _initialize_episode(
self.episode_success_status[env_ids_to_process] = False
+ # Initialize language data for the new episode
+ if self.language_manager is not None:
+ # Get task ID for language lookup
+ task_id = getattr(self, "task_name", "default")
+
+ # Get language data from provider
+ if self.language_provider is not None:
+ language_data = self.language_provider.get_language(
+ task_id, context={"env_ids": env_ids}
+ )
+ # Write language data to rollout buffer
+ self._write_language_data(language_data, env_ids_to_process)
+
# apply events such as randomization for environments that need a reset
if self.cfg.events:
if "reset" in self.event_manager.available_modes:
@@ -591,6 +705,97 @@ def _write_episode_rollout_step(
rewards.to(buffer_device), non_blocking=True
)
+ def _write_language_data(
+ self,
+ language_data: "HierarchicalLanguageData",
+ env_ids: Optional[torch.Tensor] = None,
+ ) -> None:
+ """Write hierarchical language data to the rollout buffer.
+
+ This method writes language data at multiple hierarchy levels to the
+ rollout buffer. The data is broadcast across all timesteps of the
+ current episode.
+
+ Args:
+ language_data: HierarchicalLanguageData containing task descriptions.
+ env_ids: Optional tensor of environment IDs to write to.
+ If None, writes to all environments.
+ """
+ if self.rollout_buffer is None or "language" not in self.rollout_buffer:
+ return
+
+ if env_ids is None:
+ env_ids = torch.arange(self.num_envs, device=self.device)
+
+ buffer_device = self.rollout_buffer.device
+
+ # Get language config for max values
+ cfg = self.language_manager.cfg
+ max_instructions = cfg.max_instructions_per_level
+ max_tokens = cfg.max_tokens
+
+ # Convert language data to buffer format
+ buffer_format = language_data.to_buffer_format(cfg)
+
+ # Write data for each hierarchy level
+ for level in cfg.hierarchy_levels:
+ level_key = f"{level}_level"
+
+ # Get tokens and mask
+ tokens_key = f"{level_key}_tokens"
+ mask_key = f"{level_key}_attention_mask"
+ count_key = f"{level_key}_count"
+
+ if tokens_key not in buffer_format:
+ continue
+
+ tokens = buffer_format[tokens_key] # [max_instructions, max_tokens]
+ mask = buffer_format[mask_key]
+
+ # Create the full tensor for all environments and timesteps
+ # Shape: [num_envs, max_episode_steps, max_instructions, max_tokens]
+ full_tokens = (
+ tokens.unsqueeze(0)
+ .unsqueeze(0)
+ .expand(len(env_ids), self._max_rollout_steps, -1, -1)
+ )
+ full_mask = (
+ mask.unsqueeze(0)
+ .unsqueeze(0)
+ .expand(len(env_ids), self._max_rollout_steps, -1, -1)
+ )
+
+ # Write to buffer
+ self.rollout_buffer["language"][tokens_key][env_ids, ...] = full_tokens.to(
+ buffer_device, non_blocking=True
+ )
+ self.rollout_buffer["language"][mask_key][env_ids, ...] = full_mask.to(
+ buffer_device, non_blocking=True
+ )
+
+ # Write instruction count
+ count = buffer_format.get(f"{level_key}_count", torch.tensor([0]))
+ level_idx = {"task": 0, "subtask": 1, "primitive": 2}[level]
+ self.rollout_buffer["language"]["instruction_counts"][
+ env_ids, :, level_idx
+ ] = count.item()
+
+ # Write change points
+ if "change_points" in buffer_format:
+ change_points = buffer_format["change_points"]
+ full_change_points = (
+ change_points.unsqueeze(0)
+ .unsqueeze(0)
+ .expand(len(env_ids), self._max_rollout_steps, -1)
+ )
+ self.rollout_buffer["language"]["change_points"][env_ids, ...] = (
+ full_change_points.to(buffer_device, non_blocking=True)
+ )
+
+ # Write hierarchy depth
+ hierarchy_depth = language_data.hierarchy_depth
+ self.rollout_buffer["language"]["hierarchy_depth"][env_ids, :] = hierarchy_depth
+
def _write_rl_rollout_step(
self,
obs: EnvObs,
@@ -624,6 +829,112 @@ def _write_rl_rollout_step(
: self.num_envs, self.current_rollout_step
].copy_(truncateds.to(buffer_device), non_blocking=True)
+ def _normalize_demo_action_list(
+ self, action_list: Sequence[EnvAction] | torch.Tensor | None
+ ) -> Sequence[EnvAction] | torch.Tensor | None:
+ """Validate/convert demo action outputs to match single action-space dim."""
+ if action_list is None:
+ return None
+
+ expected_dim = int(np.prod(self.action_space.shape))
+
+ if isinstance(action_list, torch.Tensor):
+ return self._normalize_demo_action_tensor(action_list, expected_dim)
+
+ if not isinstance(action_list, Sequence):
+ raise TypeError(
+ "create_demo_action_list must return None, a torch.Tensor, or a sequence of actions. "
+ f"Got {type(action_list)}."
+ )
+
+ normalized_action_list = [
+ self._normalize_demo_action_tensor(action, expected_dim)
+ for action in action_list
+ ]
+ return type(action_list)(normalized_action_list)
+
+ def _normalize_demo_action_tensor(
+ self, action: EnvAction | torch.Tensor, expected_dim: int
+ ) -> EnvAction | torch.Tensor:
+ """Normalize one action tensor to the expected action dimension.
+
+ Conversion rule:
+ - If last-dim equals action-space dim, keep as-is.
+ - If last-dim is larger, slice with ``active_joint_ids``.
+ - If last-dim is smaller, raise ``ValueError``.
+ """
+ if isinstance(action, TensorDict):
+ return self._normalize_demo_action_tensordict(action, expected_dim)
+
+ if not isinstance(action, torch.Tensor):
+ raise TypeError(
+ "Each demo action must be a torch.Tensor or TensorDict. "
+ f"Got {type(action)}."
+ )
+
+ if action.ndim == 0:
+ raise ValueError(
+ "Demo action tensor must have at least one dimension with action features on the last axis."
+ )
+
+ action_dim = int(action.shape[-1])
+ if action_dim == expected_dim:
+ return action
+ if action_dim < expected_dim:
+ raise ValueError(
+ "Demo action dim is smaller than action space dim and cannot be auto-converted. "
+ f"Got action dim={action_dim}, expected={expected_dim}."
+ )
+ return self._slice_action_with_active_joint_ids(
+ action, action_dim, expected_dim
+ )
+
+ def _normalize_demo_action_tensordict(
+ self, action: TensorDict, expected_dim: int
+ ) -> TensorDict:
+ """Normalize tensor entries in a TensorDict action payload."""
+ converted_action = action.clone()
+ for key in ("qpos", "qvel", "qf"):
+ if key not in converted_action:
+ continue
+ value = converted_action[key]
+ if value.ndim == 0:
+ raise ValueError(
+ f"Demo action TensorDict['{key}'] must have at least one dimension."
+ )
+ action_dim = int(value.shape[-1])
+ if action_dim == expected_dim:
+ continue
+ if action_dim < expected_dim:
+ raise ValueError(
+ f"Demo action TensorDict['{key}'] dim={action_dim} is smaller than expected action dim={expected_dim}."
+ )
+ converted_action[key] = self._slice_action_with_active_joint_ids(
+ value, action_dim, expected_dim
+ )
+ return converted_action
+
+ def _slice_action_with_active_joint_ids(
+ self, action: torch.Tensor, action_dim: int, expected_dim: int
+ ) -> torch.Tensor:
+ """Slice a high-dimensional action to active joints.
+
+ This is used when demo actions are generated in full-DoF form while the
+ environment action-space only controls active joints.
+ """
+ if len(self.active_joint_ids) != expected_dim:
+ raise ValueError(
+ "Cannot convert demo action by active_joint_ids because their length does not match the action space dim. "
+ f"len(active_joint_ids)={len(self.active_joint_ids)}, expected={expected_dim}."
+ )
+
+ if len(self.active_joint_ids) == 0:
+ raise ValueError(
+ "Cannot convert demo action by active_joint_ids because active_joint_ids is empty."
+ )
+
+ return action[..., self.active_joint_ids]
+
def _step_action(self, action: EnvAction) -> EnvAction:
"""Set action control command into simulation.
@@ -907,6 +1218,12 @@ def create_demo_action_list(self, *args, **kwargs) -> Sequence[EnvAction] | None
Returns:
Sequence[EnvAction] | None: A list of actions if a demonstration is available, otherwise None.
+
+ Note:
+ Subclass outputs are automatically post-processed by the base class:
+ action last-dimension must match ``single_action_space``. If larger,
+ actions are sliced by ``active_joint_ids``; if smaller, ``ValueError``
+ is raised.
"""
raise NotImplementedError(
"The method 'create_demo_action_list' must be implemented in subclasses."
diff --git a/embodichain/lab/gym/envs/managers/__init__.py b/embodichain/lab/gym/envs/managers/__init__.py
index 939f190c..22b1effe 100644
--- a/embodichain/lab/gym/envs/managers/__init__.py
+++ b/embodichain/lab/gym/envs/managers/__init__.py
@@ -30,3 +30,18 @@
from .action_manager import *
from .actions import *
from .dataset_manager import DatasetManager
+from .language import (
+ LanguageCfg,
+ LanguageCurriculumCfg,
+ LanguageAugmentationCfg,
+ LanguageManager,
+ LanguageData,
+ HierarchicalLanguageData,
+)
+from .language_provider import (
+ LanguageProvider,
+ FileBasedLanguageProvider,
+ LLMBasedLanguageProvider,
+ EnvBasedLanguageProvider,
+ TemplateBasedLanguageProvider,
+)
diff --git a/embodichain/lab/gym/envs/managers/language.py b/embodichain/lab/gym/envs/managers/language.py
new file mode 100644
index 00000000..a62bc097
--- /dev/null
+++ b/embodichain/lab/gym/envs/managers/language.py
@@ -0,0 +1,767 @@
+# ----------------------------------------------------------------------------
+# Copyright (c) 2021-2026 DexForce Technology Co., Ltd.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ----------------------------------------------------------------------------
+
+from __future__ import annotations
+
+from abc import ABC, abstractmethod
+from typing import Any, Dict, List, Optional, Literal, Union
+from dataclasses import dataclass, field
+from pathlib import Path
+
+import torch
+import numpy as np
+
+from embodichain.utils import configclass
+from embodichain.utils.logger import log_info, log_warning, log_error
+
+__all__ = [
+ "LanguageCfg",
+ "LanguageCurriculumCfg",
+ "LanguageAugmentationCfg",
+ "LanguageManager",
+ "LanguageData",
+ "HierarchicalLanguageData",
+]
+
+
+@configclass
+class LanguageCfg:
+ """Configuration for language data in rollout buffers.
+
+ Supports three storage modes:
+ - 'tokens': Store token IDs (default, most flexible)
+ - 'embeddings': Store pre-computed embeddings
+ - 'hybrid': Store both tokens and embeddings
+
+ Supports hierarchical language structure for VLA training:
+ - task_level: Overall goal/description
+ - subtask_level: Intermediate step descriptions
+ - primitive_level: Low-level action descriptions
+
+ Args:
+ mode: Storage mode ('tokens', 'embeddings', or 'hybrid').
+ hierarchy_levels: List of hierarchy levels to store. If None, uses
+ all levels. Valid levels: 'task', 'subtask', 'primitive'.
+ max_tokens: Maximum sequence length for tokenized text.
+ tokenizer: Tokenizer/model identifier (huggingface or OpenAI).
+ pad_token_id: Token ID used for padding.
+ max_instructions_per_level: Maximum number of instructions per hierarchy level.
+ embedding_dim: Dimension of text embeddings (when mode='embeddings' or 'hybrid').
+ embedding_type: How to compute embeddings from tokens.
+ tokenizer_backend: 'huggingface' or 'openai'.
+ trust_remote_code: Whether to trust remote code for huggingface tokenizers.
+ """
+
+ mode: Literal["tokens", "embeddings", "hybrid"] = "tokens"
+ """Storage mode for language data."""
+
+ hierarchy_levels: Optional[List[Literal["task", "subtask", "primitive"]]] = None
+ """Hierarchy levels to store. If None, uses all levels."""
+
+ max_tokens: int = 512
+ """Maximum sequence length for tokenized text per instruction."""
+
+ tokenizer: str = "gpt2"
+ """Tokenizer/model identifier."""
+
+ pad_token_id: int = 0
+ """Token ID used for padding."""
+
+ max_instructions_per_level: int = 3
+ """Maximum number of instructions per hierarchy level."""
+
+ embedding_dim: int = 768
+ """Dimension of text embeddings."""
+
+ embedding_type: Literal["mean_pool", "cls", "last"] = "mean_pool"
+ """How to compute embeddings from tokens."""
+
+ tokenizer_backend: Literal["huggingface", "openai"] = "huggingface"
+ """Tokenizer backend to use."""
+
+ trust_remote_code: bool = False
+ """Whether to trust remote code for huggingface tokenizers."""
+
+ def __post_init__(self) -> None:
+ if self.hierarchy_levels is None:
+ self.hierarchy_levels = ["task", "subtask", "primitive"]
+
+ # Validate hierarchy levels
+ valid_levels = {"task", "subtask", "primitive"}
+ for level in self.hierarchy_levels:
+ if level not in valid_levels:
+ log_error(
+ f"Invalid hierarchy level: {level}. Must be one of {valid_levels}.",
+ error_type=ValueError,
+ )
+
+
+@configclass
+class LanguageCurriculumCfg:
+ """Language complexity curriculum for progressive training.
+
+ Defines stages of increasing language complexity, allowing the model
+ to learn from simple descriptions before tackling complex ones.
+
+ Args:
+ stages: List of curriculum stages, each defining complexity constraints.
+ stage_duration: Number of training steps per curriculum stage.
+ enabled: Whether curriculum learning is enabled.
+ """
+
+ @dataclass
+ class CurriculumStage:
+ """Configuration for a single curriculum stage."""
+
+ max_words: int = 50
+ """Maximum number of words per instruction."""
+
+ max_sentences: int = 2
+ """Maximum number of sentences per instruction."""
+
+ max_hierarchy_depth: int = 1
+ """Maximum hierarchy depth (1=task only, 2=task+subtask, 3=all)."""
+
+ vocabulary_complexity: Literal["simple", "moderate", "complex"] = "simple"
+ """Vocabulary complexity level."""
+
+ instruction_types: List[str] = field(default_factory=lambda: ["imperative"])
+ """Allowed instruction types: 'imperative', 'declarative', 'conditional'."""
+
+ stages: List[CurriculumStage] = field(
+ default_factory=lambda: [
+ LanguageCurriculumCfg.CurriculumStage(
+ max_words=10,
+ max_sentences=1,
+ max_hierarchy_depth=1,
+ vocabulary_complexity="simple",
+ instruction_types=["imperative"],
+ ),
+ LanguageCurriculumCfg.CurriculumStage(
+ max_words=25,
+ max_sentences=2,
+ max_hierarchy_depth=2,
+ vocabulary_complexity="moderate",
+ instruction_types=["imperative", "declarative"],
+ ),
+ LanguageCurriculumCfg.CurriculumStage(
+ max_words=50,
+ max_sentences=3,
+ max_hierarchy_depth=3,
+ vocabulary_complexity="complex",
+ instruction_types=["imperative", "declarative", "conditional"],
+ ),
+ ]
+ )
+
+ stage_duration: int = 1000
+ """Number of training steps per curriculum stage."""
+
+ enabled: bool = False
+ """Whether curriculum learning is enabled."""
+
+
+@configclass
+class LanguageAugmentationCfg:
+ """Configuration for language data augmentation.
+
+ Augmentations are applied during sampling to increase data diversity
+ and improve model generalization.
+
+ Args:
+ back_translation: Use back-translation for paraphrasing.
+ synonym_replacement: Probability of replacing words with synonyms.
+ template_variation: Apply template-based rephrasing.
+ drop_word: Probability of randomly dropping a word.
+ swap_word: Probability of swapping two adjacent words.
+ insert_word: Probability of inserting a filler word.
+ """
+
+ back_translation: bool = False
+ """Use back-translation for paraphrasing."""
+
+ synonym_replacement: float = 0.0
+ """Probability of replacing words with synonyms [0.0, 1.0]."""
+
+ template_variation: bool = False
+ """Apply template-based rephrasing."""
+
+ drop_word: float = 0.0
+ """Probability of randomly dropping a word [0.0, 1.0]."""
+
+ swap_word: float = 0.0
+ """Probability of swapping two adjacent words [0.0, 1.0]."""
+
+ insert_word: float = 0.0
+ """Probability of inserting a filler word [0.0, 1.0]."""
+
+ augmentation_prob: float = 0.5
+ """Overall probability of applying any augmentation [0.0, 1.0]."""
+
+
+@dataclass
+class LanguageData:
+ """Single-level language data structure.
+
+ Contains tokenized text and metadata for a single instruction.
+
+ Args:
+ tokens: Token IDs tensor of shape [seq_len].
+ attention_mask: Attention mask tensor of shape [seq_len].
+ raw_text: Original raw text string (for debugging).
+ instruction_type: Type of instruction (imperative, declarative, etc.).
+ metadata: Additional metadata dictionary.
+ """
+
+ tokens: torch.Tensor
+ attention_mask: torch.Tensor
+ raw_text: str
+ instruction_type: str = "imperative"
+ metadata: Dict[str, Any] = field(default_factory=dict)
+
+ def to_dict(self) -> Dict[str, Any]:
+ """Convert to dictionary format."""
+ return {
+ "tokens": self.tokens,
+ "attention_mask": self.attention_mask,
+ "raw_text": self.raw_text,
+ "instruction_type": self.instruction_type,
+ "metadata": self.metadata,
+ }
+
+
+@dataclass
+class HierarchicalLanguageData:
+ """Hierarchical language data structure for VLA training.
+
+ Organizes language instructions at multiple abstraction levels:
+ - task_level: High-level goal/description
+ - subtask_level: Intermediate step descriptions
+ - primitive_level: Low-level action descriptions
+
+ This structure enables VLA models to learn from multi-scale language
+ representations, similar to human task understanding.
+
+ Args:
+ task_level: List of task-level instructions.
+ subtask_level: List of subtask-level instructions.
+ primitive_level: List of primitive-level instructions.
+ hierarchy_depth: Current depth of the hierarchy (1-3).
+ change_points: Timesteps where language changes within the trajectory.
+ """
+
+ task_level: List[LanguageData] = field(default_factory=list)
+ subtask_level: List[LanguageData] = field(default_factory=list)
+ primitive_level: List[LanguageData] = field(default_factory=list)
+ hierarchy_depth: int = 3
+ change_points: Optional[List[int]] = None
+
+ def __post_init__(self) -> None:
+ if self.change_points is None:
+ self.change_points = [0]
+
+ def get_level(self, level: str) -> List[LanguageData]:
+ """Get language data for a specific hierarchy level.
+
+ Args:
+ level: Hierarchy level ('task', 'subtask', 'primitive').
+
+ Returns:
+ List of LanguageData for the requested level.
+ """
+ level_map = {
+ "task": self.task_level,
+ "subtask": self.subtask_level,
+ "primitive": self.primitive_level,
+ }
+ if level not in level_map:
+ log_error(f"Invalid hierarchy level: {level}", error_type=ValueError)
+ return level_map[level]
+
+ def set_level(self, level: str, data: List[LanguageData]) -> None:
+ """Set language data for a specific hierarchy level.
+
+ Args:
+ level: Hierarchy level ('task', 'subtask', 'primitive').
+ data: List of LanguageData to set.
+ """
+ level_map = {
+ "task": "task_level",
+ "subtask": "subtask_level",
+ "primitive": "primitive_level",
+ }
+ if level not in level_map:
+ log_error(f"Invalid hierarchy level: {level}", error_type=ValueError)
+ setattr(self, level_map[level], data)
+
+ def flatten(self) -> Dict[str, List[LanguageData]]:
+ """Flatten hierarchical structure into a dictionary.
+
+ Returns:
+ Dictionary mapping level names to their language data.
+ """
+ return {
+ "task": self.task_level,
+ "subtask": self.subtask_level,
+ "primitive": self.primitive_level,
+ }
+
+ def to_buffer_format(self, cfg: LanguageCfg) -> Dict[str, torch.Tensor]:
+ """Convert hierarchical language data to buffer tensor format.
+
+ Args:
+ cfg: Language configuration for buffer layout.
+
+ Returns:
+ Dictionary with tensor fields ready for rollout buffer.
+ """
+ result = {}
+
+ # Process each hierarchy level
+ for level in cfg.hierarchy_levels:
+ level_data = self.get_level(level)
+ level_key = f"{level}_level"
+
+ # Pad to max_instructions_per_level
+ padded_tokens = []
+ padded_masks = []
+
+ for i in range(cfg.max_instructions_per_level):
+ if i < len(level_data):
+ # Pad sequence to max_tokens
+ tokens = level_data[i].tokens
+ mask = level_data[i].attention_mask
+
+ seq_len = tokens.shape[0]
+ if seq_len < cfg.max_tokens:
+ pad_len = cfg.max_tokens - seq_len
+ tokens = torch.cat(
+ [
+ tokens,
+ torch.full(
+ (pad_len,),
+ cfg.pad_token_id,
+ dtype=tokens.dtype,
+ device=tokens.device,
+ ),
+ ]
+ )
+ mask = torch.cat(
+ [
+ mask,
+ torch.zeros(
+ (pad_len,), dtype=mask.dtype, device=mask.device
+ ),
+ ]
+ )
+ elif seq_len > cfg.max_tokens:
+ tokens = tokens[: cfg.max_tokens]
+ mask = mask[: cfg.max_tokens]
+ else:
+ # Empty instruction
+ tokens = torch.full(
+ (cfg.max_tokens,),
+ cfg.pad_token_id,
+ dtype=torch.int64,
+ device="cpu",
+ )
+ mask = torch.zeros(
+ (cfg.max_tokens,),
+ dtype=torch.int64,
+ device="cpu",
+ )
+
+ padded_tokens.append(tokens)
+ padded_masks.append(mask)
+
+ # Stack instructions
+ result[f"{level_key}_tokens"] = torch.stack(padded_tokens)
+ result[f"{level_key}_attention_mask"] = torch.stack(padded_masks)
+
+ # Add instruction counts
+ result["instruction_counts"] = torch.tensor(
+ [
+ len(self.task_level),
+ len(self.subtask_level),
+ len(self.primitive_level),
+ ],
+ dtype=torch.int64,
+ )
+
+ # Add change points (padded to max_instructions_per_level)
+ change_points = torch.full(
+ (cfg.max_instructions_per_level,),
+ -1,
+ dtype=torch.int64,
+ device="cpu",
+ )
+ for i, cp in enumerate(self.change_points[: cfg.max_instructions_per_level]):
+ change_points[i] = cp
+ result["change_points"] = change_points
+
+ return result
+
+
+class LanguageManager:
+ """Manages language data generation, tokenization, and storage.
+
+ The LanguageManager handles:
+ - Loading and configuring tokenizers
+ - Generating or retrieving hierarchical language descriptions
+ - Tokenizing text into model-ready format
+ - Managing language curriculum and augmentation
+
+ Args:
+ cfg: Language configuration.
+ env: Reference to the environment for context.
+ """
+
+ def __init__(self, cfg: LanguageCfg, env) -> None:
+ self.cfg = cfg
+ self.env = env
+ self._tokenizer = None
+ self._load_tokenizer()
+
+ # Curriculum state
+ self._curriculum_step = 0
+ self._current_stage = 0
+
+ # Cache for tokenized language
+ self._language_cache: Dict[str, HierarchicalLanguageData] = {}
+
+ log_info(
+ f"[LanguageManager] Initialized with mode={cfg.mode}, "
+ f"hierarchy={cfg.hierarchy_levels}, tokenizer={cfg.tokenizer}"
+ )
+
+ def _load_tokenizer(self) -> None:
+ """Load the tokenizer based on configuration."""
+ if self.cfg.tokenizer_backend == "huggingface":
+ try:
+ from transformers import AutoTokenizer
+
+ self._tokenizer = AutoTokenizer.from_pretrained(
+ self.cfg.tokenizer,
+ trust_remote_code=self.cfg.trust_remote_code,
+ )
+
+ # Update pad_token_id from tokenizer if not specified
+ if (
+ self.cfg.pad_token_id == 0
+ and self._tokenizer.pad_token_id is not None
+ ):
+ self.cfg.pad_token_id = self._tokenizer.pad_token_id
+
+ log_info(
+ f"[LanguageManager] Loaded huggingface tokenizer: {self.cfg.tokenizer}"
+ )
+ except ImportError:
+ log_error(
+ "transformers library not installed. "
+ "Install with: pip install transformers",
+ error_type=ImportError,
+ )
+ except Exception as e:
+ log_error(
+ f"Failed to load huggingface tokenizer: {e}",
+ error_type=RuntimeError,
+ )
+ elif self.cfg.tokenizer_backend == "openai":
+ try:
+ import tiktoken
+
+ self._tokenizer = tiktoken.encoding_for_model(self.cfg.tokenizer)
+ log_info(
+ f"[LanguageManager] Loaded OpenAI tokenizer: {self.cfg.tokenizer}"
+ )
+ except ImportError:
+ log_error(
+ "tiktoken library not installed. "
+ "Install with: pip install tiktoken",
+ error_type=ImportError,
+ )
+ else:
+ log_error(
+ f"Unknown tokenizer backend: {self.cfg.tokenizer_backend}",
+ error_type=ValueError,
+ )
+
+ def tokenize(
+ self, text: str, return_tensors: str = "pt"
+ ) -> Dict[str, torch.Tensor]:
+ """Tokenize a single text string.
+
+ Args:
+ text: Text to tokenize.
+ return_tensors: Return tensor format ('pt' for PyTorch).
+
+ Returns:
+ Dictionary with 'input_ids' and 'attention_mask'.
+ """
+ if self._tokenizer is None:
+ log_error("Tokenizer not initialized", error_type=RuntimeError)
+
+ if self.cfg.tokenizer_backend == "huggingface":
+ result = self._tokenizer(
+ text,
+ max_length=self.cfg.max_tokens,
+ padding="max_length",
+ truncation=True,
+ return_tensors=return_tensors,
+ )
+ # Ensure dtype is int64
+ result["input_ids"] = result["input_ids"].to(torch.int64)
+ result["attention_mask"] = result["attention_mask"].to(torch.int64)
+ return result
+ else: # openai/tiktoken
+ tokens = self._tokenizer.encode(
+ text,
+ max_length=self.cfg.max_tokens,
+ truncation=True,
+ )
+ # Pad to max_tokens
+ if len(tokens) < self.cfg.max_tokens:
+ tokens = tokens + [self.cfg.pad_token_id] * (
+ self.cfg.max_tokens - len(tokens)
+ )
+ else:
+ tokens = tokens[: self.cfg.max_tokens]
+
+ input_ids = torch.tensor(tokens, dtype=torch.int64)
+ attention_mask = (input_ids != self.cfg.pad_token_id).to(torch.int64)
+
+ return {"input_ids": input_ids, "attention_mask": attention_mask}
+
+ def tokenize_batch(
+ self, texts: List[str], return_tensors: str = "pt"
+ ) -> Dict[str, torch.Tensor]:
+ """Tokenize a batch of text strings.
+
+ Args:
+ texts: List of texts to tokenize.
+ return_tensors: Return tensor format ('pt' for PyTorch).
+
+ Returns:
+ Dictionary with 'input_ids' and 'attention_mask' tensors.
+ """
+ if self._tokenizer is None:
+ log_error("Tokenizer not initialized", error_type=RuntimeError)
+
+ if self.cfg.tokenizer_backend == "huggingface":
+ result = self._tokenizer(
+ texts,
+ max_length=self.cfg.max_tokens,
+ padding="max_length",
+ truncation=True,
+ return_tensors=return_tensors,
+ )
+ result["input_ids"] = result["input_ids"].to(torch.int64)
+ result["attention_mask"] = result["attention_mask"].to(torch.int64)
+ return result
+ else: # openai/tiktoken
+ batch_tokens = []
+ for text in texts:
+ tokens = self._tokenizer.encode(
+ text,
+ max_length=self.cfg.max_tokens,
+ truncation=True,
+ )
+ if len(tokens) < self.cfg.max_tokens:
+ tokens = tokens + [self.cfg.pad_token_id] * (
+ self.cfg.max_tokens - len(tokens)
+ )
+ else:
+ tokens = tokens[: self.cfg.max_tokens]
+ batch_tokens.append(tokens)
+
+ input_ids = torch.tensor(batch_tokens, dtype=torch.int64)
+ attention_mask = (input_ids != self.cfg.pad_token_id).to(torch.int64)
+
+ return {"input_ids": input_ids, "attention_mask": attention_mask}
+
+ def decode(self, token_ids: torch.Tensor) -> str:
+ """Decode token IDs back to text.
+
+ Args:
+ token_ids: Token IDs to decode.
+
+ Returns:
+ Decoded text string.
+ """
+ if self._tokenizer is None:
+ log_error("Tokenizer not initialized", error_type=RuntimeError)
+
+ # Remove padding
+ mask = token_ids != self.cfg.pad_token_id
+ token_ids = token_ids[mask]
+
+ if self.cfg.tokenizer_backend == "huggingface":
+ return self._tokenizer.decode(token_ids, skip_special_tokens=True)
+ else: # openai/tiktoken
+ return self._tokenizer.decode(token_ids)
+
+ def create_language_data(
+ self, text: str, instruction_type: str = "imperative", **metadata
+ ) -> LanguageData:
+ """Create a LanguageData object from raw text.
+
+ Args:
+ text: Raw text string.
+ instruction_type: Type of instruction.
+ **metadata: Additional metadata.
+
+ Returns:
+ LanguageData object with tokenized text.
+ """
+ tokenized = self.tokenize(text)
+ return LanguageData(
+ tokens=tokenized["input_ids"].squeeze(0),
+ attention_mask=tokenized["attention_mask"].squeeze(0),
+ raw_text=text,
+ instruction_type=instruction_type,
+ metadata=metadata,
+ )
+
+ def create_hierarchical_language_data(
+ self,
+ task_texts: List[str] | str,
+ subtask_texts: Optional[List[str] | str] = None,
+ primitive_texts: Optional[List[str] | str] = None,
+ change_points: Optional[List[int]] = None,
+ ) -> HierarchicalLanguageData:
+ """Create hierarchical language data from text at multiple levels.
+
+ Args:
+ task_texts: Task-level descriptions (string or list).
+ subtask_texts: Subtask-level descriptions (optional).
+ primitive_texts: Primitive-level descriptions (optional).
+ change_points: Timesteps where language changes (optional).
+
+ Returns:
+ HierarchicalLanguageData object.
+ """
+ # Normalize to lists
+ if isinstance(task_texts, str):
+ task_texts = [task_texts]
+ if subtask_texts is not None and isinstance(subtask_texts, str):
+ subtask_texts = [subtask_texts]
+ if primitive_texts is not None and isinstance(primitive_texts, str):
+ primitive_texts = [primitive_texts]
+
+ # Create language data for each level
+ task_level = [self.create_language_data(text) for text in task_texts]
+ subtask_level = (
+ [self.create_language_data(text) for text in subtask_texts]
+ if subtask_texts is not None
+ else []
+ )
+ primitive_level = (
+ [self.create_language_data(text) for text in primitive_texts]
+ if primitive_texts is not None
+ else []
+ )
+
+ return HierarchicalLanguageData(
+ task_level=task_level,
+ subtask_level=subtask_level,
+ primitive_level=primitive_level,
+ change_points=change_points,
+ )
+
+ def get_task_language(
+ self, task_id: Optional[str] = None
+ ) -> HierarchicalLanguageData:
+ """Generate or retrieve language description for the current task.
+
+ This method should be overridden in subclasses or configured via
+ language providers to implement custom language generation logic.
+
+ Args:
+ task_id: Optional task identifier for cache lookup.
+
+ Returns:
+ HierarchicalLanguageData for the current task.
+ """
+ cache_key = task_id or "default"
+
+ if cache_key in self._language_cache:
+ return self._language_cache[cache_key]
+
+ # Default implementation: generate generic task description
+ task_name = getattr(self.env, "task_name", "unknown_task")
+ task_description = getattr(
+ self.env,
+ "task_description",
+ f"Complete the {task_name} task.",
+ )
+
+ language_data = self.create_hierarchical_language_data(
+ task_texts=task_description,
+ subtask_texts=None, # Can be generated by subclasses
+ primitive_texts=None, # Can be generated by subclasses
+ )
+
+ self._language_cache[cache_key] = language_data
+ return language_data
+
+ def set_curriculum_step(
+ self, step: int, curriculum_cfg: Optional[LanguageCurriculumCfg] = None
+ ) -> None:
+ """Update curriculum learning step.
+
+ Args:
+ step: Current curriculum step.
+ curriculum_cfg: Optional curriculum configuration.
+ """
+ self._curriculum_step = step
+
+ if curriculum_cfg and curriculum_cfg.enabled:
+ self._current_stage = min(
+ step // curriculum_cfg.stage_duration,
+ len(curriculum_cfg.stages) - 1,
+ )
+ log_info(
+ f"[LanguageManager] Curriculum: stage {self._current_stage}/{len(curriculum_cfg.stages)-1} "
+ f"(step {step})"
+ )
+
+ def get_current_stage_constraints(
+ self, curriculum_cfg: Optional[LanguageCurriculumCfg] = None
+ ) -> Optional[Dict[str, Any]]:
+ """Get constraints for the current curriculum stage.
+
+ Args:
+ curriculum_cfg: Optional curriculum configuration.
+
+ Returns:
+ Dictionary of constraints or None if curriculum is disabled.
+ """
+ if not curriculum_cfg or not curriculum_cfg.enabled:
+ return None
+
+ stage = curriculum_cfg.stages[self._current_stage]
+ return {
+ "max_words": stage.max_words,
+ "max_sentences": stage.max_sentences,
+ "max_hierarchy_depth": stage.max_hierarchy_depth,
+ "vocabulary_complexity": stage.vocabulary_complexity,
+ "instruction_types": stage.instruction_types,
+ }
+
+ def clear_cache(self) -> None:
+ """Clear the language cache."""
+ self._language_cache.clear()
+ log_info("[LanguageManager] Language cache cleared")
diff --git a/embodichain/lab/gym/envs/managers/language_provider.py b/embodichain/lab/gym/envs/managers/language_provider.py
new file mode 100644
index 00000000..6ef3a131
--- /dev/null
+++ b/embodichain/lab/gym/envs/managers/language_provider.py
@@ -0,0 +1,647 @@
+# ----------------------------------------------------------------------------
+# Copyright (c) 2021-2026 DexForce Technology Co., Ltd.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ----------------------------------------------------------------------------
+
+from __future__ import annotations
+
+from abc import ABC, abstractmethod
+from typing import Any, Dict, List, Optional, Literal
+from pathlib import Path
+
+import yaml
+import json
+
+from embodichain.utils.logger import log_info, log_warning, log_error
+from .language import (
+ LanguageCfg,
+ HierarchicalLanguageData,
+ LanguageData,
+)
+
+__all__ = [
+ "LanguageProvider",
+ "FileBasedLanguageProvider",
+ "LLMBasedLanguageProvider",
+ "EnvBasedLanguageProvider",
+ "TemplateBasedLanguageProvider",
+]
+
+
+class LanguageProvider(ABC):
+ """Abstract base class for language data sources.
+
+ Language providers are responsible for generating or retrieving
+ hierarchical language descriptions for tasks. Different providers
+ can be used depending on the data source (files, LLMs, environment, etc.).
+
+ Args:
+ cfg: Language configuration.
+ """
+
+ def __init__(self, cfg: LanguageCfg) -> None:
+ self.cfg = cfg
+
+ @abstractmethod
+ def get_language(
+ self, task_id: str, context: Optional[Dict[str, Any]] = None
+ ) -> HierarchicalLanguageData:
+ """Get hierarchical language data for a specific task.
+
+ Args:
+ task_id: Unique identifier for the task.
+ context: Optional context dictionary with environment state.
+
+ Returns:
+ HierarchicalLanguageData with task descriptions at multiple levels.
+ """
+ ...
+
+ @abstractmethod
+ def get_available_tasks(self) -> List[str]:
+ """Get list of available task IDs.
+
+ Returns:
+ List of task identifiers.
+ """
+ ...
+
+ def validate_hierarchy_data(self, data: HierarchicalLanguageData) -> bool:
+ """Validate that hierarchical language data meets configuration constraints.
+
+ Args:
+ data: HierarchicalLanguageData to validate.
+
+ Returns:
+ True if data is valid, False otherwise.
+ """
+ # Check each level doesn't exceed max instructions
+ if len(data.task_level) > self.cfg.max_instructions_per_level:
+ log_warning(
+ f"Task level has {len(data.task_level)} instructions, "
+ f"exceeding max {self.cfg.max_instructions_per_level}"
+ )
+ return False
+
+ if len(data.subtask_level) > self.cfg.max_instructions_per_level:
+ log_warning(
+ f"Subtask level has {len(data.subtask_level)} instructions, "
+ f"exceeding max {self.cfg.max_instructions_per_level}"
+ )
+ return False
+
+ if len(data.primitive_level) > self.cfg.max_instructions_per_level:
+ log_warning(
+ f"Primitive level has {len(data.primitive_level)} instructions, "
+ f"exceeding max {self.cfg.max_instructions_per_level}"
+ )
+ return False
+
+ return True
+
+
+class FileBasedLanguageProvider(LanguageProvider):
+ """Language provider that loads task descriptions from files.
+
+ Supports YAML and JSON file formats. The file structure should contain
+ task IDs mapped to their hierarchical descriptions.
+
+ Example YAML structure:
+ ```yaml
+ pick_and_place:
+ task:
+ - "Pick up the red block and place it in the blue basket."
+ subtask:
+ - "Move the gripper to the red block."
+ - "Grasp the red block."
+ - "Lift the block and move to the blue basket."
+ - "Release the block into the basket."
+ primitive:
+ - "Close gripper."
+ - "Move up."
+ - "Move right."
+ - "Open gripper."
+ ```
+
+ Args:
+ cfg: Language configuration.
+ config_path: Path to the configuration file (YAML or JSON).
+ reload_on_access: Whether to reload the file on each access (for dynamic updates).
+ """
+
+ def __init__(
+ self,
+ cfg: LanguageCfg,
+ config_path: str,
+ reload_on_access: bool = False,
+ ) -> None:
+ super().__init__(cfg)
+ self.config_path = Path(config_path)
+ self.reload_on_access = reload_on_access
+ self._data: Dict[str, Any] = {}
+ self._load_data()
+
+ def _load_data(self) -> None:
+ """Load language data from the configuration file."""
+ if not self.config_path.exists():
+ log_error(
+ f"Language config file not found: {self.config_path}",
+ error_type=FileNotFoundError,
+ )
+
+ suffix = self.config_path.suffix.lower()
+
+ try:
+ with open(self.config_path, "r", encoding="utf-8") as f:
+ if suffix in [".yaml", ".yml"]:
+ self._data = yaml.safe_load(f)
+ elif suffix == ".json":
+ self._data = json.load(f)
+ else:
+ log_error(
+ f"Unsupported file format: {suffix}. Use .yaml, .yml, or .json",
+ error_type=ValueError,
+ )
+
+ log_info(
+ f"[FileBasedLanguageProvider] Loaded {len(self._data)} task descriptions "
+ f"from {self.config_path}"
+ )
+ except Exception as e:
+ log_error(
+ f"Failed to load language config from {self.config_path}: {e}",
+ error_type=RuntimeError,
+ )
+
+ def get_language(
+ self, task_id: str, context: Optional[Dict[str, Any]] = None
+ ) -> HierarchicalLanguageData:
+ """Get language data from file for a specific task.
+
+ Args:
+ task_id: Unique identifier for the task.
+ context: Optional context (not used in file-based provider).
+
+ Returns:
+ HierarchicalLanguageData loaded from file.
+ """
+ if self.reload_on_access:
+ self._load_data()
+
+ if task_id not in self._data:
+ log_error(
+ f"Task ID '{task_id}' not found in language config. "
+ f"Available tasks: {list(self._data.keys())}",
+ error_type=KeyError,
+ )
+
+ task_data = self._data[task_id]
+
+ # Extract hierarchical descriptions
+ task_texts = task_data.get("task", [])
+ subtask_texts = task_data.get("subtask", [])
+ primitive_texts = task_data.get("primitive", [])
+ change_points = task_data.get("change_points", None)
+
+ # Import LanguageManager to create data (we need tokenizer access)
+ from .language import LanguageManager
+
+ # Create a temporary manager for tokenization
+ # In practice, the environment should provide the manager
+ class _TempManager:
+ def __init__(self, cfg):
+ self.cfg = cfg
+ self._tokenizer = None
+ self._load_tokenizer()
+
+ def _load_tokenizer(self):
+ if self.cfg.tokenizer_backend == "huggingface":
+ from transformers import AutoTokenizer
+
+ self._tokenizer = AutoTokenizer.from_pretrained(
+ self.cfg.tokenizer,
+ trust_remote_code=self.cfg.trust_remote_code,
+ )
+ if (
+ self.cfg.pad_token_id == 0
+ and self._tokenizer.pad_token_id is not None
+ ):
+ self.cfg.pad_token_id = self._tokenizer.pad_token_id
+ else:
+ import tiktoken
+
+ self._tokenizer = tiktoken.encoding_for_model(self.cfg.tokenizer)
+
+ def tokenize(self, text):
+ if self.cfg.tokenizer_backend == "huggingface":
+ result = self._tokenizer(
+ text,
+ max_length=self.cfg.max_tokens,
+ padding="max_length",
+ truncation=True,
+ return_tensors="pt",
+ )
+ return result["input_ids"].squeeze(0).to(torch.int64), result[
+ "attention_mask"
+ ].squeeze(0).to(torch.int64)
+ else:
+ import torch
+
+ tokens = self._tokenizer.encode(
+ text, max_length=self.cfg.max_tokens, truncation=True
+ )
+ if len(tokens) < self.cfg.max_tokens:
+ tokens = tokens + [self.cfg.pad_token_id] * (
+ self.cfg.max_tokens - len(tokens)
+ )
+ else:
+ tokens = tokens[: self.cfg.max_tokens]
+ input_ids = torch.tensor(tokens, dtype=torch.int64)
+ attention_mask = (input_ids != self.cfg.pad_token_id).to(
+ torch.int64
+ )
+ return input_ids, attention_mask
+
+ def create_language_data(self, text):
+ tokens, mask = self.tokenize(text)
+ return LanguageData(tokens=tokens, attention_mask=mask, raw_text=text)
+
+ temp_mgr = _TempManager(self.cfg)
+
+ # Build hierarchical language data
+ task_level = [
+ temp_mgr.create_language_data(t) if isinstance(t, str) else t
+ for t in (task_texts if isinstance(task_texts, list) else [task_texts])
+ ]
+ subtask_level = (
+ [
+ temp_mgr.create_language_data(t) if isinstance(t, str) else t
+ for t in (subtask_texts if isinstance(subtask_texts, list) else [])
+ ]
+ if subtask_texts
+ else []
+ )
+ primitive_level = (
+ [
+ temp_mgr.create_language_data(t) if isinstance(t, str) else t
+ for t in (primitive_texts if isinstance(primitive_texts, list) else [])
+ ]
+ if primitive_texts
+ else []
+ )
+
+ return HierarchicalLanguageData(
+ task_level=task_level,
+ subtask_level=subtask_level,
+ primitive_level=primitive_level,
+ change_points=change_points,
+ )
+
+ def get_available_tasks(self) -> List[str]:
+ """Get list of available task IDs from the file.
+
+ Returns:
+ List of task identifiers.
+ """
+ return list(self._data.keys())
+
+
+class LLMBasedLanguageProvider(LanguageProvider):
+ """Language provider that generates descriptions using an LLM.
+
+ This provider uses a language model to generate task descriptions
+ on-the-fly based on task context and templates.
+
+ Args:
+ cfg: Language configuration.
+ model: Model identifier (e.g., "gpt-4", "claude-3-opus").
+ api_key: API key for the LLM service.
+ templates: Optional dictionary of templates for different task types.
+ """
+
+ def __init__(
+ self,
+ cfg: LanguageCfg,
+ model: str = "gpt-4",
+ api_key: Optional[str] = None,
+ templates: Optional[Dict[str, Any]] = None,
+ ) -> None:
+ super().__init__(cfg)
+ self.model = model
+ self.api_key = api_key
+ self.templates = templates or self._default_templates()
+ self._client = None
+ self._init_client()
+
+ def _default_templates(self) -> Dict[str, str]:
+ """Default prompt templates for language generation."""
+ return {
+ "task": "Generate a clear, concise task description for: {task_name}.",
+ "subtask": "Break down the task '{task_name}' into {num_steps} step-by-step instructions.",
+ "primitive": "For each subtask, provide low-level action descriptions in: {task_name}.",
+ }
+
+ def _init_client(self) -> None:
+ """Initialize the LLM client based on model type."""
+ if self.model.startswith("gpt"):
+ try:
+ import openai
+
+ self._client = openai.OpenAI(api_key=self.api_key)
+ except ImportError:
+ log_warning(
+ "openai library not available. LLM provider will use fallback."
+ )
+ elif self.model.startswith("claude"):
+ try:
+ import anthropic
+
+ self._client = anthropic.Anthropic(api_key=self.api_key)
+ except ImportError:
+ log_warning(
+ "anthropic library not available. LLM provider will use fallback."
+ )
+ else:
+ log_warning(
+ f"Unknown model type: {self.model}. LLM provider will use fallback."
+ )
+
+ def _generate_with_llm(self, prompt: str) -> str:
+ """Generate text using the configured LLM.
+
+ Args:
+ prompt: The prompt to send to the LLM.
+
+ Returns:
+ Generated text string.
+ """
+ if self._client is None:
+ # Fallback: return a generic response
+ log_warning("LLM client not available, using fallback response.")
+ return "Complete the task as described in the environment."
+
+ try:
+ if self.model.startswith("gpt"):
+ response = self._client.chat.completions.create(
+ model=self.model,
+ messages=[{"role": "user", "content": prompt}],
+ max_tokens=500,
+ temperature=0.7,
+ )
+ return response.choices[0].message.content
+ elif self.model.startswith("claude"):
+ response = self._client.messages.create(
+ model=self.model,
+ max_tokens=500,
+ messages=[{"role": "user", "content": prompt}],
+ )
+ return response.content[0].text
+ except Exception as e:
+ log_warning(f"LLM generation failed: {e}. Using fallback.")
+ return "Complete the task as described in the environment."
+
+ def get_language(
+ self, task_id: str, context: Optional[Dict[str, Any]] = None
+ ) -> HierarchicalLanguageData:
+ """Generate language data using LLM for a specific task.
+
+ Args:
+ task_id: Unique identifier for the task.
+ context: Optional context with task details.
+
+ Returns:
+ HierarchicalLanguageData generated by LLM.
+ """
+ task_name = context.get("task_name", task_id) if context else task_id
+
+ # Generate task-level description
+ task_prompt = self.templates["task"].format(task_name=task_name)
+ task_text = self._generate_with_llm(task_prompt)
+
+ # Generate subtask-level descriptions
+ num_subtasks = context.get("num_subtasks", 3) if context else 3
+ subtask_prompt = self.templates["subtask"].format(
+ task_name=task_name, num_steps=num_subtasks
+ )
+ subtask_text = self._generate_with_llm(subtask_prompt)
+ subtask_texts = [
+ line.strip() for line in subtask_text.split("\n") if line.strip()
+ ]
+
+ # Generate primitive-level descriptions (optional)
+ primitive_texts = []
+ if context and context.get("include_primitive", False):
+ primitive_prompt = self.templates["primitive"].format(task_name=task_name)
+ primitive_text = self._generate_with_llm(primitive_prompt)
+ primitive_texts = [
+ line.strip() for line in primitive_text.split("\n") if line.strip()
+ ]
+
+ # Create LanguageData objects (would need LanguageManager in practice)
+ # This is a simplified version - in production, use LanguageManager
+ return HierarchicalLanguageData(
+ task_level=[], # Would be populated with LanguageData objects
+ subtask_level=[],
+ primitive_level=[],
+ )
+
+ def get_available_tasks(self) -> List[str]:
+ """Get list of available task IDs.
+
+ For LLM provider, this returns an empty list as tasks are
+ generated on-the-fly.
+
+ Returns:
+ Empty list (tasks are generated dynamically).
+ """
+ return []
+
+
+class EnvBasedLanguageProvider(LanguageProvider):
+ """Language provider that extracts descriptions from the environment.
+
+ This provider delegates language generation to the environment itself,
+ allowing task-specific implementations to provide custom logic.
+
+ Args:
+ cfg: Language configuration.
+ env: The environment instance.
+ """
+
+ def __init__(self, cfg: LanguageCfg, env) -> None:
+ super().__init__(cfg)
+ self.env = env
+
+ def get_language(
+ self, task_id: str, context: Optional[Dict[str, Any]] = None
+ ) -> HierarchicalLanguageData:
+ """Get language data from the environment.
+
+ The environment should implement one of:
+ - get_task_language(task_id, context) -> HierarchicalLanguageData
+ - task_description attribute (simple string)
+ - generate_task_description() method
+
+ Args:
+ task_id: Unique identifier for the task.
+ context: Optional context dictionary.
+
+ Returns:
+ HierarchicalLanguageData from the environment.
+ """
+ # Check for dedicated method
+ if hasattr(self.env, "get_task_language"):
+ return self.env.get_task_language(task_id, context)
+
+ # Check for attribute
+ if hasattr(self.env, "task_description"):
+ task_desc = self.env.task_description
+ # Would need LanguageManager to tokenize
+ return HierarchicalLanguageData(
+ task_level=[], # Would be populated
+ subtask_level=[],
+ primitive_level=[],
+ )
+
+ # Check for method
+ if hasattr(self.env, "generate_task_description"):
+ task_desc = self.env.generate_task_description(context)
+ return HierarchicalLanguageData(
+ task_level=[],
+ subtask_level=[],
+ primitive_level=[],
+ )
+
+ log_error(
+ "Environment does not provide language data. "
+ "Implement get_task_language, set task_description attribute, or generate_task_description method.",
+ error_type=NotImplementedError,
+ )
+
+ def get_available_tasks(self) -> List[str]:
+ """Get list of available task IDs from the environment.
+
+ The environment can optionally provide:
+ - available_tasks attribute
+ - get_available_tasks() method
+
+ Returns:
+ List of task identifiers or empty list.
+ """
+ if hasattr(self.env, "available_tasks"):
+ return self.env.available_tasks
+
+ if hasattr(self.env, "get_available_tasks"):
+ return self.env.get_available_tasks()
+
+ return []
+
+
+class TemplateBasedLanguageProvider(LanguageProvider):
+ """Language provider that uses templates with variable substitution.
+
+ This provider fills in templates with task-specific variables to generate
+ hierarchical descriptions. Useful for structured tasks with predictable patterns.
+
+ Example templates:
+ ```python
+ templates = {
+ "pick_and_place": {
+ "task": "Pick up the {color} {object} and place it {location}.",
+ "subtasks": [
+ "Move to the {color} {object}.",
+ "Grasp the {color} {object}.",
+ "Move {location}.",
+ "Release the {object}.",
+ ],
+ }
+ }
+ ```
+
+ Args:
+ cfg: Language configuration.
+ templates: Dictionary of templates keyed by task ID.
+ variables: Optional default variable values.
+ """
+
+ def __init__(
+ self,
+ cfg: LanguageCfg,
+ templates: Dict[str, Dict[str, Any]],
+ variables: Optional[Dict[str, Any]] = None,
+ ) -> None:
+ super().__init__(cfg)
+ self.templates = templates
+ self.variables = variables or {}
+
+ def get_language(
+ self, task_id: str, context: Optional[Dict[str, Any]] = None
+ ) -> HierarchicalLanguageData:
+ """Generate language data from templates for a specific task.
+
+ Args:
+ task_id: Unique identifier for the task.
+ context: Optional context with variable values.
+
+ Returns:
+ HierarchicalLanguageData generated from templates.
+ """
+ if task_id not in self.templates:
+ log_error(
+ f"Task ID '{task_id}' not found in templates. "
+ f"Available tasks: {list(self.templates.keys())}",
+ error_type=KeyError,
+ )
+
+ template = self.templates[task_id]
+
+ # Merge default variables with context
+ vars_to_use = {**self.variables, **(context or {})}
+
+ # Fill in task-level template
+ task_template = template.get("task", "Complete the task.")
+ task_text = task_template.format(**vars_to_use)
+
+ # Fill in subtask templates
+ subtask_templates = template.get("subtasks", [])
+ subtask_texts = [
+ st.format(**vars_to_use) for st in subtask_templates if isinstance(st, str)
+ ]
+
+ # Fill in primitive templates
+ primitive_templates = template.get("primitives", [])
+ primitive_texts = [
+ pt.format(**vars_to_use)
+ for pt in primitive_templates
+ if isinstance(pt, str)
+ ]
+
+ # Get change points if specified
+ change_points = template.get("change_points", None)
+
+ # Would need LanguageManager to tokenize - return placeholder
+ return HierarchicalLanguageData(
+ task_level=[], # Would be populated with LanguageData objects
+ subtask_level=[],
+ primitive_level=[],
+ change_points=change_points,
+ )
+
+ def get_available_tasks(self) -> List[str]:
+ """Get list of available task IDs from templates.
+
+ Returns:
+ List of task identifiers.
+ """
+ return list(self.templates.keys())
diff --git a/embodichain/lab/gym/envs/managers/randomization/physics.py b/embodichain/lab/gym/envs/managers/randomization/physics.py
index 7088c25a..1eea74e0 100644
--- a/embodichain/lab/gym/envs/managers/randomization/physics.py
+++ b/embodichain/lab/gym/envs/managers/randomization/physics.py
@@ -25,7 +25,6 @@
from embodichain.utils.string import resolve_matching_names
from embodichain.utils import logger
-
if TYPE_CHECKING:
from embodichain.lab.gym.envs import EmbodiedEnv
diff --git a/embodichain/lab/gym/envs/managers/randomization/spatial.py b/embodichain/lab/gym/envs/managers/randomization/spatial.py
index 0b732f5c..1af1c09f 100644
--- a/embodichain/lab/gym/envs/managers/randomization/spatial.py
+++ b/embodichain/lab/gym/envs/managers/randomization/spatial.py
@@ -25,7 +25,6 @@
from embodichain.utils.math import sample_uniform, matrix_from_euler, matrix_from_quat
from embodichain.utils import logger
-
if TYPE_CHECKING:
from embodichain.lab.gym.envs import EmbodiedEnv
diff --git a/embodichain/lab/gym/envs/managers/randomization/visual.py b/embodichain/lab/gym/envs/managers/randomization/visual.py
index 66d3d6fb..17daa5d4 100644
--- a/embodichain/lab/gym/envs/managers/randomization/visual.py
+++ b/embodichain/lab/gym/envs/managers/randomization/visual.py
@@ -658,8 +658,6 @@ def __call__(
roughness_range: tuple[float, float] | None = None,
ior_range: tuple[float, float] | None = None,
):
- from embodichain.lab.sim.utility import is_rt_enabled
-
if self.entity_cfg.uid != "default_plane" and self.entity is None:
return
@@ -700,7 +698,7 @@ def __call__(
)
randomize_plan["roughness"] = roughness
- if ior_range and is_rt_enabled():
+ if ior_range:
ior = sample_uniform(
lower=torch.tensor(ior_range[0], dtype=torch.float32),
upper=torch.tensor(ior_range[1], dtype=torch.float32),
@@ -741,3 +739,6 @@ def __call__(
random_texture_prob=random_texture_prob,
idx=i,
)
+
+ env = self._env.sim.get_env()
+ env.clean_materials()
diff --git a/embodichain/lab/gym/envs/managers/record.py b/embodichain/lab/gym/envs/managers/record.py
index 7c07ecfd..370645a3 100644
--- a/embodichain/lab/gym/envs/managers/record.py
+++ b/embodichain/lab/gym/envs/managers/record.py
@@ -80,8 +80,7 @@ def __init__(self, cfg: FunctorCfg, env: EmbodiedEnv):
)
# Add this camera's group ID to the environment for batch rendering when RT is enabled.
- if getattr(env.sim, "is_rt_enabled", False):
- env.add_camera_group_id(self.camera.group_id)
+ env.add_camera_group_id(self.camera.group_id)
self._save_path = cfg.params.get("save_path", "./outputs/videos")
self._current_episode = 0
@@ -158,7 +157,7 @@ def __call__(
max_env_num: int = 16,
save_path: str = "./outputs/videos",
):
- self.camera.update(fetch_only=self.camera.is_rt_enabled)
+ self.camera.update(fetch_only=True)
data = self.camera.get_data()
rgb = data["color"]
@@ -199,7 +198,7 @@ def __call__(
max_env_num: int = 16,
save_path: str = "./outputs/videos",
):
- self.camera.update(fetch_only=self.camera.is_rt_enabled)
+ self.camera.update(fetch_only=True)
data = self.camera.get_data()
rgb = data["color"] # shape: (num_envs, H, W, 4)
if isinstance(rgb, torch.Tensor):
diff --git a/embodichain/lab/gym/envs/tasks/tableware/pour_water/action_bank.py b/embodichain/lab/gym/envs/tasks/tableware/pour_water/action_bank.py
index 20c8a2d7..1a467133 100644
--- a/embodichain/lab/gym/envs/tasks/tableware/pour_water/action_bank.py
+++ b/embodichain/lab/gym/envs/tasks/tableware/pour_water/action_bank.py
@@ -42,7 +42,6 @@
)
from embodichain.utils import logger
-
__all__ = ["PourWaterActionBank"]
diff --git a/embodichain/lab/gym/utils/gym_utils.py b/embodichain/lab/gym/utils/gym_utils.py
index 0a1e2033..6ed0bd56 100644
--- a/embodichain/lab/gym/utils/gym_utils.py
+++ b/embodichain/lab/gym/utils/gym_utils.py
@@ -20,7 +20,7 @@
import argparse
import gymnasium
-from typing import Dict, Any, List, Tuple, Union, Sequence
+from typing import Dict, Any, List, Tuple, Union, Sequence, Optional
from gymnasium import spaces
from copy import deepcopy
from tensordict import TensorDict
@@ -737,7 +737,7 @@ def add_env_launcher_args_to_parser(parser: argparse.ArgumentParser) -> None:
--num_envs: Number of environments to run in parallel (default: 1)
--device: Device to run the environment on (default: 'cpu')
--headless: Whether to perform the simulation in headless mode (default: False)
- --enable_rt: Whether to use RTX rendering backend for the simulation (default: False)
+ --renderer: Renderer backend to use for the simulation. Options are 'hybrid', 'fast-rt', and 'rt'. (default: 'hybrid')
--gpu_id: The GPU ID to use for the simulation (default: 0)
--gym_config: Path to gym config file (default: '')
--action_config: Path to action config file (default: None)
@@ -769,18 +769,19 @@ def add_env_launcher_args_to_parser(parser: argparse.ArgumentParser) -> None:
default=False,
action="store_true",
)
+ parser.add_argument(
+ "--renderer",
+ type=str,
+ choices=["hybrid", "fast-rt", "rt"],
+ default="hybrid",
+ help="Renderer backend to use for the simulation.",
+ )
parser.add_argument(
"--arena_space",
help="The size of the arena space.",
default=5.0,
type=float,
)
- parser.add_argument(
- "--enable_rt",
- help="Whether to use RTX rendering backend for the simulation.",
- default=False,
- action="store_true",
- )
parser.add_argument(
"--gpu_id",
help="The GPU ID to use for the simulation.",
@@ -792,7 +793,7 @@ def add_env_launcher_args_to_parser(parser: argparse.ArgumentParser) -> None:
type=str,
help="Path to gym config file.",
default="",
- required=True,
+ required=False,
)
parser.add_argument(
"--action_config", type=str, help="Path to action config file.", default=None
@@ -833,7 +834,7 @@ def merge_args_with_gym_config(args: argparse.Namespace, gym_config: dict) -> di
merged_config["num_envs"] = args.num_envs
merged_config["device"] = args.device
merged_config["headless"] = args.headless
- merged_config["enable_rt"] = args.enable_rt
+ merged_config["renderer"] = args.renderer
merged_config["gpu_id"] = args.gpu_id
merged_config["arena_space"] = args.arena_space
return merged_config
@@ -854,6 +855,7 @@ def build_env_cfg_from_args(
from embodichain.utils.utility import load_json
from embodichain.lab.gym.envs import EmbodiedEnvCfg
from embodichain.lab.sim import SimulationManagerCfg
+ from embodichain.lab.sim.cfg import RenderCfg
gym_config = load_json(args.gym_config)
gym_config = merge_args_with_gym_config(args, gym_config)
@@ -876,7 +878,7 @@ def build_env_cfg_from_args(
cfg.sim_cfg = SimulationManagerCfg(
headless=gym_config["headless"],
sim_device=gym_config["device"],
- enable_rt=gym_config["enable_rt"],
+ render_cfg=RenderCfg(renderer=gym_config["renderer"]),
gpu_id=gym_config["gpu_id"],
arena_space=gym_config["arena_space"],
)
@@ -956,12 +958,117 @@ def _init_buffer_from_space(
return rollout_buffer
+def _init_language_buffer(
+ language_cfg: dict,
+ batch_size: int,
+ max_episode_steps: int,
+ device: Union[str, torch.device] = "cpu",
+) -> Dict[str, torch.Tensor]:
+ """Initialize language buffer fields for VLA training.
+
+ Creates tensor fields for hierarchical language data storage.
+
+ Args:
+ language_cfg (dict): Language configuration dictionary.
+ batch_size (int): Number of parallel environments.
+ max_episode_steps (int): Maximum episode length.
+ device (Union[str, torch.device]): Device for tensor allocation.
+
+ Returns:
+ Dict[str, torch.Tensor]: Dictionary of language tensors.
+ """
+ # Get configuration parameters with defaults
+ hierarchy_levels = language_cfg.get(
+ "hierarchy_levels", ["task", "subtask", "primitive"]
+ )
+ max_tokens = language_cfg.get("max_tokens", 512)
+ max_instructions = language_cfg.get("max_instructions_per_level", 3)
+ pad_token_id = language_cfg.get("pad_token_id", 0)
+ mode = language_cfg.get("mode", "tokens")
+
+ language_desc = {}
+
+ # Create tensor fields for each hierarchy level
+ for level in hierarchy_levels:
+ level_key = f"{level}_level"
+
+ # Token IDs: [batch_size, max_episode_steps, max_instructions, max_tokens]
+ language_desc[f"{level_key}_tokens"] = torch.zeros(
+ (batch_size, max_episode_steps, max_instructions, max_tokens),
+ dtype=torch.int64,
+ device=device,
+ )
+
+ # Attention mask: [batch_size, max_episode_steps, max_instructions, max_tokens]
+ language_desc[f"{level_key}_attention_mask"] = torch.zeros(
+ (batch_size, max_episode_steps, max_instructions, max_tokens),
+ dtype=torch.int64,
+ device=device,
+ )
+
+ # Instruction count per level: [batch_size, max_episode_steps]
+ language_desc[f"{level_key}_count"] = torch.zeros(
+ (batch_size, max_episode_steps),
+ dtype=torch.int64,
+ device=device,
+ )
+
+ # Instruction count by hierarchy level: [batch_size, max_episode_steps, 3]
+ # 3 corresponds to [task, subtask, primitive] levels
+ language_desc["instruction_counts"] = torch.zeros(
+ (batch_size, max_episode_steps, 3),
+ dtype=torch.int64,
+ device=device,
+ )
+
+ # Change points: [batch_size, max_episode_steps, max_instructions]
+ # Timesteps where language changes within the trajectory
+ language_desc["change_points"] = torch.full(
+ (batch_size, max_episode_steps, max_instructions),
+ -1,
+ dtype=torch.int64,
+ device=device,
+ )
+
+ # Hierarchy depth: [batch_size, max_episode_steps]
+ # Current depth of hierarchy used (1=task only, 2=task+subtask, 3=all)
+ language_desc["hierarchy_depth"] = torch.full(
+ (batch_size, max_episode_steps),
+ len(hierarchy_levels),
+ dtype=torch.int64,
+ device=device,
+ )
+
+ # Instruction type IDs: [batch_size, max_episode_steps, max_instructions]
+ # Encoding of instruction types (e.g., 0=imperative, 1=declarative, 2=conditional)
+ language_desc["instruction_types"] = torch.zeros(
+ (batch_size, max_episode_steps, max_instructions),
+ dtype=torch.int64,
+ device=device,
+ )
+
+ # Optional: Embedding storage for mode='embeddings' or mode='hybrid'
+ if mode in ("embeddings", "hybrid"):
+ embedding_dim = language_cfg.get("embedding_dim", 768)
+ for level in hierarchy_levels:
+ level_key = f"{level}_level"
+ # Embeddings: [batch_size, max_episode_steps, max_instructions, embedding_dim]
+ language_desc[f"{level_key}_embeddings"] = torch.zeros(
+ (batch_size, max_episode_steps, max_instructions, embedding_dim),
+ dtype=torch.float32,
+ device=device,
+ )
+
+ return language_desc
+
+
def init_rollout_buffer_from_config(
config: dict,
max_episode_steps: int,
batch_size: int,
state_dim: int,
device: Union[str, torch.device] = "cpu",
+ language_cfg: Optional[dict] = None,
) -> TensorDict:
"""Initialize a rollout buffer based on the environment configuration.
@@ -970,15 +1077,19 @@ def init_rollout_buffer_from_config(
- Sensor observations: ``sensor/`` for each sensor in config
- Extra observations: Custom observations from observation functors in ``add`` mode
that have a ``shape`` specified in their ``extra`` parameter
+ - Language data: Hierarchical language descriptions for VLA training (if language_cfg is provided)
Args:
config (dict): The environment configuration dictionary.
max_episode_steps (int): The number of steps in an episode.
batch_size (int): The batch size for the rollout buffer.
state_dim (int): The dimension of the flattened state vector.
+ language_cfg (Optional[dict]): Language configuration for VLA training.
+ If provided, language fields will be added to the buffer.
Returns:
- TensorDict: A TensorDict containing the initialized rollout buffer with keys 'obs', 'actions' and 'rewards'.
+ TensorDict: A TensorDict containing the initialized rollout buffer with keys 'obs', 'actions' and 'rewards',
+ and optionally 'language' if language_cfg is provided.
"""
# TODO: Currently we use this method to pre-allocate a rollout buffer with fixed size for simplicity.
@@ -1132,4 +1243,18 @@ def init_rollout_buffer_from_config(
for obs_name, obs_tensor in extra_obs_desc.items():
assign_data_to_dict(rollout_buffer["obs"], obs_name, obs_tensor)
+ # Add language data for VLA training if language config is provided
+ if language_cfg is not None:
+ language_desc = _init_language_buffer(
+ language_cfg, batch_size, max_episode_steps, device
+ )
+ rollout_buffer["language"] = TensorDict(
+ language_desc,
+ batch_size=[batch_size, max_episode_steps],
+ device=device,
+ )
+ log_info(
+ f"[init_rollout_buffer_from_config] Language buffer added with hierarchy levels: {language_cfg.get('hierarchy_levels', ['task', 'subtask', 'primitive'])}"
+ )
+
return rollout_buffer
diff --git a/embodichain/lab/scripts/preview_asset.py b/embodichain/lab/scripts/preview_asset.py
index 472dca87..bef02faa 100644
--- a/embodichain/lab/scripts/preview_asset.py
+++ b/embodichain/lab/scripts/preview_asset.py
@@ -58,12 +58,13 @@ def build_sim_cfg(args: argparse.Namespace):
Returns:
SimulationManagerCfg: Simulation configuration.
"""
+ from embodichain.lab.sim.cfg import RenderCfg
from embodichain.lab.sim.sim_manager import SimulationManagerCfg
return SimulationManagerCfg(
headless=args.headless,
- enable_rt=args.enable_rt,
sim_device=args.sim_device,
+ render_cfg=RenderCfg(renderer=args.renderer),
)
@@ -88,9 +89,6 @@ def load_assets(sim: SimulationManager, args: argparse.Namespace):
)
from embodichain.lab.sim.shapes import MeshCfg
- # --- light -----------------------------------------------------------
- sim.set_emission_light(intensity=150)
-
asset_paths = args.asset_path
init_pos = tuple(args.init_pos)
init_rot = tuple(args.init_rot)
@@ -286,7 +284,7 @@ def cli():
"--body_type",
type=str,
choices=["dynamic", "kinematic", "static"],
- default="kinematic",
+ default="dynamic",
help="Body type for rigid objects (default: kinematic).",
)
parser.add_argument(
@@ -314,10 +312,11 @@ def cli():
help="Run without rendering window.",
)
parser.add_argument(
- "--enable_rt",
- action="store_true",
- default=False,
- help="Enable ray tracing.",
+ "--renderer",
+ type=str,
+ choices=["hybrid", "fast-rt", "rt"],
+ default="hybrid",
+ help="Renderer backend (default: hybrid).",
)
parser.add_argument(
"--preview",
diff --git a/embodichain/lab/scripts/run_agent.py b/embodichain/lab/scripts/run_agent.py
index 912100ef..73c1eacd 100644
--- a/embodichain/lab/scripts/run_agent.py
+++ b/embodichain/lab/scripts/run_agent.py
@@ -27,7 +27,6 @@
from embodichain.utils.logger import log_error
from .run_env import main
-
if __name__ == "__main__":
np.set_printoptions(5, suppress=True)
torch.set_printoptions(precision=5, sci_mode=False)
diff --git a/embodichain/lab/sim/atom_actions.py b/embodichain/lab/sim/atom_actions.py
index a60a6dbc..2abefea9 100644
--- a/embodichain/lab/sim/atom_actions.py
+++ b/embodichain/lab/sim/atom_actions.py
@@ -39,7 +39,6 @@
extract_drive_calls,
)
-
"""
--------------------------------------------Atom action functions----------------------------------------------------
--------------------------------------------Atom action functions----------------------------------------------------
diff --git a/embodichain/lab/sim/atomic_actions/__init__.py b/embodichain/lab/sim/atomic_actions/__init__.py
new file mode 100644
index 00000000..cf1e60ce
--- /dev/null
+++ b/embodichain/lab/sim/atomic_actions/__init__.py
@@ -0,0 +1,67 @@
+# ----------------------------------------------------------------------------
+# Copyright (c) 2021-2026 DexForce Technology Co., Ltd.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ----------------------------------------------------------------------------
+
+"""Atomic action abstraction layer for embodied AI motion generation.
+
+This module provides a unified interface for atomic actions like reach, grasp,
+move, etc., with support for semantic object understanding and extensible
+custom action registration.
+"""
+
+from .core import (
+ Affordance,
+ AntipodalAffordance,
+ InteractionPoints,
+ ObjectSemantics,
+ ActionCfg,
+ AtomicAction,
+)
+from .actions import (
+ MoveAction,
+ PickUpAction,
+ PlaceAction,
+ MoveActionCfg,
+ PickUpActionCfg,
+ PlaceActionCfg,
+)
+from .engine import (
+ AtomicActionEngine,
+ register_action,
+ unregister_action,
+ get_registered_actions,
+)
+
+__all__ = [
+ # Core classes
+ "Affordance",
+ "GraspPose",
+ "InteractionPoints",
+ "ObjectSemantics",
+ "ActionCfg",
+ "AtomicAction",
+ # Action implementations
+ "MoveAction",
+ "PickUpAction",
+ "PlaceAction",
+ "MoveActionCfg",
+ "PickUpActionCfg",
+ "PlaceActionCfg",
+ # Engine
+ "AtomicActionEngine",
+ "register_action",
+ "unregister_action",
+ "get_registered_actions",
+]
diff --git a/embodichain/lab/sim/atomic_actions/actions.py b/embodichain/lab/sim/atomic_actions/actions.py
new file mode 100644
index 00000000..4f2698de
--- /dev/null
+++ b/embodichain/lab/sim/atomic_actions/actions.py
@@ -0,0 +1,634 @@
+# ----------------------------------------------------------------------------
+# Copyright (c) 2021-2026 DexForce Technology Co., Ltd.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ----------------------------------------------------------------------------
+
+from __future__ import annotations
+
+import torch
+from typing import Optional, Union, TYPE_CHECKING, Any
+
+from embodichain.lab.sim.planners import PlanResult, PlanState, MoveType
+from embodichain.lab.sim.planners.motion_generator import MotionGenOptions
+from embodichain.lab.sim.planners.toppra_planner import ToppraPlanOptions
+from .core import AtomicAction, ObjectSemantics, AntipodalAffordance, ActionCfg
+from embodichain.utils import logger
+from embodichain.utils import configclass
+from embodichain.lab.sim.utility.action_utils import interpolate_with_distance
+import numpy as np
+
+if TYPE_CHECKING:
+ from embodichain.lab.sim.planners import MotionGenerator
+ from embodichain.lab.sim.objects import Robot
+
+
+@configclass
+class MoveActionCfg(ActionCfg):
+ name: str = "move"
+ """Name of the action, used for identification and logging."""
+
+ sample_interval: int = 50
+ """Number of waypoints to sample for the motion trajectory. Should be large enough to ensure smooth motion, but not too large to cause unnecessary computation overhead."""
+
+
+@configclass
+class GraspActionCfg(MoveActionCfg):
+ """Shared configuration for actions that involve gripper open/close motions."""
+
+ hand_open_qpos: torch.Tensor | None = None
+ """[hand_dof,] of float. Joint positions for open hand state."""
+
+ hand_close_qpos: torch.Tensor | None = None
+ """[hand_dof,] of float. Joint positions for closed hand state."""
+
+ hand_control_part: str = "hand"
+ """Name of the robot part that controls the hand joints."""
+
+ lift_height: float = 0.1
+ """Height (m) to lift the end-effector after the gripper phase."""
+
+ sample_interval: int = 80
+ """Number of waypoints for the full trajectory (approach + hand + lift/back)."""
+
+ hand_interp_steps: int = 5
+ """Number of waypoints for the gripper open/close interpolation phase."""
+
+
+class MoveAction(AtomicAction):
+ def __init__(
+ self,
+ motion_generator: MotionGenerator,
+ cfg: MoveActionCfg | None = None,
+ ):
+ """
+ Initialize the atomic action.
+ Args:
+ motion_generator: The motion generator instance to use for planning.
+ cfg: Configuration for the action.
+ """
+ super().__init__(
+ motion_generator, cfg=cfg if cfg is not None else MoveActionCfg()
+ )
+
+ self.n_envs = self.robot.get_qpos().shape[0]
+ self.arm_joint_ids = self.robot.get_joint_ids(name=self.cfg.control_part)
+ self.dof = len(self.arm_joint_ids)
+
+ def _resolve_pose_target(
+ self,
+ target: Union[ObjectSemantics, torch.Tensor],
+ *,
+ action_name: str,
+ ) -> tuple[bool, torch.Tensor]:
+ """Resolve a pose target into a batched homogeneous transform tensor."""
+ if isinstance(target, ObjectSemantics):
+ logger.log_error(
+ f"{action_name} currently does not support ObjectSemantics target. "
+ f"Please provide target pose as torch.Tensor of shape (4, 4) or "
+ f"(n_envs, 4, 4)",
+ NotImplementedError,
+ )
+ if not isinstance(target, torch.Tensor):
+ logger.log_error(
+ "Target must be either ObjectSemantics or torch.Tensor of shape "
+ f"(4, 4) or ({self.n_envs}, 4, 4)",
+ TypeError,
+ )
+
+ if target.shape == (4, 4):
+ target = target.unsqueeze(0).repeat(self.n_envs, 1, 1)
+ if target.shape != (self.n_envs, 4, 4):
+ logger.log_error(
+ f"Target tensor must have shape (4, 4) or ({self.n_envs}, 4, 4), but got {target.shape}",
+ ValueError,
+ )
+ return True, target
+
+ def _resolve_start_qpos(
+ self,
+ start_qpos: Optional[torch.Tensor],
+ arm_dof: Optional[int] = None,
+ ) -> torch.Tensor:
+ """Resolve planning start joint positions into batched arm joint positions."""
+ arm_dof = self.dof if arm_dof is None else arm_dof
+ if start_qpos is None:
+ start_qpos = self.robot.get_qpos(name=self.cfg.control_part)
+ if start_qpos.shape == (arm_dof,):
+ start_qpos = start_qpos.unsqueeze(0).repeat(self.n_envs, 1)
+ if start_qpos.shape != (self.n_envs, arm_dof):
+ logger.log_error(
+ f"start_qpos must have shape ({self.n_envs}, {arm_dof}), but got {start_qpos.shape}",
+ ValueError,
+ )
+ return start_qpos
+
+ def _compute_three_phase_waypoints(
+ self,
+ hand_interp_steps: int,
+ *,
+ first_phase_name: str,
+ third_phase_name: str,
+ first_phase_ratio: float = 0.6,
+ ) -> tuple[int, int, int]:
+ """Split total sample interval into motion, hand interpolation, and motion phases."""
+ first_phase_waypoint = int(
+ np.round(self.cfg.sample_interval - hand_interp_steps) * first_phase_ratio
+ )
+ if first_phase_waypoint < 2:
+ logger.log_error(
+ f"Not enough waypoints for {first_phase_name} trajectory. "
+ "Please increase sample_interval or decrease hand_interp_steps.",
+ ValueError,
+ )
+ second_phase_waypoint = hand_interp_steps
+ third_phase_waypoint = (
+ self.cfg.sample_interval - first_phase_waypoint - second_phase_waypoint
+ )
+ if third_phase_waypoint < 2:
+ logger.log_error(
+ f"Not enough waypoints for {third_phase_name} trajectory. "
+ "Please increase sample_interval or decrease hand_interp_steps.",
+ ValueError,
+ )
+ return first_phase_waypoint, second_phase_waypoint, third_phase_waypoint
+
+ def _build_motion_gen_options(
+ self,
+ start_qpos: torch.Tensor,
+ sample_interval: int,
+ ) -> MotionGenOptions:
+ """Build default motion generation options for an atomic action."""
+ return MotionGenOptions(
+ start_qpos=start_qpos[0],
+ control_part=self.cfg.control_part,
+ is_interpolate=True,
+ is_linear=False,
+ interpolate_position_step=0.001,
+ plan_opts=ToppraPlanOptions(
+ sample_interval=sample_interval,
+ ),
+ )
+
+ def _plan_arm_trajectory(
+ self,
+ target_states_list: list[list[PlanState]],
+ start_qpos: torch.Tensor,
+ n_waypoints: int,
+ arm_dof: Optional[int] = None,
+ ) -> tuple[bool, torch.Tensor]:
+ """Plan batched arm trajectories for all environments."""
+ arm_dof = self.dof if arm_dof is None else arm_dof
+
+ n_state = len(target_states_list[0])
+ xpos_traj = torch.zeros(
+ size=(self.n_envs, n_state, 4, 4), dtype=torch.float32, device=self.device
+ )
+ for i, target_states in enumerate(target_states_list):
+ for j, target_state in enumerate(target_states):
+ # [env_i, state_j, 4, 4]
+ xpos_traj[i, j] = target_state.xpos
+
+ trajectory = torch.zeros(
+ size=(self.n_envs, n_state, arm_dof),
+ dtype=torch.float32,
+ device=self.device,
+ )
+ qpos_seed = start_qpos
+ for j in range(n_state):
+ is_success, qpos = self.robot.compute_ik(
+ pose=xpos_traj[:, j], name=self.cfg.control_part, joint_seed=qpos_seed
+ )
+ if not is_success:
+ logger.log_warning(
+ f"Failed to compute IK for target state {j} in some environments. "
+ "The resulting trajectory may be invalid."
+ )
+ return False, trajectory
+ else:
+ trajectory[:, j] = qpos
+ qpos_seed = qpos
+ trajectory = torch.concatenate([start_qpos.unsqueeze(1), trajectory], dim=1)
+ interp_traj = interpolate_with_distance(
+ trajectory=trajectory, interp_num=n_waypoints, device=self.device
+ )
+ return True, interp_traj
+
+ def _interpolate_hand_qpos(
+ self,
+ start_hand_qpos: torch.Tensor,
+ end_hand_qpos: torch.Tensor,
+ n_waypoints: int,
+ ) -> torch.Tensor:
+ """Interpolate hand joint positions between two gripper states."""
+ weights = torch.linspace(0, 1, steps=n_waypoints, device=self.device)
+ hand_qpos_list = [
+ torch.lerp(start_hand_qpos, end_hand_qpos, weight) for weight in weights
+ ]
+ return torch.stack(hand_qpos_list, dim=0)
+
+ def execute(
+ self,
+ target: Union[ObjectSemantics, torch.Tensor],
+ start_qpos: Optional[torch.Tensor] = None,
+ **kwargs,
+ ) -> tuple[bool, torch.Tensor, list[float]]:
+ """execute pick up action
+
+ Args:
+ target (ObjectSemantics): object semantics containing grasp affordance and entity information
+ start_qpos (Optional[torch.Tensor], optional): Planning start qpos. Defaults to None.
+
+ Returns:
+ tuple[bool, torch.Tensor, list[float]]:
+ is_success,
+ trajectory of shape (n_envs, n_waypoints, dof),
+ joint_ids corresponding to trajectory
+ """
+ is_success, move_xpos = self._resolve_pose_target(
+ target, action_name=self.__class__.__name__
+ )
+ start_qpos = self._resolve_start_qpos(start_qpos)
+
+ # TODO: warning and fallback if no valid grasp pose found
+ if not is_success:
+ logger.log_warning(
+ "Failed to resolve grasp pose, using default approach pose"
+ )
+ return False, torch.empty(0), self.arm_joint_ids
+
+ target_states_list = [
+ [
+ PlanState(xpos=move_xpos[i], move_type=MoveType.EEF_MOVE),
+ ]
+ for i in range(self.n_envs)
+ ]
+ is_plan_success, trajectory = self._plan_arm_trajectory(
+ target_states_list, start_qpos, self.cfg.sample_interval
+ )
+ return is_plan_success, trajectory, self.arm_joint_ids
+
+ def validate(self, target, start_qpos=None, **kwargs):
+ # TODO: implement proper validation logic for pick up action
+ return True
+
+
+@configclass
+class PickUpActionCfg(GraspActionCfg):
+ name: str = "pick_up"
+ """Name of the action, used for identification and logging."""
+
+ pre_grasp_distance: float = 0.15
+ """Distance to offset back from the grasp pose along the approach direction to get
+ the pre-grasp pose. Should be large enough to avoid collision during approach."""
+
+ approach_direction: torch.Tensor = torch.tensor([0, 0, -1], dtype=torch.float32)
+ """Direction from which the gripper approaches the object for grasping, expressed
+ in the object local frame. Default [0, 0, -1] means approaching from above."""
+
+
+class PickUpAction(MoveAction):
+ def __init__(
+ self,
+ motion_generator: MotionGenerator,
+ cfg: PickUpActionCfg | None = None,
+ ):
+ """
+ Initialize the atomic action.
+ Args:
+ motion_generator: The motion generator instance to use for planning.
+ cfg: Configuration for the action.
+ """
+ super().__init__(
+ motion_generator, cfg=cfg if cfg is not None else PickUpActionCfg()
+ )
+ self.cfg = cfg
+ self.approach_direction = self.cfg.approach_direction.to(self.device)
+ if self.cfg.hand_open_qpos is None:
+ logger.log_error("hand_open_qpos must be specified in PickUpActionCfg")
+ if self.cfg.hand_close_qpos is None:
+ logger.log_error("hand_close_qpos must be specified in PickUpActionCfg")
+ self.hand_open_qpos = self.cfg.hand_open_qpos.to(self.device)
+ self.hand_close_qpos = self.cfg.hand_close_qpos.to(self.device)
+
+ self.hand_joint_ids = self.robot.get_joint_ids(name=self.cfg.hand_control_part)
+ self.joint_ids = self.arm_joint_ids + self.hand_joint_ids
+ self.arm_dof = len(self.arm_joint_ids)
+ self.dof = len(self.joint_ids)
+
+ def execute(
+ self,
+ target: Union[ObjectSemantics, torch.Tensor],
+ start_qpos: Optional[torch.Tensor] = None,
+ **kwargs,
+ ) -> tuple[bool, torch.Tensor, list[float]]:
+ """execute pick up action
+
+ Args:
+ target (Union[ObjectSemantics, torch.Tensor]): target object semantics or target pose for grasping
+ start_qpos (Optional[torch.Tensor], optional): Planning start qpos. Defaults to None.
+
+ Returns:
+ tuple[bool, torch.Tensor, list[float]]:
+ is_success,
+ trajectory of shape (n_envs, n_waypoints, dof),
+ joint_ids corresponding to trajectory
+ """
+
+ # Resolve grasp pose
+ if isinstance(target, ObjectSemantics):
+ is_success, grasp_xpos, open_length = self._resolve_grasp_pose(target)
+ else:
+ is_success, grasp_xpos = self._resolve_pose_target(
+ target, action_name=self.__class__.__name__
+ )
+
+ # TODO: warning and fallback if no valid grasp pose found
+ if not is_success:
+ logger.log_warning(
+ "Failed to resolve grasp pose, using default approach pose"
+ )
+ return False, torch.empty(0), self.joint_ids
+
+ # Compute pre-grasp pose
+ # TODO: only for parallel gripper, approach in negative grasp z direction
+ grasp_z = grasp_xpos[:, :3, 2]
+ pre_grasp_xpos = self._apply_offset(
+ pose=grasp_xpos,
+ offset=-grasp_z * self.cfg.pre_grasp_distance,
+ )
+ # Compute lift pose
+ start_qpos = self._resolve_start_qpos(start_qpos, self.arm_dof)
+
+ # compute waypoint number for each phase
+ n_approach_waypoint, n_close_waypoint, n_lift_waypoint = (
+ self._compute_three_phase_waypoints(
+ self.cfg.hand_interp_steps,
+ first_phase_name="approach",
+ third_phase_name="lift",
+ )
+ )
+
+ # get pick trajectory
+ target_states_list = [
+ [
+ PlanState(xpos=pre_grasp_xpos[i], move_type=MoveType.EEF_MOVE),
+ PlanState(xpos=grasp_xpos[i], move_type=MoveType.EEF_MOVE),
+ ]
+ for i in range(self.n_envs)
+ ]
+ pick_trajectory = torch.zeros(
+ size=(self.n_envs, n_approach_waypoint, self.dof),
+ dtype=torch.float32,
+ device=self.device,
+ )
+ is_success, plan_traj = self._plan_arm_trajectory(
+ target_states_list,
+ start_qpos,
+ n_approach_waypoint,
+ self.arm_dof,
+ )
+ if not is_success:
+ logger.log_warning("Failed to plan approach trajectory.")
+ return False, pick_trajectory, self.joint_ids
+ pick_trajectory[:, :, : self.arm_dof] = plan_traj
+ # Padding hand open qpos to pick trajectory
+ pick_trajectory[:, :, self.arm_dof :] = self.hand_open_qpos
+
+ # get hand closing trajectory
+ grasp_qpos = pick_trajectory[
+ :, -1, : self.arm_dof
+ ] # Assuming the last point of pick trajectory is the grasp pose
+ hand_close_path = self._interpolate_hand_qpos(
+ self.hand_open_qpos,
+ self.hand_close_qpos,
+ n_close_waypoint,
+ )
+ hand_close_trajectory = torch.zeros(
+ size=(self.n_envs, n_close_waypoint, self.dof),
+ device=self.device,
+ )
+ hand_close_trajectory[:, :, : self.arm_dof] = grasp_qpos
+ hand_close_trajectory[:, :, self.arm_dof :] = hand_close_path
+
+ # get lift trajectory
+ lift_trajectory = torch.zeros(
+ size=(self.n_envs, n_lift_waypoint, self.dof),
+ dtype=torch.float32,
+ device=self.device,
+ )
+ # lift_xpos = self._compute_lift_xpos(grasp_xpos)
+ lift_xpos = self._apply_offset(
+ pose=grasp_xpos,
+ offset=torch.tensor([0, 0, 1], device=self.device) * self.cfg.lift_height,
+ )
+ target_states_list = [
+ [
+ PlanState(xpos=lift_xpos[i], move_type=MoveType.EEF_MOVE),
+ ]
+ for i in range(self.n_envs)
+ ]
+ is_success, plan_traj = self._plan_arm_trajectory(
+ target_states_list,
+ grasp_qpos,
+ n_lift_waypoint,
+ self.arm_dof,
+ )
+ if not is_success:
+ logger.log_warning("Failed to plan lift trajectory.")
+ return False, lift_trajectory, self.joint_ids
+ lift_trajectory[:, :, : self.arm_dof] = plan_traj
+ # padding hand close qpos to lift trajectory
+ lift_trajectory[:, :, self.arm_dof :] = self.hand_close_qpos
+
+ # concatenate trajectories
+ trajectory = torch.cat(
+ [pick_trajectory, hand_close_trajectory, lift_trajectory], dim=1
+ )
+ return True, trajectory, self.joint_ids
+
+ def _resolve_grasp_pose(
+ self, semantics: ObjectSemantics
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ if not isinstance(semantics.affordance, AntipodalAffordance):
+ logger.log_error(
+ "Grasp pose affordance must be of type AntipodalAffordance"
+ )
+ if semantics.entity is None:
+ logger.log_error(
+ "ObjectSemantics must be associated with an entity to get object pose"
+ )
+ obj_poses = semantics.entity.get_local_pose(to_matrix=True)
+
+ is_success, grasp_xpos, open_length = semantics.affordance.get_best_grasp_poses(
+ obj_poses=obj_poses, approach_direction=self.approach_direction
+ )
+ return is_success, grasp_xpos, open_length
+
+ def validate(self, target, start_qpos=None, **kwargs):
+ # TODO: implement proper validation logic for pick up action
+ return True
+
+
+@configclass
+class PlaceActionCfg(GraspActionCfg):
+ name: str = "place"
+ """Name of the action, used for identification and logging."""
+
+
+class PlaceAction(MoveAction):
+ def __init__(
+ self,
+ motion_generator: MotionGenerator,
+ cfg: PlaceActionCfg | None = None,
+ ):
+ """
+ Initialize the atomic action.
+ Args:
+ motion_generator: The motion generator instance to use for planning.
+ cfg: Configuration for the action.
+ """
+ super().__init__(
+ motion_generator, cfg=cfg if cfg is not None else PlaceActionCfg()
+ )
+ self.cfg = cfg
+ if self.cfg.hand_open_qpos is None:
+ logger.log_error("hand_open_qpos must be specified in PlaceActionCfg")
+ if self.cfg.hand_close_qpos is None:
+ logger.log_error("hand_close_qpos must be specified in PlaceActionCfg")
+ self.hand_open_qpos = self.cfg.hand_open_qpos.to(self.device)
+ self.hand_close_qpos = self.cfg.hand_close_qpos.to(self.device)
+
+ self.hand_joint_ids = self.robot.get_joint_ids(name=self.cfg.hand_control_part)
+ self.joint_ids = self.arm_joint_ids + self.hand_joint_ids
+ self.arm_dof = len(self.arm_joint_ids)
+ self.dof = len(self.joint_ids)
+
+ def execute(
+ self,
+ target: Union[ObjectSemantics, torch.Tensor],
+ start_qpos: Optional[torch.Tensor] = None,
+ **kwargs,
+ ) -> tuple[bool, torch.Tensor, list[float]]:
+ """execute pick up action
+
+ Args:
+ target (ObjectSemantics): object semantics containing grasp affordance and entity information
+ start_qpos (Optional[torch.Tensor], optional): Planning start qpos. Defaults to None.
+
+ Returns:
+ tuple[bool, torch.Tensor, list[float]]:
+ is_success,
+ trajectory of shape (n_envs, n_waypoints, dof),
+ joint_ids corresponding to trajectory
+ """
+ is_success, place_xpos = self._resolve_pose_target(
+ target, action_name=self.__class__.__name__
+ )
+ start_qpos = self._resolve_start_qpos(start_qpos, self.arm_dof)
+
+ # TODO: warning and fallback if no valid grasp pose found
+ if not is_success:
+ logger.log_warning(
+ "Failed to resolve grasp pose, using default approach pose"
+ )
+ return False, torch.empty(0), self.joint_ids
+
+ # compute waypoint number for each phase
+ n_down_waypoint, n_open_waypoint, n_lift_waypoint = (
+ self._compute_three_phase_waypoints(
+ self.cfg.hand_interp_steps,
+ first_phase_name="approach",
+ third_phase_name="lift",
+ )
+ )
+
+ down_trajectory = torch.zeros(
+ size=(self.n_envs, n_down_waypoint, self.dof),
+ dtype=torch.float32,
+ device=self.device,
+ )
+ lift_xpos = self._apply_offset(
+ pose=place_xpos,
+ offset=torch.tensor([0, 0, 1], device=self.device) * self.cfg.lift_height,
+ )
+ target_states_list = [
+ [
+ PlanState(xpos=lift_xpos[i], move_type=MoveType.EEF_MOVE),
+ PlanState(xpos=place_xpos[i], move_type=MoveType.EEF_MOVE),
+ ]
+ for i in range(self.n_envs)
+ ]
+ is_success, plan_traj = self._plan_arm_trajectory(
+ target_states_list,
+ start_qpos,
+ n_down_waypoint,
+ self.arm_dof,
+ )
+ if not is_success:
+ logger.log_warning("Failed to plan down trajectory.")
+ return False, down_trajectory, self.joint_ids
+ down_trajectory[:, :, : self.arm_dof] = plan_traj
+ # Padding hand open qpos to pick trajectory
+ down_trajectory[:, :, self.arm_dof :] = self.hand_close_qpos
+
+ # get hand closing trajectory
+ reach_qpos = down_trajectory[
+ :, -1, : self.arm_dof
+ ] # Assuming the last point of pick trajectory is the grasp pose
+ hand_open_path = self._interpolate_hand_qpos(
+ self.hand_close_qpos,
+ self.hand_open_qpos,
+ n_open_waypoint,
+ )
+ hand_open_trajectory = torch.zeros(
+ size=(self.n_envs, n_open_waypoint, self.dof),
+ device=self.device,
+ )
+ hand_open_trajectory[:, :, : self.arm_dof] = reach_qpos
+ hand_open_trajectory[:, :, self.arm_dof :] = hand_open_path
+
+ # get lift trajectory
+ back_trajectory = torch.zeros(
+ size=(self.n_envs, n_lift_waypoint, self.dof),
+ dtype=torch.float32,
+ device=self.device,
+ )
+ target_states_list = [
+ [
+ PlanState(xpos=lift_xpos[i], move_type=MoveType.EEF_MOVE),
+ ]
+ for i in range(self.n_envs)
+ ]
+ is_success, plan_traj = self._plan_arm_trajectory(
+ target_states_list,
+ reach_qpos,
+ n_lift_waypoint,
+ self.arm_dof,
+ )
+ if not is_success:
+ logger.log_warning("Failed to plan back trajectory.")
+ return False, back_trajectory, self.joint_ids
+ back_trajectory[:, :, : self.arm_dof] = plan_traj
+ # padding hand open qpos to back trajectory
+ back_trajectory[:, :, self.arm_dof :] = self.hand_open_qpos
+
+ # concatenate trajectories
+ trajectory = torch.cat(
+ [down_trajectory, hand_open_trajectory, back_trajectory], dim=1
+ )
+ return True, trajectory, self.joint_ids
+
+ def validate(self, target, start_qpos=None, **kwargs):
+ # TODO: implement proper validation logic for pick up action
+ return True
diff --git a/embodichain/lab/sim/atomic_actions/core.py b/embodichain/lab/sim/atomic_actions/core.py
new file mode 100644
index 00000000..08a22fc5
--- /dev/null
+++ b/embodichain/lab/sim/atomic_actions/core.py
@@ -0,0 +1,468 @@
+# ----------------------------------------------------------------------------
+# Copyright (c) 2021-2026 DexForce Technology Co., Ltd.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ----------------------------------------------------------------------------
+
+from __future__ import annotations
+
+import torch
+from abc import ABC, abstractmethod
+from dataclasses import dataclass, field
+from typing import Any, Dict, List, Optional, Union, TYPE_CHECKING
+
+from embodichain.lab.sim.planners import PlanResult, PlanState, MoveType
+from embodichain.utils import configclass
+
+from embodichain.toolkits.graspkit.pg_grasp import (
+ GraspGenerator,
+ GraspGeneratorCfg,
+)
+from embodichain.toolkits.graspkit.pg_grasp.gripper_collision_checker import (
+ GripperCollisionCfg,
+)
+from embodichain.lab.sim.common import BatchEntity
+from embodichain.utils import logger
+
+if TYPE_CHECKING:
+ from embodichain.lab.sim.planners import MotionGenerator, MotionGenOptions
+ from embodichain.lab.sim.objects import Robot
+
+
+# =============================================================================
+# Affordance Classes
+# =============================================================================
+
+
+@dataclass
+class Affordance:
+ """Base class for affordance data.
+
+ Affordance represents interaction possibilities for an object.
+ This is the base class for specific affordance types.
+ """
+
+ object_label: str = ""
+ """Label of the object this affordance belongs to."""
+
+ geometry: Dict[str, Any] = field(default_factory=dict)
+ """Geometry dictionary shared with ObjectSemantics.
+
+ The mesh payload is expected to be stored in:
+ - ``mesh_vertices``: torch.Tensor with shape [N, 3]
+ - ``mesh_triangles``: torch.Tensor with shape [M, 3]
+ """
+
+ custom_config: Dict[str, Any] = field(default_factory=dict)
+ """User-defined configuration payload for affordance creation and usage."""
+
+ @property
+ def mesh_vertices(self) -> torch.Tensor | None:
+ """Get mesh vertices from geometry.
+
+ Returns:
+ Mesh vertices tensor [N, 3], or None if unavailable.
+
+ Raises:
+ TypeError: If ``mesh_vertices`` exists but is not a torch tensor.
+ """
+ vertices = self.geometry.get("mesh_vertices")
+ if vertices is None:
+ return None
+ if not isinstance(vertices, torch.Tensor):
+ raise TypeError("geometry['mesh_vertices'] must be a torch.Tensor")
+ return vertices
+
+ @property
+ def mesh_triangles(self) -> torch.Tensor | None:
+ """Get mesh triangles from geometry.
+
+ Returns:
+ Mesh triangle index tensor [M, 3], or None if unavailable.
+
+ Raises:
+ TypeError: If ``mesh_triangles`` exists but is not a torch tensor.
+ """
+ triangles = self.geometry.get("mesh_triangles")
+ if triangles is None:
+ return None
+ if not isinstance(triangles, torch.Tensor):
+ raise TypeError("geometry['mesh_triangles'] must be a torch.Tensor")
+ return triangles
+
+ def set_custom_config(self, key: str, value: Any) -> None:
+ """Set a custom affordance configuration value."""
+ self.custom_config[key] = value
+
+ def get_custom_config(self, key: str, default: Any = None) -> Any:
+ """Get a custom affordance configuration value."""
+ return self.custom_config.get(key, default)
+
+ def get_batch_size(self) -> int:
+ """Return the batch size of this affordance data."""
+ return 1
+
+
+@dataclass
+class AntipodalAffordance(Affordance):
+ generator: GraspGenerator | None = None
+ """Grasp generator instance, initialized lazily when needed."""
+
+ force_reannotate: bool = False
+ """Whether to force re-annotation of grasp generator on each access."""
+
+ is_draw_grasp_xpos: bool = False
+ """Whether to visualize grasp poses in the simulator."""
+
+ def _init_generator(self):
+ if (
+ self.geometry.get("mesh_vertices", None) is None
+ or self.geometry.get("mesh_triangles", None) is None
+ ):
+ logger.log_error(
+ "Mesh vertices and triangles must be provided in geometry to initialize AntipodalAffordance."
+ )
+ self.generator = GraspGenerator(
+ vertices=self.geometry.get("mesh_vertices"),
+ triangles=self.geometry.get("mesh_triangles"),
+ cfg=self.custom_config.get("generator_cfg", None),
+ gripper_collision_cfg=self.custom_config.get("gripper_collision_cfg", None),
+ )
+ if self.force_reannotate:
+ self.generator.annotate()
+ else:
+ if self.generator._hit_point_pairs is None:
+ self.generator.annotate()
+
+ def get_best_grasp_poses(
+ self,
+ obj_poses: torch.Tensor,
+ approach_direction: torch.Tensor = torch.tensor(
+ [0, 0, -1], dtype=torch.float32
+ ),
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ if self.generator is None:
+ self._init_generator()
+
+ grasp_xpos_list = []
+ is_success_list = []
+ open_length_list = []
+ for i, obj_pose in enumerate(obj_poses):
+ is_success, grasp_xpos, open_length = self.generator.get_grasp_poses(
+ obj_pose, approach_direction
+ )
+ if is_success:
+ grasp_xpos_list.append(grasp_xpos.unsqueeze(0))
+ else:
+ logger.log_warning(f"No valid grasp pose found for {i}-th object.")
+ grasp_xpos_list.append(
+ torch.eye(
+ 4, dtype=torch.float32, device=self.generator.device
+ ).unsqueeze(0)
+ ) # Default to identity pose if no grasp found
+ is_success_list.append(is_success)
+ open_length_list.append(open_length)
+ is_success = torch.tensor(
+ is_success_list, dtype=torch.bool, device=self.generator.device
+ )
+ grasp_xpos = torch.concatenate(grasp_xpos_list, dim=0) # [B, 4, 4]
+ open_length = torch.tensor(
+ open_length_list, dtype=torch.float32, device=self.generator.device
+ )
+ if self.is_draw_grasp_xpos:
+ self._draw_grasp_xpos(grasp_xpos, open_length)
+ return is_success, grasp_xpos, open_length
+
+ def _draw_grasp_xpos(self, grasp_xpos: torch.Tensor, open_length: torch.Tensor):
+ sim = SimulationManager.get_instance()
+ axis_xpos = []
+ for i in range(grasp_xpos.shape[0]):
+ axis_xpos.append(grasp_xpos[i].to("cpu").numpy())
+ sim.draw_marker(
+ cfg=MarkerCfg(
+ name="grasp_xpos",
+ axis_xpos=axis_xpos,
+ axis_len=0.05,
+ )
+ )
+
+
+@dataclass
+class InteractionPoints(Affordance):
+ """Interaction points affordance containing a batch of 3D positions.
+
+ Interaction points define specific locations on an object surface
+ that can be used for contact-based interactions (pushing, poking,
+ touching) rather than full grasping.
+ """
+
+ points: torch.Tensor = field(default_factory=lambda: torch.zeros(1, 3))
+ """Batch of 3D interaction points with shape [B, 3].
+
+ Each point is a 3D coordinate in the object's local coordinate frame.
+ """
+
+ normals: torch.Tensor | None = None
+ """Optional surface normals at each interaction point with shape [B, 3].
+
+ Normals indicate the surface orientation at each point,
+ useful for determining approach directions.
+ """
+
+ point_types: List[str] = field(default_factory=list)
+ """Optional labels for each point's interaction type.
+
+ Examples: "push", "poke", "touch", "pinch"
+ """
+
+ def get_points_by_type(self, point_type: str) -> torch.Tensor | None:
+ """Get points by their interaction type.
+
+ Args:
+ point_type: Type of interaction (e.g., "push", "poke")
+
+ Returns:
+ Tensor of points if found, None otherwise
+ """
+ if point_type in self.point_types:
+ indices = [i for i, t in enumerate(self.point_types) if t == point_type]
+ return self.points[indices]
+ return None
+
+ def get_batch_size(self) -> int:
+ """Return the number of interaction points in this affordance."""
+ return self.points.shape[0]
+
+ def get_approach_direction(self, point_idx: int) -> torch.Tensor:
+ """Get recommended approach direction for a given point.
+
+ Args:
+ point_idx: Index of the point
+
+ Returns:
+ 3D approach direction vector (normalized)
+ """
+ if self.normals is not None:
+ # Approach from the opposite direction of the surface normal
+ return -self.normals[point_idx]
+ # Default: approach from positive z
+ return torch.tensor(
+ [0, 0, 1], dtype=self.points.dtype, device=self.points.device
+ )
+
+
+# =============================================================================
+# ObjectSemantics
+# =============================================================================
+
+
+@dataclass
+class ObjectSemantics:
+ """Semantic information about interaction target.
+
+ This class encapsulates all semantic and geometric information about
+ an object needed for intelligent interaction planning.
+ """
+
+ affordance: Affordance
+ """Affordance data (GraspPose, InteractionPoints, etc.)."""
+
+ geometry: Dict[str, Any]
+ """Geometric information including bounding box, mesh data."""
+
+ properties: Dict[str, Any] = field(default_factory=dict)
+ """Physical properties: mass, friction, etc."""
+
+ label: str = "none"
+ """Object category label (e.g., 'apple', 'bottle')."""
+
+ entity: BatchEntity | None = None
+ """Optional reference to the underlying simulation entity representing this object."""
+
+ def __post_init__(self) -> None:
+ """Bind affordance metadata to this semantic object.
+
+ The affordance shares the same geometry dict instance as
+ ``ObjectSemantics.geometry`` so mesh tensors are authored in one place.
+ """
+ self.affordance.object_label = self.label
+ self.affordance.geometry = self.geometry
+
+
+# =============================================================================
+# ActionCfg and AtomicAction
+# =============================================================================
+
+
+@configclass
+class ActionCfg:
+ """Configuration for atomic actions."""
+
+ name: str = "default"
+ """Name of the action, used for identification and logging."""
+
+ control_part: str = "arm"
+ """Control part name for the action."""
+
+ interpolation_type: str = "linear"
+ """Interpolation type: 'linear', 'cubic'."""
+
+ velocity_limit: Optional[float] = None
+ """Optional velocity limit for the motion."""
+
+ acceleration_limit: Optional[float] = None
+ """Optional acceleration limit for the motion."""
+
+
+class AtomicAction(ABC):
+ """Abstract base class for atomic actions.
+
+ All atomic actions use PlanResult from embodichain.lab.sim.planners
+ as the return type for execute() method, ensuring consistency with
+ the existing motion planning infrastructure.
+ """
+
+ def __init__(
+ self,
+ motion_generator: MotionGenerator,
+ cfg: ActionCfg = ActionCfg(),
+ ):
+ """
+ Initialize the atomic action.
+ Args:
+ motion_generator: The motion generator instance to use for planning.
+ cfg: Configuration for the action.
+ """
+ self.motion_generator = motion_generator
+ self.cfg = cfg
+ self.robot = motion_generator.robot
+ self.control_part = cfg.control_part
+ self.device = self.robot.device
+
+ @abstractmethod
+ def execute(
+ self,
+ target: Union[torch.Tensor, ObjectSemantics],
+ start_qpos: Optional[torch.Tensor] = None,
+ **kwargs,
+ ) -> tuple[bool, torch.Tensor, list[float]]:
+ """execute pick up action
+
+ Args:
+ target (ObjectSemantics): object semantics containing grasp affordance and entity information
+ start_qpos (Optional[torch.Tensor], optional): Planning start qpos. Defaults to None.
+
+ Returns:
+ tuple[bool, torch.Tensor, list[float]]:
+ is_success,
+ trajectory of shape (n_envs, n_waypoints, dof),
+ joint_ids corresponding to trajectory
+ """
+
+ @abstractmethod
+ def validate(
+ self,
+ target: Union[torch.Tensor, ObjectSemantics],
+ start_qpos: Optional[torch.Tensor] = None,
+ **kwargs,
+ ) -> bool:
+ """Validate if the action is feasible without executing.
+
+ This method performs a quick feasibility check (e.g., IK solvability)
+ without generating a full trajectory.
+
+ Returns:
+ True if action appears feasible, False otherwise
+ """
+ pass
+
+ def _ik_solve(
+ self, target_pose: torch.Tensor, qpos_seed: Optional[torch.Tensor] = None
+ ) -> torch.Tensor:
+ """Solve IK for target pose.
+
+ Args:
+ target_pose: Target pose [4, 4]
+ qpos_seed: Seed configuration [DOF]
+
+ Returns:
+ Joint configuration [DOF]
+
+ Raises:
+ RuntimeError: If IK fails to find a solution
+ """
+ if qpos_seed is None:
+ qpos_seed = self.robot.get_qpos()
+
+ success, qpos = self.robot.compute_ik(
+ pose=target_pose.unsqueeze(0),
+ qpos_seed=qpos_seed.unsqueeze(0),
+ name=self.control_part,
+ )
+
+ if not success.all():
+ raise RuntimeError(f"IK failed for target pose: {target_pose}")
+
+ return qpos.squeeze(0)
+
+ def _fk_compute(self, qpos: torch.Tensor) -> torch.Tensor:
+ """Compute forward kinematics.
+
+ Args:
+ qpos: Joint configuration [DOF] or [B, DOF]
+
+ Returns:
+ End-effector pose [4, 4] or [B, 4, 4]
+ """
+ if qpos.dim() == 1:
+ qpos = qpos.unsqueeze(0)
+
+ xpos = self.robot.compute_fk(
+ qpos=qpos,
+ name=self.control_part,
+ to_matrix=True,
+ )
+
+ return xpos.squeeze(0) if xpos.shape[0] == 1 else xpos
+
+ def _apply_offset(self, pose: torch.Tensor, offset: torch.Tensor) -> torch.Tensor:
+ """Apply offset to pose in local frame.
+
+ Args:
+ pose: Base pose [N, 4, 4]
+ offset: Offset in local frame [N, 3] or [3]
+
+ Returns:
+ Pose with offset applied [N, 4, 4]
+ """
+ if not len(pose.shape) == 3 or pose.shape[1:] != (4, 4):
+ logger.log_error("pose must have shape [N, 4, 4]")
+ if len(offset.shape) == 1:
+ offset = offset.unsqueeze(0)
+ if not len(offset.shape) == 2 or offset.shape[1] != 3:
+ logger.log_error("offset must have shape [N, 3] or [3]")
+ result = pose.clone()
+ result[:, :3, 3] += offset
+ return result
+
+ def plan_trajectory(
+ self,
+ target_states: List[PlanState],
+ options: Optional["MotionGenOptions"] = None,
+ ) -> "PlanResult":
+ """Plan trajectory using motion generator."""
+ from embodichain.lab.sim.planners import MotionGenOptions
+
+ if options is None:
+ options = MotionGenOptions(control_part=self.control_part)
+ return self.motion_generator.generate(target_states, options)
diff --git a/embodichain/lab/sim/atomic_actions/engine.py b/embodichain/lab/sim/atomic_actions/engine.py
new file mode 100644
index 00000000..15b868a8
--- /dev/null
+++ b/embodichain/lab/sim/atomic_actions/engine.py
@@ -0,0 +1,340 @@
+# ----------------------------------------------------------------------------
+# Copyright (c) 2021-2026 DexForce Technology Co., Ltd.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ----------------------------------------------------------------------------
+
+from __future__ import annotations
+
+import torch
+from typing import Any, Dict, List, Optional, Type, Union, TYPE_CHECKING
+
+from embodichain.lab.sim.planners import PlanResult
+from embodichain.utils import logger
+from .core import AtomicAction, ObjectSemantics, ActionCfg
+
+if TYPE_CHECKING:
+ from embodichain.lab.sim.planners import MotionGenerator
+ from embodichain.lab.sim.objects import Robot
+
+
+# =============================================================================
+# Global Action Registry
+# =============================================================================
+
+_global_action_registry: Dict[str, Type[AtomicAction]] = {}
+_global_action_configs: Dict[str, Type[ActionCfg]] = {}
+
+
+def register_action(
+ name: str,
+ action_class: Type[AtomicAction],
+ config_class: Optional[Type[ActionCfg]] = None,
+) -> None:
+ """Register a custom atomic action class globally.
+
+ This function allows registration of custom action types that can then
+ be instantiated by the AtomicActionEngine.
+
+ Args:
+ name: Unique identifier for the action type
+ action_class: The AtomicAction subclass to register
+ config_class: Optional configuration class for the action
+
+ Example:
+ >>> class MyCustomAction(AtomicAction):
+ ... def execute(self, target, **kwargs):
+ ... # Implementation
+ ... pass
+ ... def validate(self, target, **kwargs):
+ ... return True
+ >>> register_action("my_custom", MyCustomAction)
+ """
+ _global_action_registry[name] = action_class
+ if config_class is not None:
+ _global_action_configs[name] = config_class
+
+
+def unregister_action(name: str) -> None:
+ """Unregister an action type.
+
+ Args:
+ name: The action type identifier to remove
+ """
+ _global_action_registry.pop(name, None)
+ _global_action_configs.pop(name, None)
+
+
+def get_registered_actions() -> Dict[str, Type[AtomicAction]]:
+ """Get all registered action types.
+
+ Returns:
+ Dictionary mapping action names to their classes
+ """
+ return _global_action_registry.copy()
+
+
+# =============================================================================
+# Semantic Analyzer
+# =============================================================================
+
+
+class SemanticAnalyzer:
+ """Analyzes objects and provides ObjectSemantics for atomic actions."""
+
+ def __init__(self):
+ self._object_cache: Dict[str, ObjectSemantics] = {}
+
+ def analyze(
+ self,
+ label: str,
+ geometry: Optional[Dict[str, Any]] = None,
+ custom_config: Optional[Dict[str, Any]] = None,
+ use_cache: bool = True,
+ ) -> ObjectSemantics:
+ """Analyze object by label and return ObjectSemantics.
+
+ This is a placeholder implementation that should be extended
+ with actual object detection and affordance computation.
+
+ Args:
+ label: Object category label (e.g., "apple", "bottle")
+ geometry: Optional geometry payload. Can include mesh tensors:
+ ``mesh_vertices`` [N, 3] and ``mesh_triangles`` [M, 3].
+ custom_config: Optional user-defined affordance configuration.
+ use_cache: Whether to use cached semantics when available.
+
+ Returns:
+ ObjectSemantics containing affordance data
+ """
+ # Only use cache for default analyze path
+ if (
+ use_cache
+ and geometry is None
+ and custom_config is None
+ and label in self._object_cache
+ ):
+ return self._object_cache[label]
+
+ # Create default semantics (placeholder implementation)
+ from .core import AntipodalAffordance
+
+ # Generate default grasp poses based on object type
+ default_poses = torch.eye(4).unsqueeze(0)
+ default_poses[0, 2, 3] = 0.1 # Default offset
+
+ default_geometry: Dict[str, Any] = {"bounding_box": [0.1, 0.1, 0.1]}
+ if geometry is not None:
+ default_geometry.update(geometry)
+
+ grasp_affordance = AntipodalAffordance(
+ object_label=label,
+ custom_config=custom_config or {},
+ )
+
+ semantics = ObjectSemantics(
+ label=label,
+ affordance=grasp_affordance,
+ geometry=default_geometry,
+ properties={"mass": 1.0, "friction": 0.5},
+ )
+
+ # Cache only default path
+ if use_cache and geometry is None and custom_config is None:
+ self._object_cache[label] = semantics
+ return semantics
+
+ def clear_cache(self) -> None:
+ """Clear the object semantics cache."""
+ self._object_cache.clear()
+
+
+# =============================================================================
+# Atomic Action Engine
+# =============================================================================
+
+
+class AtomicActionEngine:
+ """Central engine for managing and executing atomic actions."""
+
+ def __init__(
+ self,
+ motion_generator: "MotionGenerator",
+ actions_cfg_list: Optional[List[ActionCfg]] = None,
+ ):
+ self.motion_generator = motion_generator
+ self.robot = self.motion_generator.robot
+ self.device = self.motion_generator.device
+
+ # Semantic analyzer for object understanding
+ self._semantic_analyzer = SemanticAnalyzer()
+
+ # Initialize default actions
+ self._actions: Dict[str, AtomicAction] = self._init_actions(actions_cfg_list)
+
+ def _init_actions(
+ self, actions_cfg_list: Optional[List[ActionCfg]] = None
+ ) -> Dict[str, "AtomicAction"]:
+ actions: Dict[str, AtomicAction] = {}
+ from .actions import MoveAction, PickUpAction, PlaceAction
+
+ builtin_action_map: Dict[str, Type[AtomicAction]] = {
+ "move": MoveAction,
+ "pick_up": PickUpAction,
+ "place": PlaceAction,
+ }
+ if actions_cfg_list is not None:
+ for cfg in actions_cfg_list:
+ action_class = builtin_action_map.get(
+ cfg.name
+ ) or _global_action_registry.get(cfg.name)
+ if action_class is None:
+ logger.log_error(f"Unknown action name in config: {cfg.name}")
+ continue
+ instance = action_class(motion_generator=self.motion_generator, cfg=cfg)
+ actions[cfg.name] = instance
+ return actions
+
+ def execute_static(
+ self,
+ target_list: List[Union[torch.Tensor, str, ObjectSemantics, Dict[str, Any]]],
+ ) -> tuple[bool, torch.Tensor]:
+ """Execute a sequence of actions to target poses.
+
+ Each element in ``target_list`` corresponds to an action in the order they
+ were registered via ``actions_cfg_list``.
+ """
+ action_names = list(self._actions.keys())
+ if len(target_list) != len(action_names):
+ logger.log_error(
+ f"Length of target_list ({len(target_list)}) must match number of actions ({len(action_names)})."
+ )
+ start_qpos = self.motion_generator.robot.get_qpos()
+ n_envs = start_qpos.shape[0]
+ all_dof = self.motion_generator.robot.dof
+ all_trajectory = torch.empty(
+ size=(n_envs, 0, all_dof), dtype=torch.float32, device=self.device
+ )
+
+ for action_name, target in zip(action_names, target_list):
+ atom_action = self._actions[action_name]
+ target = self._resolve_target(target)
+ control_part = atom_action.control_part
+ arm_joint_ids = self.motion_generator.robot.get_joint_ids(name=control_part)
+ start_qpos_part = start_qpos[:, arm_joint_ids]
+ is_success, traj, joint_ids = atom_action.execute(
+ target=target, start_qpos=start_qpos_part
+ )
+ if not is_success:
+ return False, all_trajectory
+ n_waypoints = traj.shape[1]
+
+ traj_full = torch.zeros(
+ size=(n_envs, n_waypoints, all_dof),
+ dtype=torch.float32,
+ device=self.device,
+ )
+ traj_full[:, :] = start_qpos
+ traj_full[:, :, joint_ids] = traj
+ all_trajectory = torch.cat((all_trajectory, traj_full), dim=1)
+ # update start qpos for the next action
+ start_qpos[:, joint_ids] = traj[:, -1, :]
+ return True, all_trajectory
+
+ def validate(
+ self,
+ action_name: str,
+ target: Union[torch.Tensor, str, ObjectSemantics, Dict[str, Any]],
+ **kwargs,
+ ) -> bool:
+ """Validate if a named action is feasible without executing."""
+ if action_name not in self._actions:
+ logger.log_warning(f"Action '{action_name}' is not registered.")
+ return False
+
+ action = self._actions[action_name]
+ target = self._resolve_target(target)
+ return action.validate(target, **kwargs)
+
+ def _resolve_target(
+ self,
+ target: Union[torch.Tensor, str, ObjectSemantics, Dict[str, Any]],
+ ) -> Union[torch.Tensor, ObjectSemantics]:
+ """Resolve user target input into tensor pose or ObjectSemantics.
+
+ Supports the convenience dict format in ``execute`` and ``validate``.
+ """
+ if isinstance(target, torch.Tensor):
+ return target
+
+ if isinstance(target, ObjectSemantics):
+ return target
+
+ if isinstance(target, str):
+ return self._semantic_analyzer.analyze(target)
+
+ if isinstance(target, dict):
+ if "pose" in target:
+ pose = target["pose"]
+ if not isinstance(pose, torch.Tensor):
+ raise TypeError("target['pose'] must be a torch.Tensor")
+ return pose
+
+ if "semantics" in target:
+ semantics = target["semantics"]
+ if not isinstance(semantics, ObjectSemantics):
+ raise TypeError(
+ "target['semantics'] must be an ObjectSemantics instance"
+ )
+ return semantics
+
+ label = target.get("label")
+ if label is None:
+ raise ValueError(
+ "Dict target must provide 'label', or use 'pose'/'semantics'."
+ )
+ if not isinstance(label, str):
+ raise TypeError("target['label'] must be a string")
+
+ geometry = target.get("geometry")
+ custom_config = target.get("custom_config")
+ use_cache = target.get("use_cache", True)
+
+ semantics = self._semantic_analyzer.analyze(
+ label=label,
+ geometry=geometry,
+ custom_config=custom_config,
+ use_cache=use_cache,
+ )
+
+ properties = target.get("properties")
+ if properties is not None:
+ semantics.properties.update(properties)
+
+ uid = target.get("uid")
+ if uid is not None:
+ semantics.uid = uid
+
+ return semantics
+
+ raise TypeError(
+ "target must be torch.Tensor, str, ObjectSemantics, or Dict[str, Any]"
+ )
+
+ def get_semantic_analyzer(self) -> SemanticAnalyzer:
+ """Get the semantic analyzer for object understanding."""
+ return self._semantic_analyzer
+
+ def set_semantic_analyzer(self, analyzer: SemanticAnalyzer) -> None:
+ """Set a custom semantic analyzer."""
+ self._semantic_analyzer = analyzer
diff --git a/embodichain/lab/sim/cfg.py b/embodichain/lab/sim/cfg.py
index 72a755f2..0b10a725 100644
--- a/embodichain/lab/sim/cfg.py
+++ b/embodichain/lab/sim/cfg.py
@@ -23,6 +23,7 @@
from dataclasses import field, MISSING
from dexsim.types import (
+ Renderer,
PhysicalAttr,
ActorType,
AxisArrowType,
@@ -40,6 +41,40 @@
from .shapes import ShapeCfg, MeshCfg
+# Global default renderer settings for simulation
+DEFAULT_RENDERER: Literal["hybrid", "fast-rt", "rt"] = "hybrid"
+
+
+@configclass
+class RenderCfg:
+ renderer: Literal["hybrid", "fast-rt", "rt"] = "hybrid"
+ """Renderer backend to use for the simulation. Options are 'hybrid', 'fast-rt', and 'rt'.
+
+ Note:
+ - 'hybrid' uses ray tracing for shadows and reflections while keeping rasterization for primary rendering,
+ providing a balance between performance and visual quality.
+ - 'fast-rt' is a fully ray-traced renderer for maximum visual fidelity, but may have higher computational cost.
+ - 'rt' is an offline ray-traced renderer for maximum visual fidelity, suitable for high-quality rendering tasks.
+ """
+
+ enable_denoiser: bool = True
+ """Whether to enable denoising. Only valid when renderer is 'hybrid' or 'fast-rt'."""
+
+ spp: int = 64
+ """Samples per pixel for ray tracing rendering. This parameter is only valid when renderer is 'hybrid' or 'fast-rt' and enable_denoiser is False."""
+
+ def to_dexsim_flags(self):
+ if self.renderer == "hybrid":
+ return Renderer.HYBRID
+ elif self.renderer == "fast-rt":
+ return Renderer.FASTRT
+ elif self.renderer == "rt":
+ return Renderer.OFFLINERT
+ else:
+ logger.log_error(
+ f"Invalid renderer type '{self.renderer}' specified. Must be one of 'hybrid', 'fast-rt', or 'rt'."
+ )
+
@configclass
class PhysicsCfg:
@@ -126,6 +161,26 @@ class MarkerCfg:
"""Index of the arena where the marker should be placed. -1 means all arenas."""
+@configclass
+class WindowRecordCfg:
+ """Configuration for interactive viewer window recording."""
+
+ enable_hotkey: bool = True
+ """Whether to register the ``r`` hotkey for viewer recording when the window opens."""
+
+ save_path: str | None = None
+ """Optional output path for viewer recordings. If None, use the default outputs directory."""
+
+ fps: int = 20
+ """Frames per second for viewer recording."""
+
+ max_memory: int = 1024
+ """Maximum buffered recording memory in MB before auto-stopping capture."""
+
+ video_prefix: str = "viewer_record"
+ """Video file prefix used when no explicit save path is provided."""
+
+
@configclass
class GPUMemoryCfg:
"""A gpu memory configuration dataclass that neatly holds all parameters that configure physics GPU memory for simulation"""
@@ -200,7 +255,7 @@ class RigidBodyAttributesCfg:
contact_offset: float = 0.002
"""Contact offset for collision detection."""
- rest_offset: float = 0.001
+ rest_offset: float = 0.0
"""Rest offset for collision detection."""
enable_collision: bool = True
@@ -846,6 +901,34 @@ class URDFCfg:
fpath_prefix: str = EMBODICHAIN_DEFAULT_DATA_ROOT + "/assembled"
"""Output directory prefix for the assembled URDF file."""
+ component_prefix: List[tuple[str, Union[str, None]]] = field(
+ default_factory=lambda: [
+ ("chassis", None),
+ ("legs", None),
+ ("torso", None),
+ ("head", None),
+ ("left_arm", "left_"),
+ ("right_arm", "right_"),
+ ("left_hand", "left_"),
+ ("right_hand", "right_"),
+ ("arm", None),
+ ("hand", None),
+ ]
+ )
+ """Component name prefixes used during URDF assembly.
+
+ Preferred form is a list of ``(component_name, prefix)`` tuples. For
+ convenience, a mapping ``{component_name: prefix}`` is also accepted when
+ constructing :class:`URDFCfg` and will be normalized internally.
+ """
+
+ name_case: dict[str, str] = field(
+ default_factory=lambda: {
+ "joint": "upper",
+ "link": "lower",
+ }
+ )
+
def __init__(
self,
components: list[dict[str, str | np.ndarray]] | None = None,
@@ -855,6 +938,8 @@ def __init__(
fpath_prefix: str = EMBODICHAIN_DEFAULT_DATA_ROOT + "/assembled",
use_signature_check: bool = True,
base_link_name: str = "base_link",
+ component_prefix: list[tuple[str, str | None]] | None = None,
+ name_case: dict[str, str] | None = None,
):
"""
Initialize URDFCfg with optional list of components and output path settings.
@@ -871,6 +956,9 @@ def __init__(
fpath_prefix (str): Output directory prefix for the assembled URDF file.
use_signature_check (bool): Whether to use signature check when merging URDFs.
base_link_name (str): Name of the base link in the assembled robot.
+ component_prefix (list[tuple[str, str | None]] | None): Optional
+ list of (component_type, prefix) pairs to override default
+ component name prefixes.
"""
self.components = {}
self.sensors = sensors or {}
@@ -880,6 +968,36 @@ def __init__(
self.fname = fname
self.fpath_prefix = fpath_prefix
+ # Initialize component prefixes (patch-style mapping per component type)
+ if component_prefix is None:
+ # Use the same default as the dataclass field
+ self.component_prefix = [
+ ("chassis", None),
+ ("legs", None),
+ ("torso", None),
+ ("head", None),
+ ("left_arm", "left_"),
+ ("right_arm", "right_"),
+ ("left_hand", "left_"),
+ ("right_hand", "right_"),
+ ("arm", None),
+ ("hand", None),
+ ]
+ elif isinstance(component_prefix, dict):
+ # Allow dict-style config: {"left_hand": "l_", ...}
+ self.component_prefix = list(component_prefix.items())
+ else:
+ # Assume caller provided a list of (component_name, prefix) tuples
+ self.component_prefix = component_prefix
+
+ if name_case is None:
+ self.name_case = {
+ "joint": "upper",
+ "link": "lower",
+ }
+ else:
+ self.name_case = name_case
+
# Auto-add components if provided
if components:
for comp_config in components:
@@ -1041,6 +1159,22 @@ def assemble_urdf(self) -> str:
# If there are multiple components, merge them into a single URDF file.
manager = URDFAssemblyManager()
manager.base_link_name = self.base_link_name
+
+ if self.component_prefix is None:
+ self.component_prefix = [
+ ("left_arm", "left_"),
+ ("right_arm", "right_"),
+ ("left_hand", "left_"),
+ ("right_hand", "right_"),
+ ]
+ if isinstance(self.component_prefix, dict):
+ self.component_prefix = list(self.component_prefix.items())
+ # Forward configured component prefixes to the assembly manager
+ manager.component_prefix = self.component_prefix
+
+ if self.name_case is not None:
+ manager.name_case = self.name_case
+
for comp_type, comp_config in components:
params = comp_config.get("params", {})
success = manager.add_component(
@@ -1094,12 +1228,16 @@ def from_dict(cls, init_dict: Dict) -> "URDFCfg":
fpath = init_dict.get("fpath", None)
use_signature_check = init_dict.get("use_signature_check", True)
base_link_name = init_dict.get("base_link_name", "base_link")
+ component_prefix = init_dict.get("component_prefix", None)
+ name_case = init_dict.get("name_case", None)
return cls(
components=components,
sensors=sensors,
fpath=fpath,
use_signature_check=use_signature_check,
base_link_name=base_link_name,
+ component_prefix=component_prefix,
+ name_case=name_case,
)
diff --git a/embodichain/lab/sim/material.py b/embodichain/lab/sim/material.py
index 08c8cb93..7daddb8f 100644
--- a/embodichain/lab/sim/material.py
+++ b/embodichain/lab/sim/material.py
@@ -25,7 +25,6 @@
from functools import cached_property
from dexsim.engine import MaterialInst, Material
-from embodichain.lab.sim.utility import is_rt_enabled
from embodichain.utils import configclass, logger
@@ -42,7 +41,7 @@ class VisualMaterialCfg:
metallic: float = 0.0
"""Metallic factor (0.0 = dielectric, 1.0 = metallic)"""
- roughness: float = 0.5
+ roughness: float = 0.7
"""Surface roughness (0.0 = smooth, 1.0 = rough)"""
# Additional PBR properties
@@ -120,10 +119,6 @@ def __init__(self, cfg: VisualMaterialCfg, mat: Material):
self._default_mat_inst = self.create_instance(self.uid)
- @cached_property
- def is_rt_enabled(self) -> bool:
- return is_rt_enabled()
-
@property
def mat(self) -> Material:
return self._mat
@@ -147,11 +142,8 @@ def set_default_properties(
mat_inst.set_normal_texture(cfg.normal_texture)
mat_inst.set_ao_texture(cfg.ao_texture)
- if self.is_rt_enabled:
- mat_inst.set_ior(cfg.ior)
- mat_inst.mat.update_pbr_material_type(
- self.MAT_TYPE_MAPPING[cfg.material_type]
- )
+ mat_inst.set_ior(cfg.ior)
+ mat_inst.mat.update_pbr_material_type(self.MAT_TYPE_MAPPING[cfg.material_type])
def create_instance(self, uid: str) -> VisualMaterialInst:
"""Create a new material instance from this material template.
@@ -400,9 +392,7 @@ def set_ao_texture(
def set_ior(self, ior: float) -> None:
"""Set index of refraction."""
- if is_rt_enabled() is False:
- logger.log_debug("Ray Tracing rendering not enabled, ignoring IOR setting.")
- return
+
self.ior = ior
inst = self._mat.get_inst(self.uid)
- inst.set_rt_param("ior", ior)
+ inst.set_pbr_param("ior", ior)
diff --git a/embodichain/lab/sim/objects/articulation.py b/embodichain/lab/sim/objects/articulation.py
index 6488fd59..b763bcc4 100644
--- a/embodichain/lab/sim/objects/articulation.py
+++ b/embodichain/lab/sim/objects/articulation.py
@@ -42,7 +42,6 @@
from embodichain.lab.sim.utility.sim_utils import (
get_dexsim_drive_type,
set_dexsim_articulation_cfg,
- is_rt_enabled,
)
from embodichain.lab.sim.utility.solver_utils import (
create_pk_chain,
@@ -907,7 +906,6 @@ def set_local_pose(
logger.log_error(
f"Invalid pose shape {pose.shape}. Expected (N, 7) or (N, 4, 4)."
)
-
# TODO: in manual physics mode, the update should be explicitly called after
# setting the pose to synchronize the state to renderer.
self._world.update(0.001)
@@ -935,15 +933,6 @@ def set_local_pose(
)
self._ps.gpu_compute_articulation_kinematic(gpu_indices=indices)
- # TODO: To be removed when gpu articulation data sync is supported.
- if is_rt_enabled() is False:
- self.body_data.body_link_pose
- link_pose = self.body_data._body_link_pose[local_env_ids]
- self._world.sync_poses_gpu_to_cpu(
- link_pose=CudaArray(link_pose),
- articulation_gpu_indices=CudaArray(indices),
- )
-
def get_local_pose(self, to_matrix=False) -> torch.Tensor:
"""Get local pose (root link pose) of the articulation.
@@ -1056,6 +1045,8 @@ def set_qpos(
# (e.g., support specifying which methods should be decorated for auto-conversion.)
if not isinstance(qpos, torch.Tensor):
qpos = torch.as_tensor(qpos, dtype=torch.float32, device=self.device)
+ else:
+ qpos = qpos.to(device=self.device, dtype=torch.float32)
if joint_ids is None:
local_joint_ids = torch.arange(
@@ -1066,7 +1057,7 @@ def set_qpos(
joint_ids, dtype=torch.int32, device=self.device
)
else:
- local_joint_ids = joint_ids
+ local_joint_ids = joint_ids.to(device=self.device, dtype=torch.int32)
local_env_ids = self._all_indices if env_ids is None else env_ids
@@ -1564,16 +1555,6 @@ def reset(self, env_ids: Sequence[int] | None = None) -> None:
self._ps.gpu_compute_articulation_kinematic(
gpu_indices=self.body_data.gpu_indices[local_env_ids]
)
-
- # TODO: To be removed when gpu articulation data sync is supported.
- if is_rt_enabled() is False:
- self.body_data.body_link_pose
- link_pose = self.body_data._body_link_pose[local_env_ids]
- indices = self.body_data.gpu_indices[local_env_ids]
- self._world.sync_poses_gpu_to_cpu(
- link_pose=CudaArray(link_pose),
- articulation_gpu_indices=CudaArray(indices),
- )
else:
self._world.update(0.001)
@@ -1680,6 +1661,7 @@ def compute_fk(
chain=self.pk_chain,
root_link_name=root_link_name,
end_link_name=end_link_name,
+ device=self.device,
)
result = pk_serial_chain.forward_kinematics(th=qpos, end_only=True)
@@ -1780,9 +1762,10 @@ def compute_jacobian(
# Create pk_serial_chain
pk_serial_chain = create_pk_serial_chain(
- chain=self.pk_chain,
+ urdf_path=self.cfg.fpath,
root_link_name=root_link_name,
end_link_name=end_link_name,
+ device=self.device,
)
# Compute the Jacobian using the kinematics chain
diff --git a/embodichain/lab/sim/objects/gizmo.py b/embodichain/lab/sim/objects/gizmo.py
index 15067772..dc7fea00 100644
--- a/embodichain/lab/sim/objects/gizmo.py
+++ b/embodichain/lab/sim/objects/gizmo.py
@@ -17,7 +17,6 @@
Gizmo: A reusable controller for interactive manipulation of simulation elements (object, robot, camera, etc.)
"""
-
import numpy as np
import torch
import dexsim
@@ -213,10 +212,7 @@ def _setup_camera_gizmo(self):
camera_pos, camera_rot_matrix, "Camera"
)
# New API uses set_flush_localpose_callback
- try:
- self._gizmo.set_flush_localpose_callback(self._proxy_gizmo_callback)
- except Exception as e:
- logger.log_warning(f"Failed to set gizmo callback for camera: {e}")
+ self._gizmo.set_flush_localpose_callback(self._proxy_gizmo_callback)
def _proxy_gizmo_callback(self, *args):
"""Generic callback for proxy-based gizmo.
diff --git a/embodichain/lab/sim/objects/rigid_object.py b/embodichain/lab/sim/objects/rigid_object.py
index 24de293b..2202bbec 100644
--- a/embodichain/lab/sim/objects/rigid_object.py
+++ b/embodichain/lab/sim/objects/rigid_object.py
@@ -31,7 +31,6 @@
VisualMaterialInst,
BatchEntity,
)
-from embodichain.lab.sim.utility import is_rt_enabled
from embodichain.utils.math import convert_quat
from embodichain.utils.math import matrix_from_quat, quat_from_matrix, matrix_from_euler
from embodichain.utils import logger
@@ -81,6 +80,12 @@ def __init__(
self._ang_vel = torch.zeros(
(self.num_instances, 3), dtype=torch.float32, device=self.device
)
+ self._lin_acc = torch.zeros(
+ (self.num_instances, 3), dtype=torch.float32, device=self.device
+ )
+ self._ang_acc = torch.zeros(
+ (self.num_instances, 3), dtype=torch.float32, device=self.device
+ )
# center of mass pose in format (x, y, z, qw, qx, qy, qz)
self.default_com_pose = torch.zeros(
(self.num_instances, 7), dtype=torch.float32, device=self.device
@@ -162,6 +167,51 @@ def vel(self) -> torch.Tensor:
"""
return torch.cat((self.lin_vel, self.ang_vel), dim=-1)
+ @property
+ def lin_acc(self) -> torch.Tensor:
+ if self.device.type == "cpu":
+ self._lin_acc = torch.as_tensor(
+ np.array(
+ [entity.get_linear_acceleration() for entity in self.entities],
+ ),
+ dtype=torch.float32,
+ device=self.device,
+ )
+ else:
+ self.ps.gpu_fetch_rigid_body_data(
+ data=self._lin_acc,
+ gpu_indices=self.gpu_indices,
+ data_type=RigidBodyGPUAPIReadType.LINEAR_ACCELERATION,
+ )
+ return self._lin_acc
+
+ @property
+ def ang_acc(self) -> torch.Tensor:
+ if self.device.type == "cpu":
+ self._ang_acc = torch.as_tensor(
+ np.array(
+ [entity.get_angular_acceleration() for entity in self.entities],
+ ),
+ dtype=torch.float32,
+ device=self.device,
+ )
+ else:
+ self.ps.gpu_fetch_rigid_body_data(
+ data=self._ang_acc,
+ gpu_indices=self.gpu_indices,
+ data_type=RigidBodyGPUAPIReadType.ANGULAR_ACCELERATION,
+ )
+ return self._ang_acc
+
+ @property
+ def acc(self) -> torch.Tensor:
+ """Get the linear and angular accelerations of the rigid bodies.
+
+ Returns:
+ torch.Tensor: The linear and angular accelerations concatenated, with shape (N, 6).
+ """
+ return torch.cat((self.lin_acc, self.ang_acc), dim=-1)
+
@property
def com_pose(self) -> torch.Tensor:
"""Get the center of mass pose of the rigid bodies.
@@ -410,10 +460,6 @@ def set_local_pose(
gpu_indices=indices,
data_type=RigidBodyGPUAPIWriteType.POSE,
)
- if is_rt_enabled() is False:
- self._world.sync_poses_gpu_to_cpu(
- rigid_pose=CudaArray(pose), rigid_gpu_indices=CudaArray(indices)
- )
def get_local_pose(self, to_matrix: bool = False) -> torch.Tensor:
"""Get local pose of the rigid object.
@@ -888,12 +934,9 @@ def set_body_scale(
f"Length of env_ids {len(local_env_ids)} does not match scale length {len(scale)}."
)
- if self.device.type == "cpu":
- for i, env_idx in enumerate(local_env_ids):
- scale_np = scale[i].cpu().numpy()
- self._entities[env_idx].set_body_scale(*scale_np)
- else:
- logger.log_error(f"Setting body scale on GPU is not supported yet.")
+ for i, env_idx in enumerate(local_env_ids):
+ scale_np = scale[i].cpu().numpy()
+ self._entities[env_idx].set_body_scale(*scale_np)
def set_com_pose(
self, com_pose: torch.Tensor, env_ids: Sequence[int] | None = None
diff --git a/embodichain/lab/sim/objects/robot.py b/embodichain/lab/sim/objects/robot.py
index e6dac158..07273e80 100644
--- a/embodichain/lab/sim/objects/robot.py
+++ b/embodichain/lab/sim/objects/robot.py
@@ -934,6 +934,9 @@ def init_solver(self, cfg: Union[SolverCfg, Dict[str, SolverCfg]]) -> None:
):
solver_cfg.joint_names = self.cfg.control_parts[part_name]
self._solvers[name] = solver_cfg.init_solver(device=self.device)
+ joint_ids = self.get_joint_ids(name=part_name)
+ joint_limits = self._data.qpos_limits[0][joint_ids]
+ self._solvers[name].update_with_robot_limit(joint_limits)
def get_solver(self, name: str | None = None) -> BaseSolver | None:
"""Get the kinematic solver for a specific control part.
diff --git a/embodichain/lab/sim/planners/motion_generator.py b/embodichain/lab/sim/planners/motion_generator.py
index 0682c492..220deeca 100644
--- a/embodichain/lab/sim/planners/motion_generator.py
+++ b/embodichain/lab/sim/planners/motion_generator.py
@@ -33,7 +33,6 @@
from .utils import MovePart, MoveType, PlanState, PlanResult
from .utils import calculate_point_allocations, interpolate_xpos
-
__all__ = ["MotionGenerator", "MotionGenCfg", "MotionGenOptions"]
@@ -508,7 +507,11 @@ def interpolate_trajectory(
qpos_seed = options.start_qpos
if qpos_seed is None and qpos_list is not None:
+ # first waypoint as seed
qpos_seed = qpos_list[0]
+ if qpos_seed is None:
+ # fallback to current robot state as seed
+ qpos_seed = self.robot.get_qpos(name=control_part)[0]
# Generate trajectory
interpolate_qpos_list = []
@@ -551,9 +554,14 @@ def interpolate_trajectory(
# compute_batch_ik expects (n_envs, n_batch, 7) or (n_envs, n_batch, 4, 4)
# Here we assume n_envs = 1 or we want to apply this to all envs if available.
# Since MotionGenerator usually works with self.robot.device, we use its batching capabilities.
+ qpos_seed_repeat = (
+ qpos_seed.unsqueeze(0)
+ .repeat(total_interpolated_poses.shape[0], 1)
+ .unsqueeze(0)
+ )
success_batch, qpos_batch = self.robot.compute_batch_ik(
pose=total_interpolated_poses.unsqueeze(0),
- joint_seed=None, # Or use qpos_seed if properly shaped
+ joint_seed=qpos_seed_repeat, # Or use qpos_seed if properly shaped
name=control_part,
)
diff --git a/embodichain/lab/sim/planners/toppra_planner.py b/embodichain/lab/sim/planners/toppra_planner.py
index 0c20ccf9..218d17ed 100644
--- a/embodichain/lab/sim/planners/toppra_planner.py
+++ b/embodichain/lab/sim/planners/toppra_planner.py
@@ -191,11 +191,9 @@ def plan(
)
# Build waypoints
- waypoints = []
- for target in target_states:
- waypoints.append(np.array(target.qpos))
-
- waypoints = np.array(waypoints)
+ waypoints = np.array(
+ [target.qpos.to("cpu").numpy() for target in target_states]
+ )
# Create spline interpolation
# NOTE: Suitable for dense waypoints
ss = np.linspace(0, 1, len(waypoints))
diff --git a/embodichain/lab/sim/planners/utils.py b/embodichain/lab/sim/planners/utils.py
index 6e8e4ceb..cfeee443 100644
--- a/embodichain/lab/sim/planners/utils.py
+++ b/embodichain/lab/sim/planners/utils.py
@@ -23,7 +23,6 @@
from embodichain.utils import logger
-
__all__ = [
"TrajectorySampleMethod",
"MovePart",
diff --git a/embodichain/lab/sim/robots/cobotmagic.py b/embodichain/lab/sim/robots/cobotmagic.py
index 1ffdcd71..ca8e7f6c 100644
--- a/embodichain/lab/sim/robots/cobotmagic.py
+++ b/embodichain/lab/sim/robots/cobotmagic.py
@@ -181,11 +181,17 @@ def build_pk_serial_chain(
if __name__ == "__main__":
from embodichain.lab.sim import SimulationManager, SimulationManagerCfg
+ from embodichain.lab.sim.cfg import RenderCfg
from embodichain.lab.sim.robots import CobotMagicCfg
torch.set_printoptions(precision=5, sci_mode=False)
- config = SimulationManagerCfg(headless=False, sim_device="cuda", num_envs=2)
+ config = SimulationManagerCfg(
+ headless=False,
+ sim_device="cpu",
+ num_envs=2,
+ render_cfg=RenderCfg(renderer="fast-rt"),
+ )
sim = SimulationManager(config)
config = {
@@ -195,7 +201,6 @@ def build_pk_serial_chain(
cfg = CobotMagicCfg.from_dict(config)
robot = sim.add_robot(cfg=cfg)
- sim.init_gpu_physics()
print("CobotMagic added to the simulation.")
from IPython import embed
diff --git a/embodichain/lab/sim/robots/dexforce_w1/cfg.py b/embodichain/lab/sim/robots/dexforce_w1/cfg.py
index c6586b4e..40f95b09 100644
--- a/embodichain/lab/sim/robots/dexforce_w1/cfg.py
+++ b/embodichain/lab/sim/robots/dexforce_w1/cfg.py
@@ -159,7 +159,7 @@ def _build_default_solver_cfg(is_industrial: bool) -> SolverCfg:
end_link_name="right_ee",
root_link_name="right_arm_base",
dh_params=w1_right_arm_params.dh_params,
- qpos_limits=w1_right_arm_params.qpos_limits,
+ user_qpos_limits=w1_right_arm_params.qpos_limits,
T_e_oe=w1_right_arm_params.T_e_oe,
T_b_ob=w1_right_arm_params.T_b_ob,
link_lengths=w1_right_arm_params.link_lengths,
@@ -170,7 +170,7 @@ def _build_default_solver_cfg(is_industrial: bool) -> SolverCfg:
end_link_name="left_ee",
root_link_name="left_arm_base",
dh_params=w1_left_arm_params.dh_params,
- qpos_limits=w1_left_arm_params.qpos_limits,
+ user_qpos_limits=w1_left_arm_params.qpos_limits,
T_e_oe=w1_left_arm_params.T_e_oe,
T_b_ob=w1_left_arm_params.T_b_ob,
link_lengths=w1_left_arm_params.link_lengths,
diff --git a/embodichain/lab/sim/robots/dexforce_w1/utils.py b/embodichain/lab/sim/robots/dexforce_w1/utils.py
index c5ebbd0d..58fcbe70 100644
--- a/embodichain/lab/sim/robots/dexforce_w1/utils.py
+++ b/embodichain/lab/sim/robots/dexforce_w1/utils.py
@@ -28,7 +28,6 @@
from embodichain.lab.sim.solvers import SolverCfg
from embodichain.lab.sim.cfg import RobotCfg, URDFCfg
-
all = [
"ChassisManager",
"TorsoManager",
diff --git a/embodichain/lab/sim/sensors/camera.py b/embodichain/lab/sim/sensors/camera.py
index c5baed17..e672532e 100644
--- a/embodichain/lab/sim/sensors/camera.py
+++ b/embodichain/lab/sim/sensors/camera.py
@@ -17,19 +17,15 @@
from __future__ import annotations
import dexsim
-import math
import torch
import dexsim.render as dr
-import warp as wp
from functools import cached_property
-from typing import Union, Tuple, Sequence, List
+from typing import Tuple, Sequence, List
from embodichain.lab.sim.sensors import BaseSensor, SensorCfg
from embodichain.utils.math import matrix_from_quat, quat_from_matrix, look_at_to_pose
-from embodichain.utils.warp.kernels import reshape_tiled_image
from embodichain.utils import logger, configclass
-from embodichain.lab.sim.utility.sim_utils import is_rt_enabled
@configclass
@@ -97,17 +93,12 @@ def get_view_attrib(self) -> dr.ViewFlags:
The view attributes for the camera.
"""
view_attrib: dr.ViewFlags = dr.ViewFlags.COLOR
- # TODO: change for fast-rt renderer backend.
if self.enable_color:
view_attrib |= dr.ViewFlags.COLOR
if self.enable_depth:
- if is_rt_enabled() is False:
- view_attrib |= dr.ViewFlags.NORMAL
view_attrib |= dr.ViewFlags.DEPTH
if self.enable_mask:
view_attrib |= dr.ViewFlags.MASK
- if is_rt_enabled() is False:
- view_attrib |= dr.ViewFlags.DEPTH
if self.enable_normal:
view_attrib |= dr.ViewFlags.NORMAL
if self.enable_position:
@@ -152,55 +143,25 @@ def _build_sensor_from_config(
arenas = [env]
num_instances = len(arenas)
- if self.is_rt_enabled:
- self._frame_buffer = self._world.create_camera_group(
- [config.width, config.height], num_instances, True
- )
-
- view_attrib = config.get_view_attrib()
- for i, arena in enumerate(arenas):
- view_name = f"{self.uid}_view{i + 1}"
- view = arena.create_camera(
- view_name,
- config.width,
- config.height,
- True,
- view_attrib,
- self._frame_buffer,
- )
- view.set_intrinsic(config.intrinsics)
- view.set_near(config.near)
- view.set_far(config.far)
- self._entities[i] = view
+ self._frame_buffer = self._world.create_camera_group(
+ [config.width, config.height], num_instances, True
+ )
- else:
- self._grid_size = math.ceil(math.sqrt(num_instances))
- frame_width = self._grid_size * config.width
- frame_height = self._grid_size * config.height
- view_attrib = config.get_view_attrib()
- # Create the data frame
- self._frame_buffer = self._world.create_frame_buffer(
- [frame_width, frame_height], view_attrib, True
+ view_attrib = config.get_view_attrib()
+ for i, arena in enumerate(arenas):
+ view_name = f"{self.uid}_view{i + 1}"
+ view = arena.create_camera(
+ view_name,
+ config.width,
+ config.height,
+ True,
+ view_attrib,
+ self._frame_buffer,
)
- self._frame_buffer.set_read_able(view_attrib)
-
- # Create camera views
- for i, arena in enumerate(arenas):
- col = i // self._grid_size
- row = i % self._grid_size
- x = row * config.width
- y = col * config.height
- view_name = f"{self.uid}_view{i + 1}"
-
- view = arena.create_camera_view(
- view_name, (x, y), (config.width, config.height), self._frame_buffer
- )
- view.set_intrinsic(config.intrinsics)
- view.set_near(config.near)
- view.set_far(config.far)
- view.enable_postprocessing(True)
-
- self._entities[i] = view
+ view.set_intrinsic(config.intrinsics)
+ view.set_near(config.near)
+ view.set_far(config.far)
+ self._entities[i] = view
# Define a mapping of data types to their respective shapes and dtypes
buffer_specs = {
@@ -239,15 +200,6 @@ def _build_sensor_from_config(
if self.cfg.extrinsics.parent is not None:
self._attach_to_entity()
- @cached_property
- def is_rt_enabled(self) -> bool:
- """Check if Ray Tracing rendering backend is enabled in the default dexsim world.
-
- Returns:
- bool: True if Ray Tracing rendering is enabled, False otherwise.
- """
- return is_rt_enabled()
-
@cached_property
def group_id(self) -> int:
"""Get the camera group ID in the dexsim world.
@@ -255,13 +207,7 @@ def group_id(self) -> int:
Returns:
int: The camera group ID.
"""
- if self.is_rt_enabled:
- return self._frame_buffer.get_group_id()
- else:
- logger.log_warning(
- "Camera group ID is only available for Ray Tracing renderer. Returning -1 for non-RT renderer."
- )
- return -1
+ return self._frame_buffer.get_group_id()
@property
def is_attached(self) -> bool:
@@ -284,81 +230,38 @@ def update(self, **kwargs) -> None:
Args:
**kwargs: Additional keyword arguments for sensor update.
- - fetch_only (bool): If True, only fetch the data from dexsim internal frame buffer without performing rendering.
"""
fetch_only = kwargs.get("fetch_only", False)
if not fetch_only:
- if self.is_rt_enabled:
- self._frame_buffer.apply()
- else:
- self._frame_buffer.apply_frame()
-
+ self._frame_buffer.apply()
self.cfg: CameraCfg
- # TODO: support fetch data from gpu buffer directly.
+
if self.cfg.enable_color:
- if self.is_rt_enabled:
- self._data_buffer["color"] = self._frame_buffer.get_rgb_gpu_buffer().to(
- self.device
- )
- else:
- data = self._frame_buffer.get_color_gpu_buffer().to(self.device)
- self._update_buffer_impl(data, self._data_buffer["color"])
+ self._data_buffer["color"] = self._frame_buffer.get_rgb_gpu_buffer().to(
+ self.device
+ )
if self.cfg.enable_depth:
- data = self._frame_buffer.get_depth_gpu_buffer().to(self.device)
- if self.is_rt_enabled:
- self._data_buffer["depth"] = data
- else:
- self._update_buffer_impl(
- data, self._data_buffer["depth"].unsqueeze_(-1)
- )
- self._data_buffer["depth"].squeeze_(-1)
+ self._data_buffer["depth"] = self._frame_buffer.get_depth_gpu_buffer().to(
+ self.device
+ )
if self.cfg.enable_mask:
- if self.is_rt_enabled:
- data = self._frame_buffer.get_visible_mask_gpu_buffer().to(
- self.device, torch.int32
- )
- self._data_buffer["mask"] = data
- else:
- data = self._frame_buffer.get_visible_gpu_buffer().to(
- self.device, torch.int32
- )
- self._update_buffer_impl(data, self._data_buffer["mask"].unsqueeze_(-1))
- self._data_buffer["mask"].squeeze_(-1)
+ self._data_buffer[
+ "mask"
+ ] = self._frame_buffer.get_visible_mask_gpu_buffer().to(
+ self.device, torch.int32
+ )
if self.cfg.enable_normal:
- data = self._frame_buffer.get_normal_gpu_buffer().to(self.device)
- if self.is_rt_enabled:
- self._data_buffer["normal"] = data
- else:
- self._update_buffer_impl(data, self._data_buffer["normal"])
+ self._data_buffer["normal"] = self._frame_buffer.get_normal_gpu_buffer().to(
+ self.device
+ )[..., :3]
if self.cfg.enable_position:
- data = self._frame_buffer.get_position_gpu_buffer().to(self.device)
- if self.is_rt_enabled:
- self._data_buffer["position"] = data
- else:
- self._update_buffer_impl(data, self._data_buffer["position"])
-
- def _update_buffer_impl(
- self, data_buffer: torch.Tensor, data_buffer_out: torch.Tensor
- ) -> None:
- device = str(self.device)
- channel = data_buffer.shape[-1] if data_buffer.dim() >= 3 else 1
- wp.launch(
- kernel=reshape_tiled_image,
- dim=(self.num_instances, self.cfg.height, self.cfg.width),
- inputs=[
- wp.from_torch(data_buffer).flatten(),
- wp.from_torch(data_buffer_out),
- self.cfg.height,
- self.cfg.width,
- channel,
- self._grid_size,
- ],
- device="cuda:0" if device == "cuda" else device,
- )
+ self._data_buffer["position"] = (
+ self._frame_buffer.get_position_gpu_buffer().to(self.device)[..., :3]
+ )
def _attach_to_entity(self) -> None:
"""Attach the sensor to the parent entity in each environment."""
diff --git a/embodichain/lab/sim/sensors/stereo.py b/embodichain/lab/sim/sensors/stereo.py
index dfea8a86..999bedca 100644
--- a/embodichain/lab/sim/sensors/stereo.py
+++ b/embodichain/lab/sim/sensors/stereo.py
@@ -17,21 +17,16 @@
from __future__ import annotations
import dexsim
-import math
import torch
import numpy as np
-import warp as wp
import dexsim.render as dr
from typing import Dict, Tuple, List, Sequence
-from tensordict import TensorDict
from dexsim.utility import inv_transform
from embodichain.lab.sim.sensors import Camera, CameraCfg
-from embodichain.utils.warp.kernels import reshape_tiled_image
from embodichain.utils.math import matrix_from_euler
from embodichain.utils import logger, configclass
-from embodichain.lab.sim.utility.sim_utils import is_rt_enabled
@configclass
@@ -177,97 +172,46 @@ def _build_sensor_from_config(
arenas = [env]
num_instances = len(arenas)
- if self.is_rt_enabled:
- self._frame_buffer = self._world.create_camera_group(
- [config.width, config.height], num_instances * 2, True
+ self._frame_buffer = self._world.create_camera_group(
+ [config.width, config.height], num_instances * 2, True
+ )
+ view_attrib = config.get_view_attrib()
+ left_list = []
+ right_list = []
+ for i, arena in enumerate(arenas):
+ left_view_name = f"{self.uid}_left_view{i + 1}"
+ left_view = arena.create_camera(
+ left_view_name,
+ config.width,
+ config.height,
+ True,
+ view_attrib,
+ self._frame_buffer,
)
- view_attrib = config.get_view_attrib()
- left_list = []
- right_list = []
- for i, arena in enumerate(arenas):
- left_view_name = f"{self.uid}_left_view{i + 1}"
- left_view = arena.create_camera(
- left_view_name,
- config.width,
- config.height,
- True,
- view_attrib,
- self._frame_buffer,
- )
- left_view.set_intrinsic(config.intrinsics)
- left_view.set_near(config.near)
- left_view.set_far(config.far)
- left_list.append(left_view)
-
- for i, arena in enumerate(arenas):
- right_view_name = f"{self.uid}_right_view{i + 1}"
- right_view = arena.create_camera(
- right_view_name,
- config.width,
- config.height,
- True,
- view_attrib,
- self._frame_buffer,
- )
- right_view.set_intrinsic(config.intrinsics_right)
- right_view.set_near(config.near)
- right_view.set_far(config.far)
- right_list.append(right_view)
-
- for i in range(num_instances):
- self._entities[i] = PairCameraView(
- left_list[i], right_list[i], config.left_to_right.cpu().numpy()
- )
-
- else:
- self._grid_size = math.ceil(math.sqrt(num_instances))
-
- # stereo camera has two views, we append the right camera to the left camera's view list
- frame_width = self._grid_size * config.width * 2
- frame_height = self._grid_size * config.height
- view_attrib = config.get_view_attrib()
-
- # Create the data frame
- self._frame_buffer = self._world.create_frame_buffer(
- [frame_width, frame_height], view_attrib, True
+ left_view.set_intrinsic(config.intrinsics)
+ left_view.set_near(config.near)
+ left_view.set_far(config.far)
+ left_list.append(left_view)
+
+ for i, arena in enumerate(arenas):
+ right_view_name = f"{self.uid}_right_view{i + 1}"
+ right_view = arena.create_camera(
+ right_view_name,
+ config.width,
+ config.height,
+ True,
+ view_attrib,
+ self._frame_buffer,
+ )
+ right_view.set_intrinsic(config.intrinsics_right)
+ right_view.set_near(config.near)
+ right_view.set_far(config.far)
+ right_list.append(right_view)
+
+ for i in range(num_instances):
+ self._entities[i] = PairCameraView(
+ left_list[i], right_list[i], config.left_to_right.cpu().numpy()
)
- self._frame_buffer.set_read_able(view_attrib)
-
- # Create camera views
- for i, arena in enumerate(arenas):
- col = i // self._grid_size
- row = i % self._grid_size
- x = row * config.width * 2
- y = col * config.height
- left_view_name = f"{self.uid}_left_view{i + 1}"
-
- left_view = arena.create_camera_view(
- left_view_name,
- (x, y),
- (config.width, config.height),
- self._frame_buffer,
- )
-
- left_view.set_intrinsic(config.intrinsics)
- left_view.set_near(config.near)
- left_view.set_far(config.far)
- left_view.enable_postprocessing(True)
-
- right_view_name = f"{self.uid}_right_view{i + 1}"
- right_view = arena.create_camera_view(
- right_view_name,
- (x + config.width, y),
- (config.width, config.height),
- self._frame_buffer,
- )
- right_view.set_intrinsic(config.intrinsics_right)
- right_view.set_near(config.near)
- right_view.set_far(config.far)
- right_view.enable_postprocessing(True)
-
- self._entities[i] = PairCameraView(
- left_view, right_view, config.left_to_right.cpu().numpy()
- )
# Define a mapping of data types to their respective shapes and dtypes
buffer_specs = {
@@ -348,66 +292,38 @@ def update(self, **kwargs) -> None:
- disparity: Disparity images with shape (B, H, W, 1) and dtype torch.float32
Args:
**kwargs: Additional keyword arguments for sensor update.
- - fetch_only (bool): If True, only fetch the data from dexsim internal frame buffer without performing rendering.
"""
-
fetch_only = kwargs.get("fetch_only", False)
if not fetch_only:
- if self.is_rt_enabled:
- self._frame_buffer.apply()
- else:
- self._frame_buffer.apply_frame()
+ self._frame_buffer.apply()
self.cfg: StereoCameraCfg
if self.cfg.enable_color:
- if self.is_rt_enabled:
- data = self._frame_buffer.get_rgb_gpu_buffer().to(self.device)
- self._data_buffer["color"] = data[: self.num_instances, ...]
- self._data_buffer[f"color_right"] = data[self.num_instances :, ...]
- else:
- data = self._frame_buffer.get_color_gpu_buffer().to(self.device)
- self._update_buffer_impl(data, self._data_buffer_stereo["color"])
+ data = self._frame_buffer.get_rgb_gpu_buffer().to(self.device)
+ self._data_buffer["color"] = data[: self.num_instances, ...]
+ self._data_buffer[f"color_right"] = data[self.num_instances :, ...]
if self.cfg.enable_depth:
data = self._frame_buffer.get_depth_gpu_buffer().to(self.device)
- if self.is_rt_enabled:
- self._data_buffer["depth"] = data[: self.num_instances, ...].unsqueeze_(
- -1
- )
- self._data_buffer[f"depth_right"] = data[
- self.num_instances :, ...
- ].unsqueeze_(-1)
- else:
- self._update_buffer_impl(data, self._data_buffer_stereo["depth"])
+ self._data_buffer["depth"] = data[: self.num_instances, ...].unsqueeze_(-1)
+ self._data_buffer[f"depth_right"] = data[
+ self.num_instances :, ...
+ ].unsqueeze_(-1)
if self.cfg.enable_mask:
- if self.is_rt_enabled:
- data = self._frame_buffer.get_visible_mask_gpu_buffer().to(
- self.device, torch.int32
- )
- self._data_buffer["mask"] = data[: self.num_instances, ...].unsqueeze_(
- -1
- )
- self._data_buffer[f"mask_right"] = data[
- self.num_instances :, ...
- ].unsqueeze_(-1)
- else:
- data = self._frame_buffer.get_visible_gpu_buffer().to(
- self.device, torch.int32
- )
- self._update_buffer_impl(data, self._data_buffer_stereo["mask"])
+ data = self._frame_buffer.get_visible_mask_gpu_buffer().to(
+ self.device, torch.int32
+ )
+ self._data_buffer["mask"] = data[: self.num_instances, ...].unsqueeze_(-1)
+ self._data_buffer[f"mask_right"] = data[
+ self.num_instances :, ...
+ ].unsqueeze_(-1)
if self.cfg.enable_normal:
- data = self._frame_buffer.get_normal_gpu_buffer().to(self.device)
- if self.is_rt_enabled:
- self._data_buffer["normal"] = data[: self.num_instances, ...]
- self._data_buffer[f"normal_right"] = data[self.num_instances :, ...]
- else:
- self._update_buffer_impl(data, self._data_buffer_stereo["normal"])
+ data = self._frame_buffer.get_normal_gpu_buffer().to(self.device)[..., :3]
+ self._data_buffer["normal"] = data[: self.num_instances, ...]
+ self._data_buffer[f"normal_right"] = data[self.num_instances :, ...]
if self.cfg.enable_position:
- data = self._frame_buffer.get_position_gpu_buffer().to(self.device)
- if self.is_rt_enabled:
- self._data_buffer["position"] = data[: self.num_instances, ...]
- self._data_buffer[f"position_right"] = data[self.num_instances :, ...]
- else:
- self._update_buffer_impl(data, self._data_buffer_stereo["position"])
+ data = self._frame_buffer.get_position_gpu_buffer().to(self.device)[..., :3]
+ self._data_buffer["position"] = data[: self.num_instances, ...]
+ self._data_buffer[f"position_right"] = data[self.num_instances :, ...]
if self.cfg.enable_disparity:
disparity = self._data_buffer["disparity"]
disparity.fill_(0.0)
@@ -421,25 +337,6 @@ def update(self, **kwargs) -> None:
self.cfg.fx * distance / depth[valid_depth_mask]
)
- def _update_buffer_impl(
- self, data_buffer: torch.Tensor, data_buffer_out: torch.Tensor
- ) -> None:
- device = str(self.device)
- channel = data_buffer.shape[-1] if data_buffer.dim() >= 3 else 1
- wp.launch(
- kernel=reshape_tiled_image,
- dim=(self.num_instances, self.cfg.height, self.cfg.width * 2),
- inputs=[
- wp.from_torch(data_buffer).flatten(),
- wp.from_torch(data_buffer_out),
- self.cfg.height,
- self.cfg.width * 2,
- channel,
- self._grid_size,
- ],
- device="cuda:0" if device == "cuda" else device,
- )
-
def get_left_right_arena_pose(self) -> torch.Tensor:
"""Get the local pose of the left and right cameras.
diff --git a/embodichain/lab/sim/sim_manager.py b/embodichain/lab/sim/sim_manager.py
index bee36aa5..9aa08911 100644
--- a/embodichain/lab/sim/sim_manager.py
+++ b/embodichain/lab/sim/sim_manager.py
@@ -17,7 +17,11 @@
from __future__ import annotations
import os
+import gc
import sys
+import queue
+import time
+import threading
import dexsim
import torch
import numpy as np
@@ -26,6 +30,7 @@
from tqdm import tqdm
from pathlib import Path
from copy import deepcopy
+from datetime import datetime
from functools import cached_property
from typing import List, Union, Dict, Union, Sequence
from dataclasses import dataclass, asdict, field, MISSING
@@ -45,6 +50,7 @@
RigidBodyGPUAPIReadType,
ArticulationGPUAPIReadType,
)
+from dexsim.core import TASK_RETURN
from dexsim.engine import CudaArray, Material
from dexsim.models import MeshObject
from dexsim.render import Light as _Light, LightType, Windows
@@ -68,9 +74,11 @@
ContactSensor,
)
from embodichain.lab.sim.cfg import (
+ RenderCfg,
PhysicsCfg,
MarkerCfg,
GPUMemoryCfg,
+ WindowRecordCfg,
LightCfg,
RigidObjectCfg,
SoftObjectCfg,
@@ -105,14 +113,8 @@ class SimulationManagerCfg:
headless: bool = False
"""Whether to run the simulation in headless mode (no Window)."""
- enable_rt: bool = False
- """Whether to enable ray tracing rendering."""
-
- enable_denoiser: bool = True
- """Whether to enable denoising for ray tracing rendering."""
-
- spp: int = 64
- """Samples per pixel for ray tracing rendering. This parameter is only valid when ray tracing is enabled and enable_denoiser is False."""
+ render_cfg: RenderCfg = field(default_factory=RenderCfg)
+ """The rendering configuration parameters."""
gpu_id: int = 0
"""The gpu index that the simulation engine will be used.
@@ -147,6 +149,26 @@ class SimulationManagerCfg:
gpu_memory_config: GPUMemoryCfg = field(default_factory=GPUMemoryCfg)
"""The GPU memory configuration parameters."""
+ window_record: WindowRecordCfg = field(default_factory=WindowRecordCfg)
+ """Viewer window recording settings (hotkey, paths, FPS, memory budget)."""
+
+
+@dataclass
+class _WindowRecordState:
+ """Internal state for viewer-window recording."""
+
+ time_step: float
+ max_memory_bytes: int
+ output_dir: str
+ video_name: str
+ save_kwargs: dict[str, object]
+ record_camera: object | None = None
+ frames: list[np.ndarray] = field(default_factory=list)
+ current_memory_bytes: int = 0
+ last_capture_time: float = field(default_factory=time.time)
+ task_status: int = TASK_RETURN.TASK_LOOP
+ loop_handle: object | None = None
+
class SimulationManager:
r"""Global Embodied AI simulation manager.
@@ -166,6 +188,8 @@ class SimulationManager:
_instances = {}
+ _cleanup_queue: queue.Queue = queue.Queue()
+
SUPPORTED_SENSOR_TYPES = {
"Camera": Camera,
"StereoCamera": StereoCamera,
@@ -189,11 +213,6 @@ def __init__(
# Mark as initialized
self.instance_id = instance_id
- if sim_config.enable_rt and instance_id > 0:
- logger.log_error(
- f"Ray Tracing rendering backend is only supported for single instance (instance_id=0). "
- )
-
# Cache paths
self._sim_cache_dir = SIM_CACHE_DIR
self._material_cache_dir = MATERIAL_CACHE_DIR
@@ -220,11 +239,22 @@ def __init__(
self._window: Windows | None = None
self._is_registered_window_control = False
+ self._window_record_state: _WindowRecordState | None = None
+ self._window_record_camera: object | None = None
+ wr = sim_config.window_record
+ self._window_record_hotkey_cfg: dict[str, object] | None = (
+ {
+ "save_path": wr.save_path,
+ "fps": wr.fps,
+ "max_memory": wr.max_memory,
+ "video_prefix": wr.video_prefix,
+ }
+ if wr.enable_hotkey
+ else None
+ )
+ self._window_record_input_control: ObjectManipulator | None = None
+ self._window_record_save_threads: list[threading.Thread] = []
- fps = int(1.0 / sim_config.physics_dt)
- self._world.set_physics_fps(fps)
-
- self._world.set_time_scale(1.0)
self._world.set_delta_time(sim_config.physics_dt)
self._world.show_coordinate_axis(False)
@@ -239,13 +269,6 @@ def __init__(
self._env = self._world.get_env()
- # set unique material path to accelerate material creation.
- # TODO: This will be removed.
- if self.sim_config.enable_rt is False:
- self._env.set_unique_mat_path(
- os.path.join(self._material_cache_dir, "default_mat")
- )
-
# arena is used as a standalone space for robots to simulate in.
self._arenas: List[dexsim.environment.Arena] = []
@@ -284,7 +307,7 @@ def __init__(
if sim_config.headless is False:
self._window = self._world.get_windows()
- self._register_default_window_control()
+ # self._register_default_window_control()
@classmethod
def get_instance(cls, instance_id: int = 0) -> SimulationManager:
@@ -334,7 +357,7 @@ def is_instantiated(cls, instance_id: int = 0) -> bool:
"""
return instance_id in cls._instances
- @property
+ @cached_property
def num_envs(self) -> int:
"""Get the number of arenas in the simulation.
@@ -343,16 +366,10 @@ def num_envs(self) -> int:
"""
return len(self._arenas) if len(self._arenas) > 0 else 1
- @cached_property
+ @property
def is_use_gpu_physics(self) -> bool:
"""Check if the physics simulation is using GPU."""
- world_config = dexsim.get_world_config()
- return self.device.type == "cuda" and world_config.enable_gpu_sim
-
- @property
- def is_rt_enabled(self) -> bool:
- """Check if Ray Tracing rendering backend is enabled."""
- return self.sim_config.enable_rt
+ return self.device.type == "cuda"
@property
def is_physics_manually_update(self) -> bool:
@@ -395,11 +412,10 @@ def _convert_sim_config(
world_config.length_tolerance = sim_config.physics_config.length_tolerance
world_config.speed_tolerance = sim_config.physics_config.speed_tolerance
- if sim_config.enable_rt:
- world_config.renderer = dexsim.types.Renderer.FASTRT
- if sim_config.enable_denoiser is False:
- world_config.raytrace_config.spp = sim_config.spp
- world_config.raytrace_config.open_denoise = False
+ world_config.renderer = sim_config.render_cfg.to_dexsim_flags()
+ if sim_config.render_cfg.enable_denoiser is False:
+ world_config.raytrace_config.spp = sim_config.render_cfg.spp
+ world_config.raytrace_config.open_denoise = False
if type(sim_config.sim_device) is str:
self.device = torch.device(sim_config.sim_device)
@@ -458,28 +474,6 @@ def init_gpu_physics(self) -> None:
if self._is_initialized_gpu_physics:
return
- # init rigid body.
- rigid_body_num = (
- 0
- if self._get_non_static_rigid_obj_num() == 0
- else len(self._ps.get_gpu_rigid_indices())
- )
- self._rigid_body_pose = torch.zeros(
- (rigid_body_num, 7), dtype=torch.float32, device=self.device
- )
-
- # init articulation.
- articulation_num = (
- 0
- if len(self._articulations) == 0 and len(self._robots) == 0
- else len(self._ps.get_gpu_articulation_indices())
- )
- max_link_count = self._ps.gpu_get_articulation_max_link_count()
- self._link_pose = torch.zeros(
- (articulation_num, max_link_count, 7),
- dtype=torch.float32,
- device=self.device,
- )
for art in self._articulations.values():
art.reallocate_body_data()
for robot in self._robots.values():
@@ -498,12 +492,7 @@ def render_camera_group(self, group_ids: list[int]) -> None:
Note: This interface is only valid when Ray Tracing rendering backend is enabled.
"""
- if self.is_rt_enabled:
- self._world.render_camera_group(group_ids)
- else:
- logger.log_warning(
- "This interface is only valid when Ray Tracing rendering backend is enabled."
- )
+ self._world.render_camera_group(group_ids)
def update(self, physics_dt: float | None = None, step: int = 10) -> None:
"""Update the physics.
@@ -524,43 +513,9 @@ def update(self, physics_dt: float | None = None, step: int = 10) -> None:
for i in range(step):
self._world.update(physics_dt)
- if self.sim_config.enable_rt is False:
- self._sync_gpu_data()
-
else:
logger.log_warning("Physics simulation is not manually updated.")
- def _sync_gpu_data(self) -> None:
- if not self.is_use_gpu_physics:
- return
-
- if not self._is_initialized_gpu_physics:
- logger.log_warning(
- "GPU physics is not initialized. Skipping GPU data synchronization."
- )
- return
-
- if self.is_window_opened or self._sensors:
- if len(self._rigid_body_pose) > 0:
- self._ps.gpu_fetch_rigid_body_data(
- data=CudaArray(self._rigid_body_pose),
- gpu_indices=self._ps.get_gpu_rigid_indices(),
- data_type=RigidBodyGPUAPIReadType.POSE,
- )
-
- if len(self._link_pose) > 0:
- self._ps.gpu_fetch_link_data(
- data=CudaArray(self._link_pose),
- gpu_indices=self._ps.get_gpu_articulation_indices(),
- data_type=ArticulationGPUAPIReadType.LINK_GLOBAL_POSE,
- )
-
- # TODO: might be optimized.
- self._world.sync_poses_gpu_to_cpu(
- rigid_pose=CudaArray(self._rigid_body_pose),
- link_pose=CudaArray(self._link_pose),
- )
-
def get_env(self, arena_index: int = -1) -> dexsim.environment.Arena:
"""Get the arena or env by index.
@@ -589,12 +544,23 @@ def open_window(self) -> None:
"""Open the simulation window."""
self._world.open_window()
self._window = self._world.get_windows()
- self._register_default_window_control()
+
+ # TODO: will open these features after fix the related blocking issues.
+ # self._register_default_window_control()
+ # if (
+ # self._window_record_hotkey_cfg is not None
+ # and self._window_record_input_control is None
+ # ):
+ # self.enable_window_record_hotkey(**self._window_record_hotkey_cfg)
self.is_window_opened = True
def close_window(self) -> None:
"""Close the simulation window."""
+ if self.is_window_recording():
+ self.stop_window_record()
self._world.close_window()
+ self._window = None
+ self._window_record_input_control = None
self.is_window_opened = False
def _build_multiple_arenas(self, num: int, space: float | None = None) -> None:
@@ -662,6 +628,7 @@ def _create_default_plane(self):
plane_collision = self._env.create_cube(
default_length, default_length, default_length / 10
)
+ plane_collision.set_visible(False)
plane_collision_pose = np.eye(4, dtype=float)
plane_collision_pose[2, 3] = -default_length / 20 - 0.001
plane_collision.set_local_pose(plane_collision_pose)
@@ -682,13 +649,11 @@ def set_default_background(self) -> None:
uid=mat_name,
base_color_texture=color_texture,
roughness_texture=roughness_texture,
+ roughness=0.7,
)
)
- if self.sim_config.enable_rt:
- self.set_emission_light([1.0, 1.0, 1.0], 80.0)
- else:
- self.set_indirect_lighting("lab_day")
+ self.set_emission_light([1.5, 1.5, 1.5], 150.0)
self._default_plane.set_material(mat.get_instance("plane_mat").mat)
self._visual_materials[mat_name] = mat
@@ -1064,17 +1029,20 @@ def arena_offsets(self) -> torch.Tensor:
)
return arena_offsets
- def _get_non_static_rigid_obj_num(self) -> int:
- """Get the number of non-static rigid objects in the scene.
+ def has_non_static_rigid_object(self) -> bool:
+ """Check if there is any non-static rigid object in the simulation.
Returns:
- int: The number of non-static rigid objects.
+ bool: True if there is at least one non-static rigid object, False otherwise.
"""
- count = 0
- for obj in self._rigid_objects.values():
- if obj.cfg.body_type != "static":
- count += 1
- return count
+ for rigid_obj in self._rigid_objects.values():
+ if rigid_obj.body_type != "static":
+ return True
+
+ if len(self._rigid_object_groups) > 0:
+ return True
+
+ return False
def add_articulation(
self,
@@ -1105,7 +1073,9 @@ def add_articulation(
if len(env_list) > 1:
logger.log_error(f"Currently not supporting multiple arenas for USD.")
env = self._env
- results = env.import_from_usd_file(cfg.fpath, return_object=True)
+ results = env.import_from_usd_file(
+ cfg.fpath, return_object=True, cache_dir=self._convex_decomp_dir
+ )
# print("USD import results:", results)
articulations_found = []
@@ -1558,6 +1528,13 @@ def draw_marker(
return False
draw_xpos = deepcopy(cfg.axis_xpos)
+ if isinstance(draw_xpos, torch.Tensor):
+ draw_xpos = draw_xpos.detach().cpu().numpy()
+ elif isinstance(draw_xpos, (list, tuple)):
+ draw_xpos = [
+ item.detach().cpu().numpy() if isinstance(item, torch.Tensor) else item
+ for item in draw_xpos
+ ]
draw_xpos = np.array(draw_xpos)
if draw_xpos.ndim == 2:
if draw_xpos.shape == (4, 4):
@@ -1657,11 +1634,6 @@ def _register_default_window_control(self) -> None:
"""Register default window controls for better simulation interaction."""
from dexsim.types import InputKey
- # TODO: window control has stucking issue with extra sensor under Raster renderer backend.
- # Will be fixed in next dexsim release.
- if self.is_rt_enabled is False:
- return
-
if self._is_registered_window_control:
return
@@ -1699,6 +1671,230 @@ def add_custom_window_control(self, controls: list[ObjectManipulator]) -> None:
for control in controls:
self._window.add_input_control(control)
+ def _build_window_record_output(
+ self, save_path: str | None, video_prefix: str
+ ) -> tuple[str, str]:
+ """Resolve the output directory and file name for viewer recording."""
+ if save_path is None:
+ output_dir = os.path.join(os.getcwd(), "outputs", "videos")
+ timestamp = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
+ video_name = f"{video_prefix}_{timestamp}"
+ else:
+ output_dir = os.path.dirname(save_path) or os.getcwd()
+ video_name = Path(os.path.basename(save_path)).stem
+ return output_dir, video_name
+
+ def is_window_recording(self) -> bool:
+ """Check whether the viewer window is currently recording."""
+ return self._window_record_state is not None
+
+ def _step_window_record(self, state: _WindowRecordState) -> int:
+ """Capture frames in the render thread without blocking the UI loop."""
+ if state.task_status != TASK_RETURN.TASK_LOOP:
+ return state.task_status
+
+ now = time.time()
+ if now - state.last_capture_time < state.time_step:
+ return state.task_status
+
+ state.last_capture_time = now
+ frame: np.ndarray | None = None
+ if self._window is not None and state.record_camera is not None:
+ pose = np.asarray(self._window.get_pose_matrix(), dtype=np.float32)
+ state.record_camera.set_world_pose(pose)
+ state.record_camera.render()
+ rgb = np.asarray(state.record_camera.get_rgb_map())
+ if rgb.size != 0:
+ frame = np.ascontiguousarray(rgb[..., :3])
+ if frame is None:
+ return state.task_status
+
+ state.frames.append(frame)
+ state.current_memory_bytes += frame.nbytes
+ if state.current_memory_bytes > state.max_memory_bytes:
+ logger.log_warning(
+ "Viewer recording exceeded the configured memory budget. "
+ "Press 'r' again to flush the buffered frames to disk."
+ )
+ state.task_status = TASK_RETURN.TASK_EXIT
+
+ return state.task_status
+
+ def _save_window_record_worker(
+ self,
+ frames: list[np.ndarray],
+ output_dir: str,
+ video_name: str,
+ save_kwargs: dict[str, object],
+ ) -> None:
+ """Encode buffered frames into a video file in a background thread."""
+ from dexsim.utility import images_to_video
+
+ try:
+ os.makedirs(output_dir, exist_ok=True)
+ images_to_video(
+ images=frames,
+ output_dir=output_dir,
+ video_name=video_name,
+ **save_kwargs,
+ )
+ logger.log_info(
+ f"Viewer recording saved to {os.path.join(output_dir, video_name + '.mp4')}"
+ )
+ except Exception as exc:
+ logger.log_error(f"Failed to save viewer recording: {exc}")
+
+ def start_window_record(
+ self,
+ save_path: str | None = None,
+ fps: int = 20,
+ max_memory: int = 1024,
+ video_prefix: str = "viewer_record",
+ ) -> bool:
+ """Start asynchronously recording the viewer by buffering frames from a hidden camera
+ that follows the live window camera pose.
+ """
+ if self._window is None:
+ logger.log_warning("No simulation window available for viewer recording.")
+ return False
+ width = self.sim_config.width
+ height = self.sim_config.height
+ if self._window_record_camera is None:
+ camera_name = f"viewer_record_camera_{self.instance_id}"
+ self._window_record_camera = self._env.create_camera(
+ camera_name, width, height
+ )
+ record_camera = self._window_record_camera
+ if hasattr(record_camera, "is_open") and record_camera.is_open() is False:
+ record_camera.open_camera()
+
+ time_step = 1.0 / float(fps)
+ output_dir, video_name = self._build_window_record_output(
+ save_path, video_prefix
+ )
+ state = _WindowRecordState(
+ time_step=time_step,
+ max_memory_bytes=max_memory * 1024 * 1024,
+ output_dir=output_dir,
+ video_name=video_name,
+ save_kwargs={"fps": fps},
+ record_camera=record_camera,
+ last_capture_time=time.time() - time_step,
+ )
+
+ def _window_record_loop(_: float) -> int:
+ return self._step_window_record(state)
+
+ state.loop_handle = self._world.thread_rt().add_loop(
+ _window_record_loop, time_step
+ )
+ self._window_record_state = state
+
+ logger.log_info(
+ f"Viewer recording started. Press 'r' again to stop and save to "
+ f"{os.path.join(output_dir, video_name + '.mp4')}"
+ )
+ return True
+
+ def stop_window_record(self, save_path: str | None = None) -> bool:
+ """Stop the active viewer recording and save frames in the background."""
+ if self._window_record_state is None:
+ logger.log_warning("No active viewer recording session found.")
+ return False
+
+ state = self._window_record_state
+ state.task_status = TASK_RETURN.TASK_EXIT
+ if save_path is not None:
+ output_dir, video_name = self._build_window_record_output(
+ save_path, "viewer_record"
+ )
+ else:
+ output_dir, video_name = state.output_dir, state.video_name
+
+ if state.record_camera is not None and hasattr(state.record_camera, "is_open"):
+ if state.record_camera.is_open():
+ state.record_camera.close_camera()
+
+ frames = list(state.frames)
+ self._window_record_state = None
+ if len(frames) == 0:
+ logger.log_warning(
+ "Viewer recording stopped, but no frames were captured. Skipping video export."
+ )
+ return False
+
+ self._window_record_save_threads = [
+ thread for thread in self._window_record_save_threads if thread.is_alive()
+ ]
+ save_thread = threading.Thread(
+ target=self._save_window_record_worker,
+ args=(frames, output_dir, video_name, dict(state.save_kwargs)),
+ daemon=False,
+ )
+ save_thread.start()
+ self._window_record_save_threads.append(save_thread)
+ logger.log_info(
+ "Viewer recording stopped. Saving video to "
+ f"{os.path.join(output_dir, video_name + '.mp4')} in background."
+ )
+ return True
+
+ def toggle_window_record(
+ self,
+ save_path: str | None = None,
+ fps: int = 20,
+ max_memory: int = 1024,
+ video_prefix: str = "viewer_record",
+ ) -> bool:
+ """Toggle viewer recording on or off."""
+ if self.is_window_recording():
+ return self.stop_window_record(save_path=save_path)
+ return self.start_window_record(
+ save_path=save_path,
+ fps=fps,
+ max_memory=max_memory,
+ video_prefix=video_prefix,
+ )
+
+ def enable_window_record_hotkey(
+ self,
+ save_path: str | None = None,
+ fps: int = 20,
+ max_memory: int = 1024,
+ video_prefix: str = "viewer_record",
+ ) -> bool:
+ """Register the ``r`` key to start/stop viewer recording."""
+ self._window_record_hotkey_cfg = {
+ "save_path": save_path,
+ "fps": fps,
+ "max_memory": max_memory,
+ "video_prefix": video_prefix,
+ }
+ if self._window is None:
+ logger.log_warning(
+ "No simulation window available yet. The viewer record hotkey will be registered after `open_window()`."
+ )
+ return False
+ if self._window_record_input_control is not None:
+ return True
+
+ from dexsim.types import InputKey
+
+ sim = self
+ hotkey_cfg = dict(self._window_record_hotkey_cfg)
+
+ class WindowRecordEvent(ObjectManipulator):
+ def on_key_down(self, key):
+ if key == InputKey.SCANCODE_R.value:
+ sim.toggle_window_record(**hotkey_cfg)
+
+ self._window_record_input_control = WindowRecordEvent()
+ self._window.add_input_control(self._window_record_input_control)
+ logger.log_info(
+ "Viewer record hotkey registered. Press 'r' to start/stop recording."
+ )
+ return True
+
def create_visual_material(self, cfg: VisualMaterialCfg) -> VisualMaterial:
"""Create a visual material with given configuration.
@@ -1735,7 +1931,8 @@ def get_visual_material(self, uid: str) -> VisualMaterial:
def clean_materials(self):
self._visual_materials = {}
- self._env.clean_materials()
+ if self._env:
+ self._env.clean_materials()
def reset_objects_state(
self,
@@ -1785,15 +1982,136 @@ def export_usd(self, fpath: str) -> bool:
logger.log_error(f"Failed to export simulation scene to USD: {e}")
return False
+ @staticmethod
+ def wait_scene_destruction(timeout_ms: int = 10000) -> None:
+ """A public helper to wait for the underlying C++ scenes (dexsim.World) to destruct completely."""
+ import dexsim
+ import gc
+
+ # Force garbage collection to break cycle references
+ gc.collect()
+
+ import time
+
+ wait_times = 0
+ scene_count = dexsim.get_world_num()
+ max_loops = timeout_ms // 10
+ while scene_count > 0 and wait_times < max_loops:
+ time.sleep(0.01)
+ scene_count = dexsim.get_world_num()
+ wait_times += 1
+ if wait_times % 50 == 0:
+ from embodichain.utils import logger
+
+ logger.log_info(
+ f"Waiting for dexsim.World scenes to destruct. Remaining scenes: {scene_count}"
+ )
+ if scene_count > 0:
+ from embodichain.utils import logger
+
+ logger.log_warning(
+ f"Scene destruction wait timeout, {scene_count} C++ scene(s) still alive!"
+ )
+
def destroy(self) -> None:
+ """
+ No longer destructs C++ objects in place due to lingering deep local variables;
+ instead, packages itself into a destruction task, submits to the cleanup queue,
+ and waits for top-level delayed consumption.
+ """
+ self._is_pending_kill = True
+
+ # Transfer the actual destruction logic to the cleanup queue
+ SimulationManager._cleanup_queue.put(self._deferred_destroy)
+
+ def _deferred_destroy(self) -> None:
"""Destroy all simulated assets and release resources."""
# Clean up all gizmos before destroying the simulation
for uid in list(self._gizmos.keys()):
self.disable_gizmo(uid)
+ import sys, gc
+
self.clean_materials()
- self._env.clean()
- self._world.quit()
+ if self._env:
+ self._env.clean()
+ if self._world:
+ self._world.quit()
+
+ # REMOVE INSTANCE FROM POOL
+ instance_id = getattr(self, "instance_id", 0)
+ SimulationManager.reset(instance_id)
+
+ # Helper to aggressively decouple C++ wrapped objects
+ def _sever_wrapper_refs(obj_registry):
+ if not hasattr(self, obj_registry):
+ return
+ registry = getattr(self, obj_registry)
+ if not isinstance(registry, dict):
+ return
+ for uid, obj in registry.items():
+ if hasattr(obj, "_world"):
+ obj._world = None
+ if hasattr(obj, "_ps"):
+ obj._ps = None
+ if hasattr(obj, "_env"):
+ obj._env = None
+ if hasattr(obj, "_entities"):
+ obj._entities = []
+ registry.clear()
+
+ _sever_wrapper_refs("_gizmos")
+ _sever_wrapper_refs("_markers")
+ _sever_wrapper_refs("_rigid_objects")
+ _sever_wrapper_refs("_rigid_object_groups")
+ _sever_wrapper_refs("_soft_objects")
+ _sever_wrapper_refs("_cloth_objects")
+ _sever_wrapper_refs("_articulations")
+ _sever_wrapper_refs("_robots")
+ _sever_wrapper_refs("_sensors")
+ _sever_wrapper_refs("_lights")
+
+ # Explicitly clear Python references to trigger C++ object destructors
+ self._ps = None
+ self._env = None
+ self._world = None
+ self._default_plane = None
+
+ # Try to break ANY possible frame cycle
+ gc.collect()
+
+ self._visual_materials.clear()
+ self._texture_cache.clear()
+ self._arenas.clear()
+ self._markers.clear()
+ self._gizmos.clear()
SimulationManager.reset(self.instance_id)
+
+ # Forcefully drop underlying C++ object wrappers
+ self._env = None
+ self._world = None
+
+ gc.collect()
+
+ @staticmethod
+ def flush_cleanup_queue():
+ """Dequeue executor and synchronization barrier provided for top-level main loop / Pytest Fixture calls"""
+ import gc
+
+ while not SimulationManager._cleanup_queue.empty():
+ task = SimulationManager._cleanup_queue.get_nowait()
+ try:
+ task()
+ except Exception as e:
+ from embodichain.utils import logger
+
+ logger.log_error(f"Error during delayed destruction: {e}")
+ pass
+
+ # After the queue is emptied, perform a top-level full GC to thoroughly reclaim dead objects that haven't released their RefPtrs yet
+ gc.collect()
+
+ # At this point, wait for the C++ Scene to return to zero, since the stack is at the top level, there will definitely be no deadlock
+ SimulationManager.wait_scene_destruction()
diff --git a/embodichain/lab/sim/solvers/base_solver.py b/embodichain/lab/sim/solvers/base_solver.py
index 143e3a89..c7fc70f2 100644
--- a/embodichain/lab/sim/solvers/base_solver.py
+++ b/embodichain/lab/sim/solvers/base_solver.py
@@ -72,6 +72,13 @@ class SolverCfg:
when multiple solutions are available.
"""
+ user_qpos_limits: List[float] | None = None
+ """
+ User defined Joint position limits [2, DOF] for the solver.
+ If not provided (None), this value will replace by joint limits defined in urdf when solver init from robot.
+ If provided, the solver will use the intersection of user defined limits and urdf limits as the final joint limits.
+ """
+
@abstractmethod
def init_solver(self, device: torch.device, **kwargs) -> "BaseSolver":
pass
@@ -165,6 +172,14 @@ def __init__(self, cfg: SolverCfg = None, device: str = None, **kwargs):
device=self.device,
)
+ self.compiled_fk = torch.compile(
+ self.pk_serial_chain.forward_kinematics_tensor,
+ fullgraph=True,
+ dynamic=True,
+ )
+
+ self._init_qpos_limits()
+
def set_ik_nearest_weight(
self, ik_weight: np.ndarray, joint_ids: np.ndarray | None = None
) -> bool:
@@ -223,51 +238,126 @@ def get_ik_nearest_weight(self):
"""
return self.ik_nearest_weight
- def set_position_limits(
+ def _init_qpos_limits(self):
+ self.lower_qpos_limits = None
+ self.upper_qpos_limits = None
+ if self.cfg.user_qpos_limits is not None:
+ # robot qpos limits from config, expected shape [DOF, 2]
+ user_qpos_limits = torch.tensor(
+ self.cfg.user_qpos_limits, dtype=torch.float32, device=self.device
+ )
+ if user_qpos_limits.shape == (2, self.dof):
+ self.set_qpos_limits(
+ lower_qpos_limits=user_qpos_limits[0],
+ upper_qpos_limits=user_qpos_limits[1],
+ )
+ elif user_qpos_limits.shape == (self.dof, 2):
+ self.set_qpos_limits(
+ lower_qpos_limits=user_qpos_limits[:, 0],
+ upper_qpos_limits=user_qpos_limits[:, 1],
+ )
+ else:
+ logger.log_error(
+ f"user_qpos_limits must have shape (2, {self.dof}) or ({self.dof}, 2), but got {user_qpos_limits.shape}."
+ )
+ elif self.pk_serial_chain is not None:
+ self.set_qpos_limits(
+ lower_qpos_limits=self.pk_serial_chain.low,
+ upper_qpos_limits=self.pk_serial_chain.high,
+ )
+
+ def update_with_robot_limit(self, robot_qpos_limits: torch.Tensor):
+ """Update with robot joint limits.
+ Make sure the solver's joint limits are within the robot's joint limits.
+
+ Args:
+ robot_qpos_limits (torch.Tensor): [DOF, 2] tensor of joint limits from the robot data
+ """
+ robot_lower_limits = robot_qpos_limits[:, 0]
+ robot_upper_limits = robot_qpos_limits[:, 1]
+ if self.lower_qpos_limits is not None:
+ if torch.any(self.lower_qpos_limits < robot_lower_limits):
+ logger.log_warning(
+ "Solver lower_qpos_limits are smaller than robot limits. Clamping to robot limits."
+ )
+ self.lower_qpos_limits = torch.max(
+ self.lower_qpos_limits, robot_lower_limits
+ )
+ else:
+ self.lower_qpos_limits = robot_lower_limits
+ if self.upper_qpos_limits is not None:
+ if torch.any(self.upper_qpos_limits > robot_upper_limits):
+ logger.log_warning(
+ "Solver upper_qpos_limits are larger than robot limits. Clamping to robot limits."
+ )
+ self.upper_qpos_limits = torch.min(
+ self.upper_qpos_limits, robot_upper_limits
+ )
+ else:
+ self.upper_qpos_limits = robot_upper_limits
+
+ def set_qpos_limits(
self,
- lower_position_limits: List[float],
- upper_position_limits: List[float],
+ lower_qpos_limits: List[float],
+ upper_qpos_limits: List[float],
) -> bool:
r"""Sets the upper and lower joint position limits.
Parameters:
- lower_position_limits (List[float]): A list of lower limits for each joint.
- upper_position_limits (List[float]): A list of upper limits for each joint.
+ lower_qpos_limits (List[float]): A list of lower limits for each joint.
+ upper_qpos_limits (List[float]): A list of upper limits for each joint.
Returns:
bool: True if limits are successfully set, False if the input is invalid.
"""
- if (
- len(lower_position_limits) != self.model.nq
- or len(upper_position_limits) != self.model.nq
- ):
- logger.log_warning("Length of limits must match the number of joints.")
- return False
if any(
- lower > upper
- for lower, upper in zip(lower_position_limits, upper_position_limits)
+ lower > upper for lower, upper in zip(lower_qpos_limits, upper_qpos_limits)
):
logger.log_warning(
"Each lower limit must be less than or equal to the corresponding upper limit."
)
return False
- self.lower_position_limits = np.array(lower_position_limits)
- self.upper_position_limits = np.array(upper_position_limits)
+ if isinstance(lower_qpos_limits, list) or isinstance(
+ lower_qpos_limits, np.ndarray
+ ):
+ self.lower_qpos_limits = torch.tensor(
+ lower_qpos_limits, dtype=float, device=self.device
+ )
+ elif isinstance(lower_qpos_limits, torch.Tensor):
+ self.lower_qpos_limits = lower_qpos_limits.clone().to(device=self.device)
+ else:
+ logger.log_error(
+ f"Invalid type for lower_qpos_limits: {type(lower_qpos_limits)}. Must be list, np.ndarray, or torch.Tensor."
+ )
+
+ if isinstance(upper_qpos_limits, list) or isinstance(
+ upper_qpos_limits, np.ndarray
+ ):
+ self.upper_qpos_limits = torch.tensor(
+ upper_qpos_limits, dtype=float, device=self.device
+ )
+ elif isinstance(upper_qpos_limits, torch.Tensor):
+ self.upper_qpos_limits = upper_qpos_limits.clone().to(device=self.device)
+ else:
+ logger.log_error(
+ f"Invalid type for upper_qpos_limits: {type(upper_qpos_limits)}. Must be list, np.ndarray, or torch.Tensor."
+ )
+
return True
- def get_position_limits(self) -> dict:
+ def get_qpos_limits(self) -> dict:
r"""Returns the current joint position limits.
Returns:
dict: A dictionary containing:
- - lower_position_limits (List[float]): The current lower limits for each joint.
- - upper_position_limits (List[float]): The current upper limits for each joint.
+ - lower_qpos_limits (List[float]): The current lower limits for each joint.
+ - upper_qpos_limits (List[float]): The current upper limits for each joint.
"""
return {
- "lower_position_limits": self.lower_position_limits.tolist(),
- "upper_position_limits": self.upper_position_limits.tolist(),
+ "lower_qpos_limits": self.lower_qpos_limits.tolist(),
+ "upper_qpos_limits": self.upper_qpos_limits.tolist(),
}
def set_tcp(self, xpos: np.ndarray):
@@ -339,35 +429,18 @@ def get_fk(self, qpos: torch.tensor, **kwargs) -> torch.Tensor:
)
qpos = torch.as_tensor(qpos, dtype=torch.float32, device=self.device)
+ if self.pk_serial_chain is None:
+ logger.log_error("Kinematic chain is not initialized.")
+ return torch.eye(4, device=self.device)
# Compute forward kinematics
- result = self.pk_serial_chain.forward_kinematics(
- qpos, end_only=(self.end_link_name is None)
- )
-
- # Extract transformation matrices
- if isinstance(result, dict):
- matrices = result[self.end_link_name].get_matrix()
- elif isinstance(result, list):
- matrices = torch.stack([xpos.get_matrix().squeeze() for xpos in result])
- else:
- matrices = result.get_matrix()
-
- # Ensure batch format
- if matrices.dim() == 2:
- matrices = matrices.unsqueeze(0)
-
- # Create result tensor with proper homogeneous coordinates
- result = (
- torch.eye(4, device=self.device).expand(matrices.shape[0], 4, 4).clone()
- )
- result[:, :3, :] = matrices[:, :3, :]
+ ee_link_xpos = self.compiled_fk(qpos)[-1, :, :, :]
# Ensure batch format for TCP
- batch_size = result.shape[0]
+ batch_size = qpos.shape[0]
tcp_xpos_batch = tcp_xpos.unsqueeze(0).expand(batch_size, -1, -1)
# Apply TCP transformation
- return torch.bmm(result, tcp_xpos_batch)
+ return torch.bmm(ee_link_xpos, tcp_xpos_batch)
def get_jacobian(
self,
diff --git a/embodichain/lab/sim/solvers/differential_solver.py b/embodichain/lab/sim/solvers/differential_solver.py
index fc6e596b..12e51bcb 100644
--- a/embodichain/lab/sim/solvers/differential_solver.py
+++ b/embodichain/lab/sim/solvers/differential_solver.py
@@ -25,7 +25,6 @@
compute_pose_error,
)
-
if TYPE_CHECKING:
from typing import Self
diff --git a/embodichain/lab/sim/solvers/opw_solver.py b/embodichain/lab/sim/solvers/opw_solver.py
index 4d8f9047..e64cc99c 100644
--- a/embodichain/lab/sim/solvers/opw_solver.py
+++ b/embodichain/lab/sim/solvers/opw_solver.py
@@ -29,12 +29,11 @@
OPWparam,
opw_fk_kernel,
opw_ik_kernel,
- opw_best_ik_kernel,
+ opw_ik_select_kernel,
wp_vec6f,
)
from embodichain.utils.device_utils import standardize_device_string
-
if TYPE_CHECKING:
from typing import Self
@@ -72,6 +71,9 @@ class OPWSolverCfg(SolverCfg):
# Parameters for the inverse-kinematics method.
ik_params: dict | None = None
+ # safe margin for joint limits, in radians
+ safe_margin: float = 0.0 # 5.0 * np.pi / 180.0
+
def init_solver(
self, device: torch.device = torch.device("cpu"), **kwargs
) -> "OPWSolver":
@@ -247,23 +249,44 @@ def get_ik_warp(
N_SOL = 8
DOF = 6
n_sample = target_xpos.shape[0]
+ kernel_device = standardize_device_string(self.device)
if target_xpos.shape == (4, 4):
- target_xpos_batch = target_xpos[None, :, :]
+ target_xpos_batch = target_xpos[None, :, :].to(kernel_device)
else:
- target_xpos_batch = target_xpos
+ target_xpos_batch = target_xpos.to(kernel_device)
target_xpos_wp = wp.from_torch(target_xpos_batch.reshape(-1))
all_qpos_wp = wp.zeros(
n_sample * N_SOL * DOF,
dtype=float,
- device=standardize_device_string(self.device),
+ device=standardize_device_string(kernel_device),
)
all_ik_valid_wp = wp.zeros(
- n_sample * N_SOL, dtype=int, device=standardize_device_string(self.device)
+ n_sample * N_SOL, dtype=int, device=standardize_device_string(kernel_device)
)
# TODO: whether require gradient
+ offsets_ = self.offsets.to(standardize_device_string(kernel_device))
+ sign_corrections_ = self.sign_corrections.to(
+ standardize_device_string(kernel_device)
+ )
+ lower_limits_ = wp_vec6f(
+ self.lower_qpos_limits[0],
+ self.lower_qpos_limits[1],
+ self.lower_qpos_limits[2],
+ self.lower_qpos_limits[3],
+ self.lower_qpos_limits[4],
+ self.lower_qpos_limits[5],
+ )
+ upper_limits_ = wp_vec6f(
+ self.upper_qpos_limits[0],
+ self.upper_qpos_limits[1],
+ self.upper_qpos_limits[2],
+ self.upper_qpos_limits[3],
+ self.upper_qpos_limits[4],
+ self.upper_qpos_limits[5],
+ )
wp.launch(
kernel=opw_ik_kernel,
dim=(n_sample),
@@ -271,26 +294,42 @@ def get_ik_warp(
target_xpos_wp,
self._tcp_inv_warp,
self.params,
- self.offsets,
- self.sign_corrections,
+ offsets_,
+ sign_corrections_,
+ lower_limits_,
+ upper_limits_,
+ self.cfg.safe_margin,
),
outputs=[all_qpos_wp, all_ik_valid_wp],
- device=standardize_device_string(self.device),
+ device=standardize_device_string(kernel_device),
)
if return_all_solutions:
all_qpos = wp.to_torch(all_qpos_wp).reshape(n_sample, N_SOL, DOF)
all_ik_valid = wp.to_torch(all_ik_valid_wp).reshape(n_sample, N_SOL)
return all_ik_valid, all_qpos
-
if qpos_seed is not None:
- qpos_seed_wp = wp.from_torch(qpos_seed.reshape(-1))
+ if qpos_seed.shape == (
+ n_sample,
+ DOF,
+ ):
+ qpos_seed_ = qpos_seed.to(kernel_device)
+ elif qpos_seed.shape == (DOF,):
+ qpos_seed_ = (
+ qpos_seed.unsqueeze(0).repeat(n_sample, 1).to(kernel_device)
+ )
+ else:
+ logger.log_error(
+ f"Invalid shape for qpos_seed: {qpos_seed.shape}. Expected ({n_sample}, {DOF}) or ({DOF},)."
+ )
+ qpos_seed_wp = wp.from_torch(qpos_seed_)
else:
- qpos_seed_wp = wp.zeros(
- n_sample * DOF,
- dtype=float,
- device=standardize_device_string(self.device),
+ qpos_seed = torch.zeros(
+ (n_sample, DOF), dtype=torch.float32, device=kernel_device
)
+ qpos_seed_wp = wp.from_torch(qpos_seed)
+ all_qpos_wp = all_qpos_wp.reshape((n_sample, N_SOL, DOF))
+ all_ik_valid_wp = all_ik_valid_wp.reshape((n_sample, N_SOL))
joint_weight = kwargs.get("joint_weight", torch.ones(size=(DOF,), dtype=float))
joint_weight_wp = wp_vec6f(
joint_weight[0],
@@ -301,13 +340,13 @@ def get_ik_warp(
joint_weight[5],
)
best_ik_result_wp = wp.zeros(
- n_sample * 6, dtype=float, device=standardize_device_string(self.device)
+ (n_sample, 6), dtype=float, device=standardize_device_string(kernel_device)
)
best_ik_valid_wp = wp.zeros(
- n_sample, dtype=int, device=standardize_device_string(self.device)
+ n_sample, dtype=int, device=standardize_device_string(kernel_device)
)
wp.launch(
- kernel=opw_best_ik_kernel,
+ kernel=opw_ik_select_kernel,
dim=(n_sample),
inputs=[
all_qpos_wp,
@@ -315,11 +354,17 @@ def get_ik_warp(
qpos_seed_wp,
joint_weight_wp,
],
- outputs=[best_ik_result_wp, best_ik_valid_wp],
- device=standardize_device_string(self.device),
+ outputs=[
+ best_ik_result_wp,
+ best_ik_valid_wp,
+ ],
+ device=standardize_device_string(kernel_device),
+ )
+
+ best_ik_result = (
+ wp.to_torch(best_ik_result_wp).reshape(n_sample, 1, 6).to(self.device)
)
- best_ik_result = wp.to_torch(best_ik_result_wp).reshape(n_sample, 1, 6)
- best_ik_valid = wp.to_torch(best_ik_valid_wp)
+ best_ik_valid = wp.to_torch(best_ik_valid_wp).to(self.device)
return best_ik_valid, best_ik_result
def _calculate_dynamic_weights(
diff --git a/embodichain/lab/sim/solvers/pinocchio_solver.py b/embodichain/lab/sim/solvers/pinocchio_solver.py
index ec7e345a..9ddde65b 100644
--- a/embodichain/lab/sim/solvers/pinocchio_solver.py
+++ b/embodichain/lab/sim/solvers/pinocchio_solver.py
@@ -35,7 +35,6 @@
compute_pinocchio_fk,
)
-
if TYPE_CHECKING:
from typing import Self
@@ -129,9 +128,6 @@ def __init__(self, cfg: PinocchioSolverCfg, **kwargs):
self.robot.model.njoints - 1
) # Degrees of freedom of reduced robot joints
- self.upper_position_limits = self.robot.model.upperPositionLimit
- self.lower_position_limits = self.robot.model.lowerPositionLimit
-
self.ik_nearest_weight = np.ones(self.dof)
# TODO: The Casadi-based solver is currently disabled due to stability issues.
@@ -325,12 +321,14 @@ def qpos_to_limits(
# Generate possible values for each joint
dof_num = len(q)
+ lower_limits = self.lower_qpos_limits.to("cpu").numpy()
+ upper_limits = self.upper_qpos_limits.to("cpu").numpy()
for i in range(dof_num):
current_possible_values = []
# Calculate how many 2π fits into the adjustment to the limits
- lower_adjustment = (q[i] - self.lower_position_limits[i]) // (2 * np.pi)
- upper_adjustment = (self.upper_position_limits[i] - q[i]) // (2 * np.pi)
+ lower_adjustment = (q[i] - lower_limits[i]) // (2 * np.pi)
+ upper_adjustment = (upper_limits[i] - q[i]) // (2 * np.pi)
# Consider the current value and its periodic adjustments
for offset in range(
@@ -339,15 +337,11 @@ def qpos_to_limits(
adjusted_value = q[i] + offset * (2 * np.pi)
# Check if the adjusted value is within limits
- if (
- self.lower_position_limits[i]
- <= adjusted_value
- <= self.upper_position_limits[i]
- ):
+ if lower_limits[i] <= adjusted_value <= upper_limits[i]:
current_possible_values.append(adjusted_value)
# Also check the original value
- if self.lower_position_limits[i] <= q[i] <= self.upper_position_limits[i]:
+ if lower_limits[i] <= q[i] <= upper_limits[i]:
current_possible_values.append(q[i])
if not current_possible_values:
diff --git a/embodichain/lab/sim/solvers/pytorch_solver.py b/embodichain/lab/sim/solvers/pytorch_solver.py
index cdcdc562..c0fcf465 100644
--- a/embodichain/lab/sim/solvers/pytorch_solver.py
+++ b/embodichain/lab/sim/solvers/pytorch_solver.py
@@ -170,13 +170,11 @@ def __init__(
max_iterations=self._max_iterations,
lr=self._dt,
num_retries=1,
+ use_compile=True,
)
self.dof = self.pk_serial_chain.n_joints
- self.upper_position_limits = self.pk_serial_chain.high
- self.lower_position_limits = self.pk_serial_chain.low
-
def get_iteration_params(self) -> dict:
r"""Returns the current iteration parameters.
@@ -247,6 +245,7 @@ def set_iteration_params(
max_iterations=self._max_iterations,
lr=self._dt,
num_retries=1,
+ use_compile=True,
)
return True
@@ -284,106 +283,40 @@ def _compute_inverse_kinematics(
self.pik.initial_config = joint_seed
result = self.pik.solve(tf)
+ return result.converged_any, result.solutions[:, 0, :].squeeze(0)
- if result.converged_any.any().item():
- return result.converged_any, result.solutions[:, 0, :].squeeze(0)
-
- return False, torch.empty(0)
-
- @staticmethod
- def _qpos_to_limits_single(
- q: torch.Tensor,
- joint_seed: torch.Tensor,
- lower_position_limits: torch.Tensor,
- upper_position_limits: torch.Tensor,
- ik_nearest_weight: torch.Tensor,
- periodic_mask: torch.Tensor = None, # Optional mask for periodic joints
- ) -> torch.Tensor:
- """
- Adjusts the given joint positions (q) to fit within the specified limits while minimizing the difference to the seed position.
-
- Args:
- q (torch.Tensor): The initial joint positions.
- joint_seed (torch.Tensor): The seed joint positions for comparison.
- lower_position_limits (torch.Tensor): The lower bounds for the joint positions.
- upper_position_limits (torch.Tensor): The upper bounds for the joint positions.
- ik_nearest_weight (torch.Tensor): The weights for the inverse kinematics nearest calculation.
- periodic_mask (torch.Tensor, optional): Boolean mask indicating which joints are periodic.
-
- Returns:
- torch.Tensor: The adjusted joint positions that fit within the limits.
- """
- device = q.device
- joint_seed = joint_seed.to(device)
- lower = lower_position_limits.to(device)
- upper = upper_position_limits.to(device)
- weight = ik_nearest_weight.to(device)
-
- # If periodic_mask is not provided, assume all joints are periodic
- if periodic_mask is None:
- periodic_mask = torch.ones_like(q, dtype=torch.bool, device=device)
-
- # Only enumerate [-2π, 0, 2π] for periodic joints, single value for non-periodic
- offsets = torch.tensor([-2 * torch.pi, 0, 2 * torch.pi], device=device)
- candidate_list = []
- for i in range(q.size(0)):
- if periodic_mask[i]:
- candidate_list.append(q[i] + offsets)
- else:
- candidate_list.append(q[i].unsqueeze(0))
- # Generate all possible combinations
- mesh = torch.meshgrid(*candidate_list, indexing="ij")
- candidates = torch.stack([m.reshape(-1) for m in mesh], dim=1)
- # Filter candidates that are out of limits
- mask = (candidates >= lower) & (candidates <= upper)
- valid_mask = mask.all(dim=1)
- valid_candidates = candidates[valid_mask]
- if valid_candidates.shape[0] == 0:
- return torch.tensor([]).to(device)
- # Compute weighted distance to seed and select the closest
- diffs = torch.abs(valid_candidates - joint_seed) * weight
- distances = torch.sum(diffs, dim=1)
- min_idx = torch.argmin(distances)
- return valid_candidates[min_idx]
-
- def _qpos_to_limits(
- self, qpos_list_split: torch.Tensor, joint_seed: torch.Tensor
- ) -> torch.Tensor:
- r"""Adjusts a batch of joint positions to fit within joint limits and minimize the weighted distance to the seed position.
+ def _qpos_map_to_limits(
+ self, qpos: torch.Tensor
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ r"""Maps a batch of joint positions to fit within joint limits and computes the distance to the seed position.
Args:
- qpos_list_split (torch.Tensor): Batch of candidate joint positions, shape (N, dof).
- joint_seed (torch.Tensor): The reference joint positions for comparison, shape (dof,).
-
+ qpos (torch.Tensor): Batch of candidate joint positions, shape (N, dof).
Returns:
- torch.Tensor: Batch of adjusted joint positions that fit within the limits, shape (M, dof),
- where M <= N (invalid candidates are filtered out).
+ tuple[torch.Tensor, torch.Tensor]: A tuple containing:
+ - torch.Tensor: whether qpos exactly within joint limit, shape (N).
+ - torch.Tensor: qpos that roughly mapped into joint limit, shape (N, dof).
"""
-
- periodic_mask = torch.ones_like(
- qpos_list_split[0], dtype=torch.bool, device=self.device
+ two_pi = 2.0 * torch.pi
+ k = torch.ceil((self.lower_qpos_limits - qpos) / two_pi)
+ qpos_mapped = qpos + k * two_pi
+ is_within_limits = (qpos_mapped >= self.lower_qpos_limits) & (
+ qpos_mapped <= self.upper_qpos_limits
)
- adjusted_qpos_list = [
- self._qpos_to_limits_single(
- q,
- joint_seed,
- self.lower_position_limits,
- self.upper_position_limits,
- self.ik_nearest_weight,
- periodic_mask,
- )
- for q in qpos_list_split
+ # if qpos_mapped is valid near zero, use it
+ k_zero = torch.ceil(
+ (-torch.pi - qpos) / two_pi
+ ) # [-pi, pi] is the valid range near zero
+ qpos_mapped_near_zero = qpos + k_zero * two_pi
+ is_within_limits_near_zero = (
+ qpos_mapped_near_zero >= self.lower_qpos_limits
+ ) & (qpos_mapped_near_zero <= self.upper_qpos_limits)
+ qpos_mapped[is_within_limits_near_zero] = qpos_mapped_near_zero[
+ is_within_limits_near_zero
]
- # Filter out empty results
- adjusted_qpos_list = [q for q in adjusted_qpos_list if q.numel() > 0]
-
- return (
- torch.stack(adjusted_qpos_list).to(qpos_list_split.device)
- if adjusted_qpos_list
- else torch.tensor([], device=self.device)
- )
+ return is_within_limits.all(dim=1), qpos_mapped
@ensure_pose_shape
def get_ik(
@@ -433,27 +366,28 @@ def get_ik(
qpos_seed = torch.as_tensor(qpos_seed, device=self.device)
# Check qpos_seed dimensions
- if qpos_seed.dim() == 1:
- qpos_seed = qpos_seed.unsqueeze(0)
- qpos_seed_ndim = 1
- elif qpos_seed.dim() == 2:
- qpos_seed_ndim = 2
- if qpos_seed.shape[0] != target_xpos.shape[0]:
- raise ValueError(
- "Batch size of qpos_seed must match batch size of target_xpos when qpos_seed is a 2D tensor."
- )
+ n_batch = target_xpos.shape[0]
+ if qpos_seed.shape == (n_batch, self.dof):
+ qpos_seed = qpos_seed
+ elif qpos_seed.shape == (self.dof,):
+ qpos_seed = qpos_seed.unsqueeze(0).repeat(n_batch, 1)
else:
- raise ValueError("`qpos_seed` must be a tensor of shape (n,) or (n, n).")
+ logger.log_error(
+ f"Invalid qpos_seed shape {qpos_seed.shape} for batch_size {n_batch} and dof {self.dof}",
+ ValueError,
+ )
+ # output qpos_seed shape: (batch_size, dof)
# Transform target_xpos by TCP
tcp_xpos = torch.as_tensor(
- deepcopy(self.tcp_xpos), device=self.device, dtype=torch.float32
+ self.tcp_xpos, device=self.device, dtype=torch.float32
)
- target_xpos = target_xpos @ torch.inverse(tcp_xpos)
+ tcp_xpos_inv = tcp_xpos.clone()
+ tcp_xpos_inv[:3, :3] = tcp_xpos_inv[:3, :3].T
+ tcp_xpos_inv[:3, 3] = -tcp_xpos_inv[:3, :3] @ tcp_xpos_inv[:3, 3]
+ target_xpos = target_xpos @ tcp_xpos_inv
# Get joint limits and ensure shape matches dof
- upper_limits = self.upper_position_limits.float()
- lower_limits = self.lower_position_limits.float()
batch_size = target_xpos.shape[0]
@@ -461,79 +395,43 @@ def get_ik(
num_samples=self._num_samples, dof=self.dof, device=self.device
)
random_qpos_seeds = sampler.sample(
- qpos_seed, lower_limits, upper_limits, batch_size
+ qpos_seed,
+ self.lower_qpos_limits,
+ self.upper_qpos_limits,
+ batch_size,
)
target_xpos_repeated = sampler.repeat_target_xpos(
target_xpos, self._num_samples
)
# Compute IK solutions for all samples
- res_list, qpos_list = self._compute_inverse_kinematics(
+ is_ik_success, ik_qpos = self._compute_inverse_kinematics(
target_xpos_repeated, random_qpos_seeds
)
-
- if not isinstance(res_list, torch.Tensor) or not res_list.any():
- logger.log_warning(
- "Pk: No valid solutions found for the given target poses and joint seeds."
- )
- return torch.zeros(
- batch_size, dtype=torch.bool, device=self.device
- ), torch.zeros((batch_size, self.dof), device=self.device)
-
- # Split res_list and qpos_list according to self._num_samples
- res_list_split = torch.split(res_list, self._num_samples)
- qpos_list_split = torch.split(qpos_list, self._num_samples)
-
- # Initialize the final results and the closest joint positions
- final_results = []
- final_qpos = []
-
- # For each batch, select the closest valid solution to qpos_seed
- for i in range(batch_size):
- target_qpos_seed = qpos_seed[i] if qpos_seed_ndim == 2 else qpos_seed
-
- if not res_list_split[i].any():
- final_results.append(False)
- final_qpos.append(torch.zeros((1, self.dof), device=self.device))
- continue
-
- result_qpos_limit = self._qpos_to_limits(
- qpos_list_split[i], target_qpos_seed
- )
-
- if result_qpos_limit.shape[0] == 0:
- final_results.append(False)
- final_qpos.append(torch.zeros((self.dof), device=self.device))
- continue
-
- distances = torch.norm(result_qpos_limit - target_qpos_seed, dim=1)
- sorted_indices = torch.argsort(distances)
- # shape: (N, dof)
- sorted_qpos_array = result_qpos_limit[sorted_indices]
- final_qpos.append(sorted_qpos_array)
- final_results.append(True)
-
- # Pad all batches to the same number of solutions for stacking
- max_solutions = max([q.shape[0] for q in final_qpos]) if final_qpos else 1
- final_qpos_tensor = torch.zeros(
- (batch_size, max_solutions, self.dof), device=self.device
- )
- for i, q in enumerate(final_qpos):
- n = q.shape[0]
- final_qpos_tensor[i, :n, :] = q
-
- final_results = torch.tensor(
- final_results, dtype=torch.bool, device=self.device
- )
+ if is_ik_success.any().item() is False:
+ logger.log_warning("No IK solutions found for any of the target poses.")
+ failed_state = is_ik_success.reshape(batch_size, self._num_samples)[:, 0]
+ failed_qpos = ik_qpos.reshape(batch_size, self._num_samples, self.dof)[
+ :, 0, :
+ ]
+ return failed_state, failed_qpos
+ # map ik_qpos to within limits and check validity
+ is_mask_valid, ik_qpos_mapped = self._qpos_map_to_limits(ik_qpos)
+ is_success = torch.logical_and(is_ik_success, is_mask_valid)
+
+ all_is_success = is_success.reshape(batch_size, self._num_samples)
+ all_results = ik_qpos_mapped.reshape(batch_size, self._num_samples, self.dof)
if return_all_solutions:
- # Return all sorted solutions for each batch (shape: batch_size, max_solutions, dof)
- return final_results, final_qpos_tensor
-
- # Only return the closest solution for each batch (shape: batch_size, 1, dof)
- # If multiple solutions, take the first (closest)
- final_qpos_tensor = final_qpos_tensor[:, :1, :]
- return final_results, final_qpos_tensor
+ return all_is_success.any(dim=1), all_results
+ qpos_seed_repeat = qpos_seed.unsqueeze(1).repeat(1, self._num_samples, 1)
+ weighed_diff = self.ik_nearest_weight * (all_results - qpos_seed_repeat)
+ qpos_seed_dis = torch.norm(weighed_diff, dim=2)
+ # Tricky: mask out invalid solutions by setting distance to inf, so they won't be selected as closest
+ qpos_seed_dis[~all_is_success] = float("inf")
+ closest_indices = torch.argmin(qpos_seed_dis, dim=1)
+ closest_qpos = all_results[torch.arange(batch_size), closest_indices]
+ return all_is_success.any(dim=1), closest_qpos[:, None, :]
def get_all_fk(self, qpos: torch.tensor) -> torch.tensor:
r"""Get the forward kinematics for all links from root to end link.
diff --git a/embodichain/lab/sim/solvers/qpos_seed_sampler.py b/embodichain/lab/sim/solvers/qpos_seed_sampler.py
index c6a4ef30..03674506 100644
--- a/embodichain/lab/sim/solvers/qpos_seed_sampler.py
+++ b/embodichain/lab/sim/solvers/qpos_seed_sampler.py
@@ -15,6 +15,7 @@
# ----------------------------------------------------------------------------
import torch
+from embodichain.utils import logger
class QposSeedSampler:
@@ -52,22 +53,29 @@ def sample(
Returns:
torch.Tensor: (batch_size * num_samples, dof) joint seeds.
"""
- joint_seeds_list = []
- for i in range(batch_size):
- current_seed = (
- qpos_seed[i].unsqueeze(0)
- if qpos_seed.shape[0] == batch_size
- else qpos_seed
+ if qpos_seed.shape == (batch_size, self.dof):
+ seed_head = qpos_seed[:, None, :]
+ elif qpos_seed.shape == (self.dof,):
+ seed_head = qpos_seed.unsqueeze(0).repeat(batch_size, 1)[:, None, :]
+ else:
+ logger.log_error(
+ f"Invalid qpos_seed shape {qpos_seed.shape} for batch_size {batch_size} and dof {self.dof}",
+ ValueError,
)
- if self.num_samples > 1:
- rand_part = lower_limits + (upper_limits - lower_limits) * torch.rand(
- (self.num_samples - 1, self.dof), device=self.device
- )
- else:
- rand_part = torch.empty((0, self.dof), device=self.device)
- seeds = torch.cat([current_seed, rand_part], dim=0)
- joint_seeds_list.append(seeds)
- return torch.cat(joint_seeds_list, dim=0)
+ n_random_samples = self.num_samples - 1
+
+ # seed_random = torch.rand(
+ # size=(batch_size, n_random_samples, self.dof), device=self.device
+ # )
+
+ # save sampling time, repeat for each batch and sample in one go
+ seed_random = torch.rand(
+ size=(1, n_random_samples, self.dof), device=self.device
+ )
+ seed_random = seed_random.repeat(batch_size, 1, 1)
+ seed_random = lower_limits + (upper_limits - lower_limits) * seed_random
+ joint_seeds = torch.cat([seed_head, seed_random], dim=1)
+ return joint_seeds.reshape(-1, self.dof)
def repeat_target_xpos(
self, target_xpos: torch.Tensor, num_samples: int
@@ -81,8 +89,6 @@ def repeat_target_xpos(
Returns:
torch.Tensor: (batch_size * num_samples, 4, 4) or (batch_size * num_samples, 3, 3)
"""
- repeated_list = [
- target_xpos[i].unsqueeze(0).repeat(num_samples, 1, 1)
- for i in range(target_xpos.shape[0])
- ]
- return torch.cat(repeated_list, dim=0)
+
+ target_xpos_repeated = target_xpos.unsqueeze(1).repeat(1, num_samples, 1, 1)
+ return target_xpos_repeated.reshape(-1, 4, 4)
diff --git a/embodichain/lab/sim/solvers/srs_solver.py b/embodichain/lab/sim/solvers/srs_solver.py
index 64c4f492..d68f470b 100644
--- a/embodichain/lab/sim/solvers/srs_solver.py
+++ b/embodichain/lab/sim/solvers/srs_solver.py
@@ -51,9 +51,6 @@ class SRSSolverCfg(SolverCfg):
dh_params = []
"""Denavit-Hartenberg parameters for the robot's kinematic chain."""
- qpos_limits = []
- """Joint position limits for the robot."""
-
T_b_ob = np.eye(4)
"""Base to observed base transform."""
@@ -107,9 +104,7 @@ def __init__(self, cfg: SRSSolverCfg, device: torch.device):
self.device = device
self.dofs = 7
self.dh_params = cfg.dh_params
- self.qpos_limits = cfg.qpos_limits
self.tcp_xpos = np.eye(4)
-
# Initialize transformation matrices
self._parse_params()
@@ -122,7 +117,6 @@ def _parse_params(self):
# Convert configuration parameters to numpy arrays for efficient computation.
self.dh_params_np = np.asarray(self.cfg.dh_params)
- self.qpos_limits_np = np.asarray(self.cfg.qpos_limits)
self.link_lengths_np = np.asarray(self.cfg.link_lengths)
self.rotation_directions_np = np.asarray(self.cfg.rotation_directions)
@@ -628,11 +622,6 @@ def _parse_params(self):
dtype=float,
device=standardize_device_string(self.device),
)
- self.qpos_limits_wp = wp.array(
- self.qpos_limits_np,
- dtype=wp.vec2,
- device=standardize_device_string(self.device),
- )
self.link_lengths_wp = wp.array(
self.link_lengths_np.flatten(),
dtype=float,
@@ -1197,6 +1186,21 @@ def __init__(self, cfg: SRSSolverCfg, num_envs: int, device: str, **kwargs):
else:
self.impl = _CPUSRSSolverImpl(cfg, self.device)
+ self._update_impl_qpos_limits()
+
+ def _update_impl_qpos_limits(self):
+ qpos_limits = torch.vstack([self.lower_qpos_limits, self.upper_qpos_limits]).T
+ self.impl.qpos_limits_np = qpos_limits.cpu().numpy()
+ self.impl.qpos_limits_wp = wp.array(
+ self.impl.qpos_limits_np,
+ dtype=wp.vec2,
+ device=standardize_device_string(self.device),
+ )
+
+ def update_with_robot_limit(self, robot_qpos_limits):
+ super().update_with_robot_limit(robot_qpos_limits)
+ self._update_impl_qpos_limits()
+
def get_ik(
self,
target_xpos: torch.Tensor,
diff --git a/embodichain/lab/sim/types.py b/embodichain/lab/sim/types.py
index 0a7f0c22..c727ea83 100644
--- a/embodichain/lab/sim/types.py
+++ b/embodichain/lab/sim/types.py
@@ -20,7 +20,6 @@
from typing import Sequence, Union
from tensordict import TensorDict
-
Array = Union[torch.Tensor, np.ndarray, Sequence]
Device = Union[str, torch.device]
diff --git a/embodichain/lab/sim/utility/keyboard_utils.py b/embodichain/lab/sim/utility/keyboard_utils.py
index f0646b25..d64eca18 100644
--- a/embodichain/lab/sim/utility/keyboard_utils.py
+++ b/embodichain/lab/sim/utility/keyboard_utils.py
@@ -14,6 +14,8 @@
# limitations under the License.
# ----------------------------------------------------------------------------
+from __future__ import annotations
+
import select
import sys
import tty
@@ -24,8 +26,11 @@
import numpy as np
from scipy.spatial.transform import Rotation as R
+from typing import TYPE_CHECKING
+
+if TYPE_CHECKING:
+ from embodichain.lab.sim.sensors import Camera
-from embodichain.lab.sim.sensors import Camera
from embodichain.utils.logger import log_info, log_error, log_warning
@@ -47,12 +52,6 @@ def run_keyboard_control_for_camera(
sim = SimulationManager.get_instance()
- if vis_pose and sim.is_rt_enabled:
- log_warning(
- "'vis_pose' is not fully supported with ray tracing enabled. Will be fixed in future updates."
- )
- return
-
if isinstance(sensor, str):
sensor = sim.get_sensor(uid=sensor)
@@ -269,12 +268,6 @@ def run_keyboard_control_for_light(
sim = SimulationManager.get_instance()
- if vis_pose and sim.is_rt_enabled:
- log_warning(
- "'vis_pose' is not fully supported with ray tracing enabled. Will be fixed in future updates."
- )
- return
-
if isinstance(light, str):
light: Light = sim.get_light(uid=light)
diff --git a/embodichain/lab/sim/utility/sim_utils.py b/embodichain/lab/sim/utility/sim_utils.py
index 088709c3..9a3f1eea 100644
--- a/embodichain/lab/sim/utility/sim_utils.py
+++ b/embodichain/lab/sim/utility/sim_utils.py
@@ -152,7 +152,11 @@ def is_rt_enabled() -> bool:
"""
config = dexsim.get_world_config()
- return config.renderer == dexsim.types.Renderer.FASTRT
+ return (
+ config.renderer == dexsim.types.Renderer.FASTRT
+ or config.renderer == dexsim.types.Renderer.HYBRID
+ or config.renderer == dexsim.types.Renderer.OFFLINERT
+ )
def create_cube(
diff --git a/embodichain/lab/sim/utility/solver_utils.py b/embodichain/lab/sim/utility/solver_utils.py
index 9cdf1bc4..b6eac155 100644
--- a/embodichain/lab/sim/utility/solver_utils.py
+++ b/embodichain/lab/sim/utility/solver_utils.py
@@ -109,7 +109,7 @@ def create_pk_serial_chain(
else:
return pk.SerialChain(
chain=chain, end_frame_name=end_link_name, root_frame_name=root_link_name
- )
+ ).to(device=device)
def build_reduced_pinocchio_robot(
diff --git a/embodichain/lab/sim/utility/workspace_analyzer/caches/base_cache.py b/embodichain/lab/sim/utility/workspace_analyzer/caches/base_cache.py
index 63e40349..20eb407e 100644
--- a/embodichain/lab/sim/utility/workspace_analyzer/caches/base_cache.py
+++ b/embodichain/lab/sim/utility/workspace_analyzer/caches/base_cache.py
@@ -18,7 +18,6 @@
from typing import List
import numpy as np
-
all = [
"BaseCache",
]
diff --git a/embodichain/lab/sim/utility/workspace_analyzer/caches/cache_manager.py b/embodichain/lab/sim/utility/workspace_analyzer/caches/cache_manager.py
index 40fb56a2..13397246 100644
--- a/embodichain/lab/sim/utility/workspace_analyzer/caches/cache_manager.py
+++ b/embodichain/lab/sim/utility/workspace_analyzer/caches/cache_manager.py
@@ -25,7 +25,6 @@
CacheConfig,
)
-
all = [
"CacheManager",
]
diff --git a/embodichain/lab/sim/utility/workspace_analyzer/configs/__init__.py b/embodichain/lab/sim/utility/workspace_analyzer/configs/__init__.py
index f07ad587..549bc124 100644
--- a/embodichain/lab/sim/utility/workspace_analyzer/configs/__init__.py
+++ b/embodichain/lab/sim/utility/workspace_analyzer/configs/__init__.py
@@ -36,7 +36,6 @@
DensityConfig,
)
-
__all__ = [
"CacheConfig",
"DimensionConstraint",
diff --git a/embodichain/lab/sim/utility/workspace_analyzer/constraints/base_constraint.py b/embodichain/lab/sim/utility/workspace_analyzer/constraints/base_constraint.py
index a2e59704..8eb55a9d 100644
--- a/embodichain/lab/sim/utility/workspace_analyzer/constraints/base_constraint.py
+++ b/embodichain/lab/sim/utility/workspace_analyzer/constraints/base_constraint.py
@@ -21,7 +21,6 @@
from embodichain.utils import logger
-
__all__ = [
"IConstraintChecker",
"BaseConstraintChecker",
diff --git a/embodichain/lab/sim/utility/workspace_analyzer/constraints/workspace_constraint.py b/embodichain/lab/sim/utility/workspace_analyzer/constraints/workspace_constraint.py
index aa564cfb..0e9f8d5e 100644
--- a/embodichain/lab/sim/utility/workspace_analyzer/constraints/workspace_constraint.py
+++ b/embodichain/lab/sim/utility/workspace_analyzer/constraints/workspace_constraint.py
@@ -24,7 +24,6 @@
DimensionConstraint,
)
-
__all__ = [
"WorkspaceConstraintChecker",
]
@@ -139,6 +138,19 @@ def check_collision(
return valid
+ def check_constraints(
+ self, points: torch.Tensor | np.ndarray
+ ) -> torch.Tensor | np.ndarray:
+ """Check all constraints (bounds + collision) in a single call.
+
+ Args:
+ points: Array of shape (N, 3) containing 3D point positions.
+
+ Returns:
+ Boolean array of shape (N,) indicating which points satisfy all constraints.
+ """
+ return self.check_bounds(points) & self.check_collision(points)
+
def filter_points(
self, points: torch.Tensor | np.ndarray
) -> torch.Tensor | np.ndarray:
diff --git a/embodichain/lab/sim/utility/workspace_analyzer/metrics/density_metric.py b/embodichain/lab/sim/utility/workspace_analyzer/metrics/density_metric.py
index 8b82d857..f91236fe 100644
--- a/embodichain/lab/sim/utility/workspace_analyzer/metrics/density_metric.py
+++ b/embodichain/lab/sim/utility/workspace_analyzer/metrics/density_metric.py
@@ -92,35 +92,42 @@ def compute(
def _compute_local_density(self, points: np.ndarray) -> np.ndarray:
"""Compute local density for each point.
+ Uses scipy.spatial.cKDTree for O(N log N) performance instead of
+ the O(N^2) brute-force approach. Falls back to brute-force if
+ scipy is unavailable.
+
Args:
points: Point cloud, shape (N, 3).
Returns:
Local densities, shape (N,).
"""
- n_points = len(points)
- densities = np.zeros(n_points)
-
- # Use radius-based density estimation for better performance
radius = self.config.radius
-
- for i in range(n_points):
- # Compute distances to all other points
- distances = np.linalg.norm(points - points[i], axis=1)
-
- # Count neighbors within radius
- num_neighbors = np.sum(distances <= radius) - 1 # Exclude self
-
- # Density = neighbors / volume of sphere
- volume = (4.0 / 3.0) * np.pi * (radius**3)
- densities[i] = num_neighbors / volume if volume > 0 else 0.0
-
- return densities
+ volume = (4.0 / 3.0) * np.pi * (radius**3)
+
+ try:
+ from scipy.spatial import cKDTree
+
+ tree = cKDTree(points)
+ # Count neighbors within radius for all points at once
+ counts = tree.query_ball_point(points, r=radius, return_length=True)
+ # Subtract 1 to exclude self
+ densities = (counts - 1) / volume if volume > 0 else np.zeros(len(points))
+ return densities
+ except ImportError:
+ # Fallback: brute-force O(N^2)
+ n_points = len(points)
+ densities = np.zeros(n_points)
+ for i in range(n_points):
+ distances = np.linalg.norm(points - points[i], axis=1)
+ num_neighbors = np.sum(distances <= radius) - 1
+ densities[i] = num_neighbors / volume if volume > 0 else 0.0
+ return densities
def _compute_knn_density(self, points: np.ndarray) -> np.ndarray:
"""Compute k-nearest neighbors density.
- Alternative method using k-nearest neighbors instead of fixed radius.
+ Uses scipy.spatial.cKDTree for O(N log N) performance.
Args:
points: Point cloud, shape (N, 3).
@@ -134,19 +141,25 @@ def _compute_knn_density(self, points: np.ndarray) -> np.ndarray:
if k <= 0:
return np.zeros(n_points)
- densities = np.zeros(n_points)
-
- for i in range(n_points):
- # Compute distances to all other points
- distances = np.linalg.norm(points - points[i], axis=1)
-
- # Find k-nearest neighbors (excluding self)
- distances[i] = np.inf
- knn_distances = np.partition(distances, k)[:k]
-
- # Density = k / volume of sphere containing k neighbors
- max_distance = knn_distances.max()
- volume = (4.0 / 3.0) * np.pi * (max_distance**3)
- densities[i] = k / volume if volume > 0 else 0.0
-
- return densities
+ try:
+ from scipy.spatial import cKDTree
+
+ tree = cKDTree(points)
+ # Query k+1 nearest (includes self)
+ distances, _ = tree.query(points, k=k + 1)
+ # Use the k-th nearest distance (index k, since 0 is self)
+ max_distances = distances[:, -1]
+ max_distances = np.maximum(max_distances, 1e-10)
+ volumes = (4.0 / 3.0) * np.pi * (max_distances**3)
+ densities = k / volumes
+ return densities
+ except ImportError:
+ densities = np.zeros(n_points)
+ for i in range(n_points):
+ distances = np.linalg.norm(points - points[i], axis=1)
+ distances[i] = np.inf
+ knn_distances = np.partition(distances, k)[:k]
+ max_distance = knn_distances.max()
+ volume = (4.0 / 3.0) * np.pi * (max_distance**3)
+ densities[i] = k / volume if volume > 0 else 0.0
+ return densities
diff --git a/embodichain/lab/sim/utility/workspace_analyzer/metrics/manipulability_metric.py b/embodichain/lab/sim/utility/workspace_analyzer/metrics/manipulability_metric.py
index 16c71c5f..5b0e8d0b 100644
--- a/embodichain/lab/sim/utility/workspace_analyzer/metrics/manipulability_metric.py
+++ b/embodichain/lab/sim/utility/workspace_analyzer/metrics/manipulability_metric.py
@@ -95,6 +95,9 @@ def compute(
valid_mask = manipulability_scores >= self.config.jacobian_threshold
valid_scores = manipulability_scores[valid_mask]
+ if len(valid_scores) == 0:
+ valid_scores = np.array([0.0])
+
self.results = {
"mean_manipulability": float(valid_scores.mean()),
"std_manipulability": float(valid_scores.std()),
@@ -112,40 +115,46 @@ def compute(
return self.results
def _compute_manipulability_index(self, jacobians: np.ndarray) -> np.ndarray:
- """Compute Yoshikawa manipulability index.
+ """Compute Yoshikawa manipulability index with batched operations.
Args:
- jacobians: Jacobian matrices, shape (N, 6, num_joints).
+ jacobians: Jacobian matrices, shape (N, rows, cols).
Returns:
Manipulability indices, shape (N,).
"""
- # Manipulability index: sqrt(det(J * J^T))
- manipulability = np.zeros(len(jacobians))
+ # Batch matrix multiply: J @ J^T for all samples
+ JJT = np.matmul(jacobians, np.swapaxes(jacobians, -2, -1))
- for i, J in enumerate(jacobians):
- JJT = J @ J.T
- det = np.linalg.det(JJT)
- manipulability[i] = np.sqrt(max(det, 0))
+ # Batch determinant
+ dets = np.linalg.det(JJT)
- return manipulability
+ # sqrt(max(0, det))
+ return np.sqrt(np.maximum(dets, 0.0))
def _compute_condition_numbers(self, jacobians: np.ndarray) -> np.ndarray:
- """Compute condition numbers of Jacobian matrices.
+ """Compute condition numbers of Jacobian matrices with batched SVD.
Args:
- jacobians: Jacobian matrices, shape (N, 6, num_joints).
+ jacobians: Jacobian matrices, shape (N, rows, cols).
Returns:
Condition numbers, shape (N,).
"""
- condition_numbers = np.zeros(len(jacobians))
-
- for i, J in enumerate(jacobians):
- try:
- condition_numbers[i] = np.linalg.cond(J)
- except np.linalg.LinAlgError:
- # Singular matrix, use infinity as condition number
- condition_numbers[i] = np.inf
-
- return condition_numbers
+ try:
+ _, singular_values, _ = np.linalg.svd(jacobians, full_matrices=False)
+ # Condition number = max singular value / min singular value
+ max_sv = singular_values[:, 0]
+ min_sv = singular_values[:, -1]
+ # Avoid division by zero
+ min_sv = np.maximum(min_sv, 1e-15)
+ return max_sv / min_sv
+ except np.linalg.LinAlgError:
+ # Fallback to per-matrix computation if batch SVD fails
+ condition_numbers = np.zeros(len(jacobians))
+ for i, J in enumerate(jacobians):
+ try:
+ condition_numbers[i] = np.linalg.cond(J)
+ except np.linalg.LinAlgError:
+ condition_numbers[i] = np.inf
+ return condition_numbers
diff --git a/embodichain/lab/sim/utility/workspace_analyzer/metrics/reachability_metric.py b/embodichain/lab/sim/utility/workspace_analyzer/metrics/reachability_metric.py
index f20f0e1c..39721f7c 100644
--- a/embodichain/lab/sim/utility/workspace_analyzer/metrics/reachability_metric.py
+++ b/embodichain/lab/sim/utility/workspace_analyzer/metrics/reachability_metric.py
@@ -112,7 +112,7 @@ def compute(
def _voxelize_points(
self, points: np.ndarray, voxel_size: float
) -> Dict[tuple, int]:
- """Convert points to voxel grid.
+ """Convert points to voxel grid using vectorized operations.
Args:
points: Point cloud, shape (N, 3).
@@ -124,14 +124,14 @@ def _voxelize_points(
# Convert points to voxel indices
voxel_indices = np.floor(points / voxel_size).astype(int)
- # Count points in each voxel
- voxel_grid = {}
- for idx in voxel_indices:
- key = tuple(idx)
- voxel_grid[key] = voxel_grid.get(key, 0) + 1
+ # Use np.unique for vectorized counting
+ unique_indices, counts = np.unique(voxel_indices, axis=0, return_counts=True)
- # Filter by minimum points threshold
+ # Filter by minimum points threshold and build dict
min_points = self.config.min_points_per_voxel
- voxel_grid = {k: v for k, v in voxel_grid.items() if v >= min_points}
+ voxel_grid = {}
+ for idx, count in zip(unique_indices, counts):
+ if count >= min_points:
+ voxel_grid[tuple(idx)] = int(count)
return voxel_grid
diff --git a/embodichain/lab/sim/utility/workspace_analyzer/samplers/base_sampler.py b/embodichain/lab/sim/utility/workspace_analyzer/samplers/base_sampler.py
index 2685e5ec..30a1bf97 100644
--- a/embodichain/lab/sim/utility/workspace_analyzer/samplers/base_sampler.py
+++ b/embodichain/lab/sim/utility/workspace_analyzer/samplers/base_sampler.py
@@ -21,7 +21,6 @@
from embodichain.utils import logger
-
__all__ = [
"ISampler",
"BaseSampler",
diff --git a/embodichain/lab/sim/utility/workspace_analyzer/samplers/halton_sampler.py b/embodichain/lab/sim/utility/workspace_analyzer/samplers/halton_sampler.py
index 01b005f8..c00c991a 100644
--- a/embodichain/lab/sim/utility/workspace_analyzer/samplers/halton_sampler.py
+++ b/embodichain/lab/sim/utility/workspace_analyzer/samplers/halton_sampler.py
@@ -176,7 +176,7 @@ def __init__(
self.bases = bases
self.skip = skip
- def sample(
+ def _sample_from_bounds(
self, bounds: torch.Tensor | np.ndarray, num_samples: int
) -> torch.Tensor:
"""Generate Halton sequence samples within the given bounds.
@@ -190,13 +190,6 @@ def sample(
Raises:
ValueError: If bounds are invalid or num_samples is non-positive.
-
- Examples:
- >>> sampler = HaltonSampler(skip=100)
- >>> bounds = torch.tensor([[-1.0, 1.0], [-1.0, 1.0]], dtype=torch.float32)
- >>> samples = sampler.sample(bounds, num_samples=100)
- >>> samples.shape
- torch.Size([100, 2])
"""
bounds = self._validate_bounds(bounds)
@@ -220,14 +213,8 @@ def sample(
)
bases = self.bases[:n_dims]
- # Generate Halton sequence
- samples_unit = np.zeros((num_samples, n_dims), dtype=np.float32)
-
- for dim in range(n_dims):
- base = bases[dim]
- for i in range(num_samples):
- index = i + self.skip + 1 # Start from 1, apply skip
- samples_unit[i, dim] = self._halton_number(index, base)
+ # Generate Halton sequence with vectorized van der Corput
+ samples_unit = self._generate_halton_vectorized(num_samples, n_dims, bases)
# Convert to tensor and scale to bounds
samples_unit_tensor = self._to_tensor(samples_unit)
@@ -238,30 +225,53 @@ def sample(
return samples
- @staticmethod
- def _halton_number(index: int, base: int) -> float:
- """Compute a single Halton number.
+ def _generate_halton_vectorized(
+ self, num_samples: int, n_dims: int, bases: list[int]
+ ) -> np.ndarray:
+ """Generate Halton sequence using vectorized van der Corput computation.
+
+ Args:
+ num_samples: Number of samples to generate.
+ n_dims: Number of dimensions.
+ bases: Prime bases for each dimension.
+
+ Returns:
+ Array of shape (num_samples, n_dims) with values in [0, 1].
+ """
+ indices = np.arange(1, num_samples + 1) + self.skip # (num_samples,)
+ samples = np.zeros((num_samples, n_dims), dtype=np.float32)
+
+ for dim in range(n_dims):
+ samples[:, dim] = self._van_der_corput_vectorized(indices, bases[dim])
- The Halton sequence is generated by reversing the base-n representation
- of the index.
+ return samples
+
+ @staticmethod
+ def _van_der_corput_vectorized(indices: np.ndarray, base: int) -> np.ndarray:
+ """Compute van der Corput sequence for multiple indices at once.
Args:
- index: Sequence index (starting from 1).
+ indices: Array of sequence indices.
base: Prime base for this dimension.
Returns:
- Halton number in [0, 1].
+ Array of van der Corput values in [0, 1].
"""
- result = 0.0
- f = 1.0 / base
- i = index
+ # Determine maximum number of digits needed
+ max_idx = int(indices.max())
+ n_digits = int(np.ceil(np.log(max_idx + 1) / np.log(base))) + 1
+
+ result = np.zeros(len(indices), dtype=np.float64)
+ i_vals = indices.astype(np.float64).copy()
+ current_f = 1.0 / base
- while i > 0:
- result += f * (i % base)
- i //= base
- f /= base
+ for _ in range(n_digits):
+ remainders = i_vals % base
+ result += current_f * remainders
+ i_vals = np.floor(i_vals / base)
+ current_f /= base
- return result
+ return result.astype(np.float32)
def get_strategy_name(self) -> str:
"""Get the name of the sampling strategy.
diff --git a/embodichain/lab/sim/utility/workspace_analyzer/samplers/iniform_sampler.py b/embodichain/lab/sim/utility/workspace_analyzer/samplers/iniform_sampler.py
index 8f536817..1db2ce4b 100644
--- a/embodichain/lab/sim/utility/workspace_analyzer/samplers/iniform_sampler.py
+++ b/embodichain/lab/sim/utility/workspace_analyzer/samplers/iniform_sampler.py
@@ -75,8 +75,8 @@ def _sample_from_bounds(
bounds: Tensor/Array of shape (n_dims, 2) containing [lower, upper] bounds for each dimension.
num_samples: Total number of samples to generate. This is used to calculate
samples_per_dim if not explicitly provided during initialization.
- Note: The actual number of samples may differ slightly from this value
- to maintain a uniform grid.
+ Note: The actual number of samples (samples_per_dim^n_dims) will not
+ exceed this value, but may be less to maintain a uniform grid.
Returns:
Tensor of shape (actual_num_samples, n_dims) containing the sampled points.
@@ -99,7 +99,8 @@ def _sample_from_bounds(
# Calculate samples per dimension if not provided
if self.samples_per_dim is None:
# Compute samples_per_dim to approximate the desired num_samples
- samples_per_dim = max(2, int(np.ceil(num_samples ** (1.0 / n_dims))))
+ # Use floor to ensure actual grid size never exceeds num_samples
+ samples_per_dim = max(2, int(num_samples ** (1.0 / n_dims)))
else:
samples_per_dim = self.samples_per_dim
diff --git a/embodichain/lab/sim/utility/workspace_analyzer/visualizers/base_visualizer.py b/embodichain/lab/sim/utility/workspace_analyzer/visualizers/base_visualizer.py
index 42541098..4c27bc94 100644
--- a/embodichain/lab/sim/utility/workspace_analyzer/visualizers/base_visualizer.py
+++ b/embodichain/lab/sim/utility/workspace_analyzer/visualizers/base_visualizer.py
@@ -40,7 +40,6 @@
VisualizationConfig,
)
-
__all__ = [
"IVisualizer",
"BaseVisualizer",
diff --git a/embodichain/lab/sim/utility/workspace_analyzer/visualizers/sphere_visualizer.py b/embodichain/lab/sim/utility/workspace_analyzer/visualizers/sphere_visualizer.py
index 401cedbc..08bb3c2c 100644
--- a/embodichain/lab/sim/utility/workspace_analyzer/visualizers/sphere_visualizer.py
+++ b/embodichain/lab/sim/utility/workspace_analyzer/visualizers/sphere_visualizer.py
@@ -33,7 +33,6 @@
from embodichain.utils import logger
-
__all__ = ["SphereVisualizer"]
diff --git a/embodichain/lab/sim/utility/workspace_analyzer/visualizers/voxel_visualizer.py b/embodichain/lab/sim/utility/workspace_analyzer/visualizers/voxel_visualizer.py
index 47b46fd4..1cfc0647 100644
--- a/embodichain/lab/sim/utility/workspace_analyzer/visualizers/voxel_visualizer.py
+++ b/embodichain/lab/sim/utility/workspace_analyzer/visualizers/voxel_visualizer.py
@@ -33,7 +33,6 @@
from embodichain.utils import logger
-
__all__ = ["VoxelVisualizer"]
diff --git a/embodichain/lab/sim/utility/workspace_analyzer/workspace_analyzer.py b/embodichain/lab/sim/utility/workspace_analyzer/workspace_analyzer.py
index 38937ea7..ef523c49 100644
--- a/embodichain/lab/sim/utility/workspace_analyzer/workspace_analyzer.py
+++ b/embodichain/lab/sim/utility/workspace_analyzer/workspace_analyzer.py
@@ -302,6 +302,7 @@ def _create_sampler(self) -> BaseSampler:
return factory.create_sampler(
strategy=self.config.sampling.strategy,
seed=self.config.sampling.seed,
+ device=self.device,
)
# Note: Geometric constraint creation methods temporarily removed
@@ -893,6 +894,9 @@ def compute_workspace_points(
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Compute end-effector positions for given joint configurations.
+ Uses batched FK computation via ``robot.compute_batch_fk`` for
+ significant speedup on large sample counts.
+
Args:
joint_configs: Joint configurations, shape (num_samples, num_joints).
batch_size: Batch size for FK computation. If None, uses config value.
@@ -903,56 +907,66 @@ def compute_workspace_points(
- valid_configs: Valid joint configurations, shape (num_valid, num_joints)
"""
num_samples = len(joint_configs)
+ batch_size = batch_size or self.config.sampling.batch_size
+ # Cap batch size to total samples
+ batch_size = min(batch_size, num_samples)
+
+ logger.log_info(
+ f"Computing FK for {num_samples} samples (batch_size={batch_size})..."
+ )
+ # Pre-allocate lists for results
workspace_points_list = []
valid_configs_list = []
-
- logger.log_info(f"Computing FK for {num_samples} samples...")
-
- # Track valid points for progress bar
total_valid = 0
- # Robot expects one configuration at a time (batch_size from robot environments, not samples)
- # Process each configuration individually
pbar = self._create_optimized_tqdm(
- range(num_samples),
- desc="Forward Kinematics",
- unit="cfg",
+ range(0, num_samples, batch_size),
+ desc="Forward Kinematics (batched)",
+ unit="batch",
color="cyan",
emoji="🤖",
)
- for i in pbar:
- qpos = joint_configs[i : i + 1] # Keep batch dimension
+
+ for batch_start in pbar:
+ batch_end = min(batch_start + batch_size, num_samples)
+
+ # Reshape to (n_envs=1, batch_size, num_joints) for compute_batch_fk
+ qpos_batch = joint_configs[batch_start:batch_end].unsqueeze(0)
try:
- # Compute forward kinematics
- pose = self.robot.compute_fk(
- qpos=qpos,
+ # Batched FK: (1, batch, num_joints) -> (1, batch, 4, 4)
+ poses = self.robot.compute_batch_fk(
+ qpos=qpos_batch,
name=self.control_part_name,
to_matrix=True,
)
- # Extract position (x, y, z)
- position = pose[:, :3, 3] # Shape: (1, 3)
+ # Extract positions: (1, batch, 4, 4) -> (batch, 3)
+ positions = poses[0, :, :3, 3]
- # Filter by constraints (bounds + collision check)
- valid_bounds = self.constraint_checker.check_bounds(position)
- valid_collision = self.constraint_checker.check_collision(position)
- valid_mask = valid_bounds & valid_collision
+ # Vectorized constraint check for entire batch
+ valid_mask = self.constraint_checker.check_constraints(positions)
- # Store valid results
if valid_mask.any():
- workspace_points_list.append(position[valid_mask])
- valid_configs_list.append(qpos[valid_mask])
- total_valid += 1
+ workspace_points_list.append(positions[valid_mask])
+ valid_configs_list.append(
+ joint_configs[batch_start:batch_end][valid_mask]
+ )
+ total_valid += valid_mask.sum().item()
- # Update progress bar with intelligent statistics
self._update_progress_with_stats(
- pbar, i, total_valid, metric_name="valid", show_rate=True
+ pbar,
+ batch_end - 1,
+ total_valid,
+ metric_name="valid",
+ show_rate=True,
)
except Exception as e:
- logger.log_warning(f"FK computation failed for sample {i}: {e}")
+ logger.log_warning(
+ f"FK computation failed for batch [{batch_start}:{batch_end}]: {e}"
+ )
continue
# Concatenate all results
@@ -963,19 +977,19 @@ def compute_workspace_points(
workspace_points = torch.empty((0, 3), device=self.device)
valid_configs = torch.empty((0, self.num_joints), device=self.device)
- # Performance summary for FK computation
- pbar.close() # Ensure progress bar is closed
- success_rate = len(workspace_points) / num_samples * 100
+ pbar.close()
+ success_rate = (
+ len(workspace_points) / num_samples * 100 if num_samples > 0 else 0
+ )
- # Performance indicator based on success rate
if success_rate >= 90:
- perf_icon = "🏆" # Trophy for excellent performance
+ perf_icon = "🏆"
elif success_rate >= 75:
- perf_icon = "✅" # Check mark for good performance
+ perf_icon = "✅"
elif success_rate >= 50:
- perf_icon = "🟡" # Yellow circle for moderate performance
+ perf_icon = "🟡"
else:
- perf_icon = "⚠️" # Warning for low performance
+ perf_icon = "⚠️"
logger.log_info(
f"{perf_icon} FK Results: {len(workspace_points)}/{num_samples} valid points "
@@ -987,7 +1001,13 @@ def compute_workspace_points(
def compute_reachability(
self, cartesian_points: torch.Tensor, batch_size: int | None = None
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
- """Compute reachability for Cartesian points using IK.
+ """Compute reachability for Cartesian points using batched IK.
+
+ All ``ik_samples_per_point`` random seeds for a batch of points are
+ merged into the batch dimension and resolved with a **single**
+ ``robot.compute_batch_ik`` call (shape ``(1, n_valid * K, 4, 4)``).
+ This avoids the Python loop overhead and lets the solver process all
+ seeds in one vectorised pass.
Args:
cartesian_points: Cartesian positions, shape (num_samples, 3).
@@ -1003,208 +1023,125 @@ def compute_reachability(
"""
num_samples = len(cartesian_points)
ik_samples_per_point = self.config.ik_samples_per_point
+ batch_size = batch_size or self.config.sampling.batch_size
+ batch_size = min(batch_size, num_samples)
- # Pre-filter Cartesian points by workspace constraints
- # This eliminates points that are outside bounds or in collision zones
- valid_cartesian_mask = self.constraint_checker.check_bounds(
+ # Pre-filter by workspace constraints (vectorized)
+ valid_cartesian_mask = self.constraint_checker.check_constraints(
cartesian_points
- ) & self.constraint_checker.check_collision(cartesian_points)
+ )
logger.log_info(
f"Pre-filtered Cartesian points: {valid_cartesian_mask.sum()}/{num_samples} "
f"points pass workspace constraints ({(valid_cartesian_mask.sum()/num_samples*100):.1f}%)"
)
- # Store results for all points (including invalid ones for consistent indexing)
+ # Get reference end-effector pose for IK target orientation
+ current_ee_pose = self._get_reference_pose()
+
+ # Initialize result arrays
all_success_rates = torch.zeros(num_samples, device=self.device)
reachable_points_list = []
best_configs_list = []
+ total_reachable = 0
- logger.log_info(
- f"Computing IK for {num_samples} Cartesian samples "
- f"({ik_samples_per_point} seeds per point)..."
- )
-
- # Create a random sampler for generating IK seeds (avoid UniformSampler issues)
+ # Prepare random seeds for all attempts
from embodichain.lab.sim.utility.workspace_analyzer.samplers import (
RandomSampler,
)
- random_sampler = RandomSampler(seed=self.config.sampling.seed)
-
- # Get reference end-effector pose for IK target orientation
- # Priority: use reference_pose if provided, otherwise compute from current joint configuration
- if (
- hasattr(self.config, "reference_pose")
- and self.config.reference_pose is not None
- ):
- # Use provided reference pose (should be 4x4 transformation matrix)
- reference_pose = self.config.reference_pose
- if isinstance(reference_pose, np.ndarray):
- reference_pose = torch.from_numpy(reference_pose).to(self.device)
- if reference_pose.dim() == 2: # Shape: (4, 4) -> (1, 4, 4)
- reference_pose = reference_pose.unsqueeze(0)
- current_ee_pose = reference_pose # Shape: (1, 4, 4)
- logger.log_info("Using provided reference pose for IK target orientation")
- else:
- # Fallback: compute current end-effector pose from joint configuration
- try:
- # Using first environment (index 0) for qpos retrieval
- current_qpos = self.robot.get_qpos()[0][
- self.robot.get_joint_ids(self.control_part_name)
- ]
- current_ee_pose = self.robot.compute_fk(
- name=self.control_part_name,
- qpos=current_qpos.unsqueeze(0),
- to_matrix=True,
- ) # Shape: (1, 4, 4)
- logger.log_info(
- "Computing reference pose from current robot configuration"
- )
- except Exception as e:
- logger.log_warning(f"Failed to compute current robot pose: {e}")
- # Create identity pose as fallback
- current_ee_pose = torch.eye(4, device=self.device).unsqueeze(0)
- current_ee_pose[0, :3, 3] = torch.tensor(
- [0.5, 0.0, 1.0], device=self.device
- ) # Default position
- logger.log_info("Using default identity pose as fallback")
-
- # Print current joint configuration and computed pose
- pose_np = current_ee_pose[0].cpu().numpy()
- position = pose_np[:3, 3]
- rotation_matrix = pose_np[:3, :3]
-
- # Convert rotation matrix to Euler angles
- import scipy.spatial.transform as spt
-
- euler_angles = spt.Rotation.from_matrix(rotation_matrix).as_euler(
- "xyz", degrees=True
- )
-
- # Print detailed reference pose information
- pose_np = current_ee_pose[0].cpu().numpy()
- position = pose_np[:3, 3]
- rotation_matrix = pose_np[:3, :3]
-
- # Convert rotation matrix to Euler angles (ZYX convention)
- import scipy.spatial.transform as spt
-
- euler_angles = spt.Rotation.from_matrix(rotation_matrix).as_euler(
- "xyz", degrees=True
+ random_sampler = RandomSampler(
+ seed=self.config.sampling.seed, device=self.device
)
- # Format matrix with proper indentation
- matrix_lines = np.array2string(pose_np, precision=4, suppress_small=True).split(
- "\n"
- )
- matrix_str = "\n".join(f"\t {line}" for line in matrix_lines)
logger.log_info(
- f"🎯 Using provided reference pose for IK target orientation:\n"
- f"\t Position: [{position[0]:.4f}, {position[1]:.4f}, {position[2]:.4f}] m\n"
- f"\t Rotation (XYZ Euler): [{euler_angles[0]:.2f}°, {euler_angles[1]:.2f}°, {euler_angles[2]:.2f}°]\n"
- f"\t Matrix:\n{matrix_str}"
+ f"Computing IK for {num_samples} Cartesian samples "
+ f"(batch_size={batch_size}, {ik_samples_per_point} seeds per point)..."
)
- # Track statistics for progress bar
- total_reachable = 0
-
- # Process each point individually (robot expects batch_size from environments, not samples)
pbar = self._create_optimized_tqdm(
- range(num_samples),
- desc="Inverse Kinematics",
- unit="pt",
+ range(0, num_samples, batch_size),
+ desc="Inverse Kinematics (batched)",
+ unit="batch",
color="magenta",
emoji="🎯",
)
- for i in pbar:
- position = cartesian_points[i] # Shape: (3,)
-
- # Skip points that don't satisfy workspace constraints
- if not valid_cartesian_mask[i]:
- # Mark as unreachable due to constraint violation
- all_success_rates[i] = 0.0
- # Update progress bar
- reachability_rate = total_reachable / (i + 1) * 100
- if reachability_rate >= 70:
- reach_color = "\033[32m" # Green for high reachability
- elif reachability_rate >= 40:
- reach_color = "\033[33m" # Yellow for medium reachability
- else:
- reach_color = "\033[31m" # Red for low reachability
- pbar.set_postfix_str(
- f"🎯 Reachable: {total_reachable}/{i+1} | {reach_color}{reachability_rate:.1f}%\033[0m rate (❌ constraint)"
- )
+ for batch_start in pbar:
+ batch_end = min(batch_start + batch_size, num_samples)
+ batch_valid_mask = valid_cartesian_mask[batch_start:batch_end]
+ n_valid = batch_valid_mask.sum().item()
+
+ if n_valid == 0:
continue
- # Create target pose: use current orientation, replace position with sampled position
- pose = current_ee_pose.clone()
- pose[0, :3, 3] = position
-
- # Try multiple random seeds for this point
- success_count = 0
- best_qpos = None
-
- logger.set_log_level("ERROR") # Suppress warnings during IK attempts
- for seed_idx in range(ik_samples_per_point):
- # Generate random joint seed using RandomSampler
- random_seed = random_sampler.sample(
- bounds=self.qpos_limits, num_samples=1
- ) # Shape: (1, num_joints)
-
- try:
- # Compute IK
- ret, qpos = self.robot.compute_ik(
- pose=pose,
- joint_seed=random_seed,
- name=self.control_part_name,
- )
+ # Get valid positions (n_valid, 3)
+ valid_positions = cartesian_points[batch_start:batch_end][batch_valid_mask]
- # Count successes
- if ret is not None and ret[0]:
- success_count += 1
- # Store first successful configuration
- if best_qpos is None:
- best_qpos = qpos[0] # Extract from batch dimension
+ # Build target poses for all seeds in one shot.
+ # Each position is repeated ik_samples_per_point times so that a single
+ # compute_batch_ik call covers all (n_valid * K) targets at once.
+ # Shape: (1, n_valid * K, 4, 4)
+ base_pose = current_ee_pose.unsqueeze(1).expand(1, n_valid, 4, 4).clone()
+ base_pose[0, :, :3, 3] = valid_positions
+ target_poses = base_pose.repeat_interleave(ik_samples_per_point, dim=1)
- except Exception as e:
- logger.log_warning(
- f"IK computation failed for sample {i}, seed {seed_idx}: {e}"
- )
- continue
- logger.set_log_level("INFO") # Restore log level
-
- # Calculate success rate for this point
- success_rate = success_count / ik_samples_per_point
- all_success_rates[i] = success_rate
-
- # Filter by success threshold for reachable points
- if success_rate and best_qpos is not None:
- reachable_points_list.append(position.unsqueeze(0)) # Add batch dim
- best_configs_list.append(best_qpos.unsqueeze(0)) # Add batch dim
- total_reachable += 1
-
- # Update progress bar with reachability statistics
- reachability_rate = total_reachable / (i + 1) * 100
- # Use color coding for the reachability rate
- if reachability_rate >= 70:
- reach_color = "\033[32m" # Green for high reachability
- elif reachability_rate >= 40:
- reach_color = "\033[33m" # Yellow for medium reachability
- else:
- reach_color = "\033[31m" # Red for low reachability
+ # Generate all random seeds at once: (1, n_valid * K, num_joints)
+ all_seeds = random_sampler.sample(
+ bounds=self.qpos_limits, num_samples=n_valid * ik_samples_per_point
+ ).unsqueeze(0)
- # Add success rate indicator for this specific point
- if success_rate:
- point_status = "✅ IK"
- elif success_rate > 0:
- point_status = f"🟡 IK({success_rate:.1f})"
- else:
- point_status = "❌ IK"
+ try:
+ logger.set_log_level("ERROR")
+ success, qpos = self.robot.compute_batch_ik(
+ pose=target_poses,
+ joint_seed=all_seeds,
+ name=self.control_part_name,
+ )
+ logger.set_log_level("INFO")
+
+ # Reshape results from flat batch to (n_valid, K)
+ success_2d = success[0].reshape(n_valid, ik_samples_per_point)
+ qpos_3d = qpos[0].reshape(
+ n_valid, ik_samples_per_point, self.num_joints
+ )
+
+ # Success rate: fraction of seeds that solved IK for each point
+ success_rates_batch = success_2d.float().mean(dim=1) # (n_valid,)
+
+ # Pick the joint config from the first successful seed per point
+ any_success = success_2d.any(dim=1) # (n_valid,)
+ first_success_idx = success_2d.float().argmax(dim=1) # (n_valid,)
+ best_qpos = qpos_3d[
+ torch.arange(n_valid, device=self.device), first_success_idx
+ ] # (n_valid, num_joints)
+
+ except Exception as e:
+ logger.set_log_level("INFO")
+ logger.log_warning(
+ f"IK computation failed for batch [{batch_start}:{batch_end}]: {e}"
+ )
+ success_rates_batch = torch.zeros(n_valid, device=self.device)
+ any_success = torch.zeros(n_valid, dtype=torch.bool, device=self.device)
+ best_qpos = torch.zeros(n_valid, self.num_joints, device=self.device)
+
+ # Map results back to original (pre-filter) indices
+ valid_local_indices = batch_valid_mask.nonzero(as_tuple=True)[0]
+ global_indices = batch_start + valid_local_indices
+ all_success_rates[global_indices] = success_rates_batch
+
+ # Collect reachable points
+ if any_success.any():
+ reachable_points_list.append(valid_positions[any_success])
+ best_configs_list.append(best_qpos[any_success])
+ total_reachable += any_success.sum().item()
- pbar.set_postfix_str(
- f"🎯 Reachable: {total_reachable}/{i+1} | {reach_color}{reachability_rate:.1f}%\033[0m rate | {point_status}"
+ self._update_progress_with_stats(
+ pbar,
+ batch_end - 1,
+ total_reachable,
+ metric_name="reachable",
+ show_rate=True,
)
# Concatenate reachable results
@@ -1215,24 +1152,23 @@ def compute_reachability(
reachable_points = torch.empty((0, 3), device=self.device)
best_configs = torch.empty((0, self.num_joints), device=self.device)
- # Create reachability mask
reachability_mask = all_success_rates > 0
- # Performance summary for IK computation
- pbar.close() # Ensure progress bar is closed
- reachability = len(reachable_points) / num_samples * 100
+ pbar.close()
+ reachability = (
+ len(reachable_points) / num_samples * 100 if num_samples > 0 else 0
+ )
- # Reachability performance indicator
if reachability >= 80:
- reach_icon = "🏆" # Trophy for high reachability
+ reach_icon = "🏆"
elif reachability >= 60:
- reach_icon = "🚀" # Rocket for good reachability
+ reach_icon = "🚀"
elif reachability >= 40:
- reach_icon = "🟡" # Yellow for moderate reachability
+ reach_icon = "🟡"
elif reachability >= 20:
- reach_icon = "🟠" # Orange for low reachability
+ reach_icon = "🟠"
else:
- reach_icon = "⚠️" # Warning for very low reachability
+ reach_icon = "⚠️"
logger.log_info(
f"{reach_icon} IK Results: {len(reachable_points)}/{num_samples} reachable points "
@@ -1247,6 +1183,42 @@ def compute_reachability(
best_configs,
)
+ def _get_reference_pose(self) -> torch.Tensor:
+ """Get reference end-effector pose for IK target orientation.
+
+ Returns:
+ Reference pose tensor of shape (1, 4, 4).
+ """
+ if (
+ hasattr(self.config, "reference_pose")
+ and self.config.reference_pose is not None
+ ):
+ reference_pose = self.config.reference_pose
+ if isinstance(reference_pose, np.ndarray):
+ reference_pose = torch.from_numpy(reference_pose).to(self.device)
+ if reference_pose.dim() == 2:
+ reference_pose = reference_pose.unsqueeze(0)
+ logger.log_info("Using provided reference pose for IK target orientation")
+ return reference_pose
+
+ try:
+ current_qpos = self.robot.get_qpos()[0][
+ self.robot.get_joint_ids(self.control_part_name)
+ ]
+ current_ee_pose = self.robot.compute_fk(
+ name=self.control_part_name,
+ qpos=current_qpos.unsqueeze(0),
+ to_matrix=True,
+ )
+ logger.log_info("Computing reference pose from current robot configuration")
+ return current_ee_pose
+ except Exception as e:
+ logger.log_warning(f"Failed to compute current robot pose: {e}")
+ default_pose = torch.eye(4, device=self.device).unsqueeze(0)
+ default_pose[0, :3, 3] = torch.tensor([0.5, 0.0, 1.0], device=self.device)
+ logger.log_info("Using default identity pose as fallback")
+ return default_pose
+
def analyze(
self,
num_samples: int | None = None,
diff --git a/embodichain/toolkits/graspkit/pg_grasp/antipodal_generator.py b/embodichain/toolkits/graspkit/pg_grasp/antipodal_generator.py
index f6389ff8..9ec009bc 100644
--- a/embodichain/toolkits/graspkit/pg_grasp/antipodal_generator.py
+++ b/embodichain/toolkits/graspkit/pg_grasp/antipodal_generator.py
@@ -73,7 +73,7 @@ class GraspGeneratorCfg:
number of sampled surface points, ray perturbation angle, and gripper jaw
distance limits. See :class:`AntipodalSamplerCfg` for details."""
- max_deviation_angle: float = np.pi / 12
+ max_deviation_angle: float = np.pi / 6
"""Maximum allowed angle (in radians) between the specified approach
direction and the axis connecting an antipodal point pair. Pairs that
deviate more than this threshold from perpendicular to the approach are
@@ -583,7 +583,7 @@ def get_grasp_poses(
approach_direction: torch.Tensor,
visualize_collision: bool = False,
visualize_pose: bool = False,
- ) -> tuple[torch.Tensor, torch.Tensor]:
+ ) -> tuple[bool, torch.Tensor, float]:
"""Get grasp pose given approach direction.
Uses the antipodal point pairs stored in ``self._hit_point_pairs``
@@ -603,19 +603,20 @@ def get_grasp_poses(
after computation.
Returns:
- A tuple ``(best_grasp_pose, best_open_length)`` where
- ``best_grasp_pose`` is a ``(4, 4)`` homogeneous matrix and
- ``best_open_length`` is a scalar.
+ is_success (bool): Whether a valid grasp pose is found.
+ best_grasp_pose (torch.Tensor): If a valid grasp pose is found, a tensor of shape (4, 4) representing the homogeneous transformation matrix of the best grasp pose in the world frame. Otherwise, an identity matrix.
+ best_open_length (float): If a valid grasp pose is found, a scalar representing the optimal gripper opening length. Otherwise, a zero tensor.
Raises:
RuntimeError: If :meth:`generate` or :meth:`annotate` has not
been called yet.
"""
if self._hit_point_pairs is None:
- raise RuntimeError(
+ logger.log_warning(
"No antipodal point pairs available. "
"Call generate() or annotate() first."
)
+ return False, torch.eye(4, device=self.device), 0.0
origin_points = self._hit_point_pairs[:, 0, :]
hit_points = self._hit_point_pairs[:, 1, :]
origin_points_ = self._apply_transform(origin_points, object_pose)
@@ -632,6 +633,10 @@ def get_grasp_poses(
valid_mask = (
positive_angle - torch.pi / 2
).abs() <= self.cfg.max_deviation_angle
+ if valid_mask.sum() == 0:
+ logger.log_warning("No valid antipodal pairs after angle filtering.")
+ return False, torch.eye(4, device=self.device), 0.0
+
valid_grasp_x = grasp_x[valid_mask]
valid_centers = centers[valid_mask]
@@ -650,6 +655,9 @@ def get_grasp_poses(
is_visual=visualize_collision,
collision_threshold=0.0,
)
+ if is_colliding.logical_not().sum() == 0:
+ logger.log_warning("No valid antipodal pairs after angle filtering.")
+ return False, torch.eye(4, device=self.device), 0.0
# get best grasp pose
valid_grasp_poses = valid_grasp_poses[~is_colliding]
valid_open_lengths = valid_open_lengths[~is_colliding]
@@ -674,7 +682,7 @@ def get_grasp_poses(
grasp_pose=best_grasp_pose,
open_length=best_open_length.item(),
)
- return best_grasp_pose, best_open_length
+ return True, best_grasp_pose, best_open_length
@staticmethod
def _grasp_pose_from_approach_direction(
diff --git a/embodichain/toolkits/urdf_assembly/component.py b/embodichain/toolkits/urdf_assembly/component.py
index 211ecf18..ae027224 100644
--- a/embodichain/toolkits/urdf_assembly/component.py
+++ b/embodichain/toolkits/urdf_assembly/component.py
@@ -25,7 +25,7 @@
URDFAssemblyLogger,
)
from embodichain.toolkits.urdf_assembly.mesh import URDFMeshManager
-
+from embodichain.toolkits.urdf_assembly.name_normalizer import NameNormalizer
__all__ = ["ComponentRegistry", "URDFComponent", "URDFComponentManager"]
@@ -83,12 +83,40 @@ def __post_init__(self):
class URDFComponentManager:
- """Responsible for loading, renaming, and processing meshes for a single component."""
+ """Responsible for loading, renaming, and processing meshes for a single component.
+
+ This manager normalizes link and joint names according to a configurable
+ case policy so that the overall assembly naming scheme can be controlled
+ centrally (e.g. all links lowercase, all joints uppercase).
+ """
+
+ def __init__(
+ self,
+ mesh_manager: URDFMeshManager,
+ name_case: dict[str, str] | None = None,
+ ):
+ """Create a component manager.
+
+ Args:
+ mesh_manager (URDFMeshManager): Mesh manager used for copying and
+ rewriting mesh references.
+ name_case (dict[str, str] | None): Optional mapping controlling
+ how joint and link names are normalized. Supported keys are
+ ``"joint"`` and ``"link"`` with values ``"upper``,
+ ``"lower"`` or ``"none"``. When omitted, joints are
+ uppercased and links are lowercased (the previous default
+ behavior).
+ """
- def __init__(self, mesh_manager: URDFMeshManager):
self.mesh_manager = mesh_manager
self.logger = URDFAssemblyLogger.get_logger("component_manager")
+ self.name_normalizer = NameNormalizer(name_case)
+
+ def _apply_case(self, kind: str, name: str | None) -> str | None:
+ """Normalize a name using the NameNormalizer."""
+ return self.name_normalizer.normalize(kind, name)
+
def process_component(
self,
comp: str,
@@ -119,12 +147,12 @@ def process_component(
# Safe way to get link and joint names, handling None values
global_link_names = {
- link.get("name").lower()
+ self._apply_case("link", link.get("name"))
for link in links
if link.get("name") is not None
}
global_joint_names = {
- joint.get("name").upper()
+ self._apply_case("joint", joint.get("name"))
for joint in joints
if joint.get("name") is not None
}
@@ -143,15 +171,19 @@ def process_component(
# Generate unique name
if prefix:
- new_name = self._generate_unique_name(
- orig_name, prefix, global_link_names
- ).lower()
+ new_name = self._apply_case(
+ "link",
+ self._generate_unique_name(
+ orig_name, prefix, global_link_names
+ ),
+ )
else:
# For components without prefix, ensure names are unique
- if orig_name.lower() in global_link_names:
- new_name = f"{comp}_{orig_name}".lower()
+ normalized_orig = self._apply_case("link", orig_name)
+ if normalized_orig in global_link_names:
+ new_name = self._apply_case("link", f"{comp}_{orig_name}")
else:
- new_name = orig_name.lower()
+ new_name = normalized_orig
global_link_names.add(new_name)
@@ -160,7 +192,7 @@ def process_component(
base_points[comp] = new_name
first_link_flag = False
- # Update link name mapping and set link name to lowercase
+ # Update link name mapping and set link name according to policy
name_mapping[(comp, orig_name)] = new_name
link.set("name", new_name)
links.append(link)
@@ -176,9 +208,12 @@ def process_component(
if orig_joint_name is None:
continue
- new_joint_name = self._generate_unique_name(
- orig_joint_name, prefix, global_joint_names
- ).upper()
+ new_joint_name = self._apply_case(
+ "joint",
+ self._generate_unique_name(
+ orig_joint_name, prefix, global_joint_names
+ ),
+ )
global_joint_names.add(new_joint_name)
# Build the complete mapping table
@@ -192,16 +227,16 @@ def process_component(
# Set the new joint name
joint.set("name", new_joint_name)
- # Update parent and child links to lowercase - with None checks
+ # Update parent and child links with case normalization - with None checks
parent_elem = joint.find("parent")
child_elem = joint.find("child")
if parent_elem is not None:
parent = parent_elem.get("link")
if parent is not None:
- new_parent_name = name_mapping.get(
- (comp, parent), parent
- ).lower()
+ new_parent_name = self._apply_case(
+ "link", name_mapping.get((comp, parent), parent)
+ )
parent_elem.set("link", new_parent_name)
else:
self.logger.warning(
@@ -211,7 +246,9 @@ def process_component(
if child_elem is not None:
child = child_elem.get("link")
if child is not None:
- new_child_name = name_mapping.get((comp, child), child).lower()
+ new_child_name = self._apply_case(
+ "link", name_mapping.get((comp, child), child)
+ )
child_elem.set("link", new_child_name)
else:
self.logger.warning(
@@ -270,10 +307,14 @@ def _generate_unique_name(
if orig_name is None:
orig_name = "unnamed"
+ # For uniqueness checks we always operate on a normalized form that is
+ # consistent with the link case policy. This keeps collisions and
+ # generated names aligned with how names are written back to the URDF.
+ base_name = orig_name
if prefix and not orig_name.lower().startswith(prefix.lower()):
- new_name = f"{prefix}{orig_name}".lower()
- else:
- new_name = orig_name.lower()
+ base_name = f"{prefix}{orig_name}"
+
+ new_name = base_name
# Ensure the new name is unique
if new_name in existing_names:
diff --git a/embodichain/toolkits/urdf_assembly/connection.py b/embodichain/toolkits/urdf_assembly/connection.py
index 4dad94a1..7309118c 100644
--- a/embodichain/toolkits/urdf_assembly/connection.py
+++ b/embodichain/toolkits/urdf_assembly/connection.py
@@ -14,30 +14,259 @@
# limitations under the License.
# ----------------------------------------------------------------------------
+from __future__ import annotations
+
import xml.etree.ElementTree as ET
+from typing import Any
from scipy.spatial.transform import Rotation as R
-from embodichain.toolkits.urdf_assembly.logging_utils import (
- URDFAssemblyLogger,
-)
+from embodichain.toolkits.urdf_assembly.logging_utils import URDFAssemblyLogger
+from embodichain.toolkits.urdf_assembly.name_normalizer import NameNormalizer
__all__ = ["URDFConnectionManager"]
class URDFConnectionManager:
- r"""
- Responsible for managing connection rules between components and sensor attachments.
- """
+ r"""Responsible for managing connection rules between components and sensor attachments."""
+
+ _DEFAULT_ORIGIN = {"xyz": "0 0 0", "rpy": "0 0 0"}
- def __init__(self, base_link_name: str):
- r"""Initialize the URDFConnectionManager.
+ def __init__(self, base_link_name: str, name_case: dict[str, str] | None = None):
+ """Initialize the URDFConnectionManager.
Args:
- base_link_name (str): The name of the base link to which the chassis or other components may be attached.
+ base_link_name: The name of the base link to which the chassis or other
+ components may be attached.
+ name_case: Optional mapping controlling how joint and link names are
+ normalized. Supported keys are ``"joint"`` and ``"link"`` with
+ values ``"upper"``, ``"lower"`` or ``"none"``.
+
+ When omitted, joints are uppercased and links are lowercased (the
+ previous default behavior).
"""
self.base_link_name = base_link_name
self.logger = URDFAssemblyLogger.get_logger("connection_manager")
+ self.name_normalizer = NameNormalizer(name_case)
+
+ def _apply_case(self, kind: str, name: str | None) -> str | None:
+ """Normalize a name using the NameNormalizer."""
+ return self.name_normalizer.normalize(kind, name)
+
+ @staticmethod
+ def _get_attr(obj: Any, key: str, default: Any = None) -> Any:
+ """Read attribute from object or key from dict."""
+ if obj is None:
+ return default
+ if isinstance(obj, dict):
+ return obj.get(key, default)
+ return getattr(obj, key, default)
+
+ @staticmethod
+ def _format_scalar(value: Any) -> str:
+ """Format scalar values for URDF attribute strings."""
+ try:
+ f = float(value)
+ except Exception:
+ return "0"
+
+ # Keep strings stable and compact (avoid long repr / numpy scalars).
+ s = f"{f:.6f}".rstrip("0").rstrip(".")
+ return s if s else "0"
+
+ def _format_vec3(self, vec3: Any) -> str:
+ """Format a 3D vector as URDF 'x y z' string."""
+ try:
+ x, y, z = vec3[0], vec3[1], vec3[2]
+ except Exception:
+ return "0 0 0"
+ return f"{self._format_scalar(x)} {self._format_scalar(y)} {self._format_scalar(z)}"
+
+ def _origin_kwargs_from_transform(self, transform: Any | None) -> dict[str, str]:
+ """Convert a 4x4 transform matrix to URDF origin attributes."""
+ if transform is None:
+ return dict(self._DEFAULT_ORIGIN)
+
+ try:
+ xyz = transform[:3, 3]
+ rotation = R.from_matrix(transform[:3, :3])
+ rpy = rotation.as_euler("xyz")
+ except Exception as exc:
+ self.logger.warning(f"Invalid transform, fallback to identity: {exc}")
+ return dict(self._DEFAULT_ORIGIN)
+
+ return {"xyz": self._format_vec3(xyz), "rpy": self._format_vec3(rpy)}
+
+ @staticmethod
+ def _make_unique(base: str, existing: set[str]) -> str:
+ """Make a unique name by appending suffixes when needed."""
+ if base not in existing:
+ return base
+ idx = 1
+ while f"{base}_{idx}" in existing:
+ idx += 1
+ return f"{base}_{idx}"
+
+ def _collect_existing_joint_names(self, joints: list) -> set[str]:
+ names: set[str] = set()
+ for joint in joints:
+ if not hasattr(joint, "get"):
+ continue
+ raw = joint.get("name")
+ if not raw:
+ continue
+ normalized = self._apply_case("joint", raw)
+ if normalized:
+ names.add(normalized)
+ return names
+
+ def _append_fixed_joint(
+ self,
+ joints: list,
+ existing_joint_names: set[str],
+ joint_name: str,
+ parent_link: str,
+ child_link: str,
+ origin_kwargs: dict[str, str] | None = None,
+ ) -> None:
+ """Append a fixed joint if it doesn't already exist."""
+ normalized_joint_name = self._apply_case("joint", joint_name)
+ if not normalized_joint_name:
+ self.logger.error(f"Empty joint name for joint_name={joint_name!r}")
+ return
+
+ if normalized_joint_name in existing_joint_names:
+ self.logger.warning(f"Duplicate joint: {normalized_joint_name}")
+ return
+
+ joint = ET.Element("joint", name=normalized_joint_name, type="fixed")
+ ET.SubElement(joint, "origin", **(origin_kwargs or dict(self._DEFAULT_ORIGIN)))
+ ET.SubElement(joint, "parent", link=parent_link)
+ ET.SubElement(joint, "child", link=child_link)
+
+ joints.append(joint)
+ existing_joint_names.add(normalized_joint_name)
+
+ def _normalize_link_or_none(self, link_name: str | None) -> str | None:
+ if not link_name:
+ return None
+ return self._apply_case("link", link_name)
+
+ def _connect_chassis_to_base(
+ self,
+ joints: list,
+ base_points: dict,
+ existing_joint_names: set[str],
+ chassis_component: str,
+ ) -> bool:
+ if chassis_component not in base_points:
+ return False
+
+ chassis_first_link = self._normalize_link_or_none(
+ base_points.get(chassis_component)
+ )
+ if not chassis_first_link:
+ self.logger.error("Invalid chassis base link (None)")
+ return True
+
+ self._append_fixed_joint(
+ joints=joints,
+ existing_joint_names=existing_joint_names,
+ joint_name=f"BASE_LINK_TO_{chassis_component}_CONNECTOR",
+ parent_link=self.base_link_name,
+ child_link=chassis_first_link,
+ )
+ self.logger.info(
+ f"[{chassis_component.capitalize()}] connected to [base_link] via ({chassis_first_link})"
+ )
+ return True
+
+ def _connect_orphan_components_to_base(
+ self,
+ joints: list,
+ base_points: dict,
+ connection_rules: list,
+ component_transforms: dict,
+ existing_joint_names: set[str],
+ ) -> None:
+ # Find components that don't have parents in connection_rules
+ components_with_parents = {child for parent, child in connection_rules}
+ orphan_components = [
+ comp for comp in base_points.keys() if comp not in components_with_parents
+ ]
+
+ for comp in orphan_components:
+ comp_first_link = self._normalize_link_or_none(base_points.get(comp))
+ if not comp_first_link:
+ self.logger.error(f"Invalid base link for component [{comp}]")
+ continue
+
+ origin_kwargs = self._origin_kwargs_from_transform(
+ component_transforms.get(comp)
+ )
+ if comp in component_transforms:
+ self.logger.info(
+ f"Applied transform to base connection {comp}: {origin_kwargs}"
+ )
+
+ self._append_fixed_joint(
+ joints=joints,
+ existing_joint_names=existing_joint_names,
+ joint_name=f"BASE_TO_{comp}_CONNECTOR",
+ parent_link=self.base_link_name,
+ child_link=comp_first_link,
+ origin_kwargs=origin_kwargs,
+ )
+
+ self.logger.info(
+ f"[{comp.capitalize()}] connected to [base_link] via ({comp_first_link})"
+ )
+
+ def _connect_component_pair(
+ self,
+ joints: list,
+ base_points: dict,
+ parent_attach_points: dict,
+ parent: str,
+ child: str,
+ component_transforms: dict,
+ existing_joint_names: set[str],
+ ) -> None:
+ if parent not in parent_attach_points or child not in base_points:
+ self.logger.error(f"Invalid connection rule: {parent} -> {child}")
+ return
+
+ parent_connect_link = self._normalize_link_or_none(
+ parent_attach_points.get(parent)
+ )
+ child_connect_link = self._normalize_link_or_none(base_points.get(child))
+
+ if not parent_connect_link or not child_connect_link:
+ self.logger.error(
+ f"Invalid link in connection: {parent} ({parent_connect_link}) -> {child} ({child_connect_link})"
+ )
+ return
+
+ self.logger.info(
+ f"Connecting [{parent}]-({parent_connect_link}) to [{child}]-({child_connect_link})"
+ )
+
+ origin_kwargs = self._origin_kwargs_from_transform(
+ component_transforms.get(child)
+ )
+ if child in component_transforms:
+ self.logger.info(
+ f"Applied transform to connection {parent} -> {child}: {origin_kwargs}"
+ )
+
+ self._append_fixed_joint(
+ joints=joints,
+ existing_joint_names=existing_joint_names,
+ joint_name=self._apply_case("joint", f"{parent}_TO_{child}_CONNECTOR"),
+ parent_link=parent_connect_link,
+ child_link=child_connect_link,
+ origin_kwargs=origin_kwargs,
+ )
def add_connections(
self,
@@ -45,168 +274,195 @@ def add_connections(
base_points: dict,
parent_attach_points: dict,
connection_rules: list,
- component_transforms: dict = None,
- ):
+ component_transforms: dict | None = None,
+ ) -> None:
r"""Add connection joints between robot components according to the specified rules.
Args:
- joints (list): A list to collect joint elements.
- base_points (dict): A mapping from component names to their child connection link names.
- parent_attach_points (dict): A mapping from component names to their parent connection link names.
- connection_rules (list): A list of (parent, child) tuples specifying connection relationships.
- component_transforms (dict): Optional mapping from component names to their transform matrices.
+ joints: A list to collect joint elements.
+ base_points: Mapping from component names to their child connection link names.
+ parent_attach_points: Mapping from component names to their parent connection link names.
+ connection_rules: A list of (parent, child) tuples specifying connection relationships.
+ component_transforms: Optional mapping from component names to their 4x4 transform matrices.
"""
chassis_component = "chassis"
component_transforms = component_transforms or {}
- existing_joint_names = {
- joint.get("name") for joint in joints if hasattr(joint, "get")
- }
+ existing_joint_names = self._collect_existing_joint_names(joints)
# chassis is always attached to base_link (no transform applied to this connection)
- if chassis_component in base_points:
- chassis_first_link = base_points[chassis_component]
- joint_name = f"BASE_LINK_TO_{chassis_component.upper()}_CONNECTOR"
- if joint_name not in existing_joint_names:
- joint = ET.Element("joint", name=joint_name, type="fixed")
- ET.SubElement(joint, "origin", xyz="0 0 0", rpy="0 0 0")
- ET.SubElement(joint, "parent", link=self.base_link_name)
- ET.SubElement(joint, "child", link=chassis_first_link)
- joints.append(joint)
- existing_joint_names.add(joint_name)
- self.logger.info(
- f"[{chassis_component.capitalize()}] connected to [base_link] via ({chassis_first_link})"
- )
- else:
+ if not self._connect_chassis_to_base(
+ joints=joints,
+ base_points=base_points,
+ existing_joint_names=existing_joint_names,
+ chassis_component=chassis_component,
+ ):
# If no chassis, connect components directly to base_link with their transforms
self.logger.info(
"No chassis found, connecting components directly to base_link"
)
-
- # Find components that don't have parents in connection_rules
- components_with_parents = {child for parent, child in connection_rules}
- orphan_components = [
- comp
- for comp in base_points.keys()
- if comp not in components_with_parents
- ]
-
- for comp in orphan_components:
- comp_first_link = base_points[comp]
- joint_name = f"BASE_TO_{comp.upper()}_CONNECTOR"
-
- if joint_name not in existing_joint_names:
- joint = ET.Element("joint", name=joint_name, type="fixed")
-
- # Apply transform to this specific connection if the component has one
- if comp in component_transforms:
- transform = component_transforms[comp]
- xyz = transform[:3, 3] # Extract translation
- rotation = R.from_matrix(transform[:3, :3])
- rpy = rotation.as_euler("xyz")
-
- ET.SubElement(
- joint,
- "origin",
- xyz=f"{xyz[0]} {xyz[1]} {xyz[2]}",
- rpy=f"{rpy[0]} {rpy[1]} {rpy[2]}",
- )
- self.logger.info(
- f"Applied transform to base connection {comp}: xyz={xyz}, rpy={rpy}"
- )
- else:
- ET.SubElement(joint, "origin", xyz="0 0 0", rpy="0 0 0")
-
- ET.SubElement(joint, "parent", link=self.base_link_name)
- ET.SubElement(joint, "child", link=comp_first_link)
- joints.append(joint)
- existing_joint_names.add(joint_name)
-
- self.logger.info(
- f"[{comp.capitalize()}] connected to [base_link] via ({comp_first_link})"
- )
+ self._connect_orphan_components_to_base(
+ joints=joints,
+ base_points=base_points,
+ connection_rules=connection_rules,
+ component_transforms=component_transforms,
+ existing_joint_names=existing_joint_names,
+ )
# Process other connection relationships
for parent, child in connection_rules:
- if parent in parent_attach_points and child in base_points:
- parent_connect_link = parent_attach_points[parent].lower()
- child_connect_link = base_points[child].lower()
+ self._connect_component_pair(
+ joints=joints,
+ base_points=base_points,
+ parent_attach_points=parent_attach_points,
+ parent=parent,
+ child=child,
+ component_transforms=component_transforms,
+ existing_joint_names=existing_joint_names,
+ )
- self.logger.info(
- f"Connecting [{parent}]-({parent_connect_link}) to [{child}]-({child_connect_link})"
- )
+ def add_sensor_attachments(
+ self, links: list, joints: list, attach_dict: dict, base_points: dict
+ ) -> None:
+ r"""Attach sensors by adding their URDF links/joints and creating a fixed connector.
- # Create a unique joint name
- base_joint_name = f"{parent.upper()}_TO_{child.upper()}_CONNECTOR"
- if base_joint_name not in existing_joint_names:
- joint = ET.Element("joint", name=base_joint_name, type="fixed")
-
- # Apply transform to this specific connection if the child component has one
- if child in component_transforms:
- transform = component_transforms[child]
- xyz = transform[:3, 3] # Extract translation
- rotation = R.from_matrix(transform[:3, :3])
- rpy = rotation.as_euler("xyz")
-
- ET.SubElement(
- joint,
- "origin",
- xyz=f"{xyz[0]} {xyz[1]} {xyz[2]}",
- rpy=f"{rpy[0]} {rpy[1]} {rpy[2]}",
- )
- self.logger.info(
- f"Applied transform to connection {parent} -> {child}: xyz={xyz}, rpy={rpy}"
- )
- else:
- ET.SubElement(joint, "origin", xyz="0 0 0", rpy="0 0 0")
-
- ET.SubElement(joint, "parent", link=parent_connect_link)
- ET.SubElement(joint, "child", link=child_connect_link)
- joints.append(joint)
- existing_joint_names.add(base_joint_name)
- else:
- self.logger.warning(
- f"Duplicate connection rule: {parent} -> {child}"
- )
- else:
- self.logger.error(f"Invalid connection rule: {parent} -> {child}")
+ .. attention::
+ This is a legacy helper kept for backward compatibility. Newer code paths
+ use :class:`URDFSensorManager`.
+
+ Args:
+ links: Global list to collect sensor link elements.
+ joints: Global list to collect sensor joint elements.
+ attach_dict: Mapping from sensor names to attachment configs.
+ base_points: Mapping from component names to their base link names.
+ """
+ existing_link_names = {
+ self._apply_case("link", link.get("name"))
+ for link in links
+ if hasattr(link, "get") and link.get("name")
+ }
+ existing_link_names.discard(None)
+
+ existing_joint_names = self._collect_existing_joint_names(joints)
- def add_sensor_attachments(
- self, joints: list, attach_dict: dict, base_points: dict
- ):
- r"""Attach sensors to the robot by creating fixed joints."""
for sensor_name, attach in attach_dict.items():
- sensor_urdf = ET.parse(attach.sensor_urdf).getroot()
+ sensor_urdf_path = self._get_attr(attach, "sensor_urdf")
+ if not sensor_urdf_path:
+ self.logger.error(f"Sensor [{sensor_name}] has no sensor_urdf")
+ continue
- # Add sensor links and joints to the main lists
+ try:
+ sensor_urdf = ET.parse(sensor_urdf_path).getroot()
+ except Exception as exc:
+ self.logger.error(
+ f"Failed to parse sensor URDF for [{sensor_name}]: {exc}"
+ )
+ continue
+
+ link_name_map: dict[str, str] = {}
+ processed_link_names: list[str] = []
+
+ # Add sensor links to the links list (ensure lowercase + uniqueness)
for link in sensor_urdf.findall("link"):
- # Ensure sensor link names are lowercase
- link.set("name", link.get("name").lower())
- joints.append(link) # This should be added to links list instead
+ raw_name = link.get("name")
+ if not raw_name:
+ continue
+
+ normalized_raw = self._apply_case("link", raw_name)
+ if not normalized_raw:
+ continue
+
+ base_name = normalized_raw
+ sensor_suffix = str(sensor_name).lower()
+ if sensor_suffix and sensor_suffix not in base_name:
+ base_name = f"{base_name}_{sensor_suffix}"
+
+ unique_name = self._make_unique(base_name, existing_link_names)
+ link.set("name", unique_name)
+
+ link_name_map[normalized_raw] = unique_name
+ processed_link_names.append(unique_name)
+ existing_link_names.add(unique_name)
+ links.append(link)
+ # Add sensor joints to the joints list (ensure uppercase + update link references)
for joint in sensor_urdf.findall("joint"):
- # Ensure sensor joint names are uppercase and link references are lowercase
- joint.set("name", joint.get("name").upper())
+ raw_joint_name = joint.get("name") or "sensor_joint"
+
+ normalized_joint_name = self._apply_case(
+ "joint", f"{sensor_name}_{raw_joint_name}"
+ )
+ if not normalized_joint_name:
+ continue
+
+ normalized_joint_name = self._make_unique(
+ normalized_joint_name, existing_joint_names
+ )
+ joint.set("name", normalized_joint_name)
+
parent_elem = joint.find("parent")
child_elem = joint.find("child")
+
if parent_elem is not None:
- parent_elem.set("link", parent_elem.get("link").lower())
+ raw_parent = parent_elem.get("link")
+ normalized_parent = self._apply_case("link", raw_parent)
+ if normalized_parent and normalized_parent in link_name_map:
+ parent_elem.set("link", link_name_map[normalized_parent])
+ elif normalized_parent:
+ parent_elem.set("link", normalized_parent)
+
if child_elem is not None:
- child_elem.set("link", child_elem.get("link").lower())
+ raw_child = child_elem.get("link")
+ normalized_child = self._apply_case("link", raw_child)
+ if normalized_child and normalized_child in link_name_map:
+ child_elem.set("link", link_name_map[normalized_child])
+ elif normalized_child:
+ child_elem.set("link", normalized_child)
+
joints.append(joint)
+ existing_joint_names.add(normalized_joint_name)
+
+ if not processed_link_names:
+ self.logger.error(f"Sensor [{sensor_name}] has no elements")
+ continue
- parent_link = base_points.get(
- attach.parent_component, attach.parent_component
- ).lower() # Ensure lowercase
+ # Determine parent link: prefer explicit parent_link if provided.
+ parent_component = self._get_attr(attach, "parent_component")
+ raw_parent_link = self._get_attr(attach, "parent_link")
+ if raw_parent_link:
+ parent_link = self._apply_case("link", raw_parent_link)
+ else:
+ parent_link = self._apply_case(
+ "link",
+ base_points.get(parent_component, parent_component),
+ )
- # Create connection joint with uppercase name
- joint_name = (
- f"{attach.parent_component.upper()}_TO_{sensor_name.upper()}_CONNECTOR"
+ if not parent_link:
+ self.logger.error(
+ f"Invalid parent link for sensor [{sensor_name}] on component [{parent_component}]"
+ )
+ continue
+
+ # Create connector joint (apply transform if provided by attachment).
+ origin_kwargs = self._origin_kwargs_from_transform(
+ self._get_attr(attach, "transform")
)
- joint = ET.Element("joint", name=joint_name, type="fixed")
- ET.SubElement(joint, "origin", xyz="0 0 0", rpy="0 0 0")
- ET.SubElement(joint, "parent", link=parent_link)
- ET.SubElement(
- joint, "child", link=sensor_urdf.find("link").get("name").lower()
+
+ connector_joint_name = self._make_unique(
+ self._apply_case(
+ "joint", f"{parent_component}_TO_{sensor_name}_CONNECTOR"
+ )
+ or self._apply_case(
+ "joint", f"{parent_component}_TO_{sensor_name}_CONNECTOR".upper()
+ ),
+ existing_joint_names,
+ )
+
+ self._append_fixed_joint(
+ joints=joints,
+ existing_joint_names=existing_joint_names,
+ joint_name=connector_joint_name,
+ parent_link=parent_link,
+ child_link=processed_link_names[0],
+ origin_kwargs=origin_kwargs,
)
- joints.append(joint)
diff --git a/embodichain/toolkits/urdf_assembly/file_writer.py b/embodichain/toolkits/urdf_assembly/file_writer.py
index 4ddcd3fe..f1898f58 100644
--- a/embodichain/toolkits/urdf_assembly/file_writer.py
+++ b/embodichain/toolkits/urdf_assembly/file_writer.py
@@ -127,7 +127,7 @@ def generate_header(
now = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
# Calculate proper spacing for centered content
- header_width = 80
+ header_width = 120
separator_line = ""
def center_comment(text: str) -> str:
diff --git a/embodichain/toolkits/urdf_assembly/name_normalizer.py b/embodichain/toolkits/urdf_assembly/name_normalizer.py
new file mode 100644
index 00000000..ffd9ee16
--- /dev/null
+++ b/embodichain/toolkits/urdf_assembly/name_normalizer.py
@@ -0,0 +1,77 @@
+# ----------------------------------------------------------------------------
+# Copyright (c) 2021-2026 DexForce Technology Co., Ltd.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ----------------------------------------------------------------------------
+
+
+class NameNormalizer:
+ """Handles name normalization for different entity types."""
+
+ VALID_KEYS = {"joint", "link"}
+ VALID_MODES = {"upper", "lower", "none"}
+
+ def __init__(self, default_case: dict[str, str] | None = None):
+ """Initialize the NameNormalizer with default cases.
+
+ Args:
+ default_case (dict[str, str] | None): Default normalization modes for "joint" and "link".
+ """
+ self._name_case = {
+ "joint": "upper",
+ "link": "lower",
+ }
+ if default_case:
+ for key, mode in default_case.items():
+ if key in self.VALID_KEYS and mode in self.VALID_MODES:
+ self._name_case[key] = mode
+ else:
+ raise ValueError(
+ f"Invalid default_case entry {key}={mode}. "
+ f"Allowed keys: {self.VALID_KEYS}, allowed modes: {self.VALID_MODES}."
+ )
+
+ def set_case(self, key: str, mode: str):
+ """Set the normalization mode for a specific key.
+
+ Args:
+ key (str): The entity type ("joint" or "link").
+ mode (str): The normalization mode ("upper", "lower", "none").
+ """
+ if key in self.VALID_KEYS and mode in self.VALID_MODES:
+ self._name_case[key] = mode
+ else:
+ raise ValueError(
+ f"Invalid key or mode: {key}={mode}. "
+ f"Allowed keys: {self.VALID_KEYS}, allowed modes: {self.VALID_MODES}."
+ )
+
+ def normalize(self, kind: str, name: str | None) -> str | None:
+ """Normalize a name according to the configured case policy.
+
+ Args:
+ kind (str): One of "joint" or "link".
+ name (str | None): The original name.
+
+ Returns:
+ str | None: The normalized name, or the original value if kind is unknown or mode is "none".
+ """
+ if name is None:
+ return None
+
+ mode = self._name_case.get(kind, "none")
+ if mode == "lower":
+ return name.lower()
+ if mode == "upper":
+ return name.upper()
+ return name
diff --git a/embodichain/toolkits/urdf_assembly/signature.py b/embodichain/toolkits/urdf_assembly/signature.py
index 3ebbd73a..27a56521 100644
--- a/embodichain/toolkits/urdf_assembly/signature.py
+++ b/embodichain/toolkits/urdf_assembly/signature.py
@@ -62,6 +62,12 @@ def calculate_assembly_signature(self, urdf_dict: dict, output_path: str) -> str
signature_data = {
"output_filename": os.path.basename(output_path),
"components": {},
+ # Optional metadata that can affect the assembly even if the
+ # component URDF files themselves do not change. For example,
+ # the processing order and name prefixes for each component,
+ # and the global casing policy for links/joints.
+ "component_order_and_prefix": [],
+ "name_case": {},
}
def to_serializable(obj):
@@ -85,8 +91,20 @@ def to_serializable(obj):
else:
return obj
- # Process each component
+ # Process each entry passed in from the assembly manager. Most entries
+ # are components (with URDF files), but some may be metadata such as
+ # the component_order_and_prefix or name_case used during assembly.
for comp_type, comp_obj in urdf_dict.items():
+ # Special key reserved for component order/prefix metadata
+ if comp_type == "__component_order_and_prefix__":
+ signature_data["component_order_and_prefix"] = to_serializable(comp_obj)
+ continue
+
+ # Special key reserved for global name_case policy (link/joint casing)
+ if comp_type == "__name_case__":
+ signature_data["name_case"] = to_serializable(comp_obj)
+ continue
+
if comp_obj is None:
continue
diff --git a/embodichain/toolkits/urdf_assembly/urdf_assembly_manager.py b/embodichain/toolkits/urdf_assembly/urdf_assembly_manager.py
index 9739faa9..4d9fb7b6 100644
--- a/embodichain/toolkits/urdf_assembly/urdf_assembly_manager.py
+++ b/embodichain/toolkits/urdf_assembly/urdf_assembly_manager.py
@@ -14,6 +14,7 @@
# limitations under the License.
# ----------------------------------------------------------------------------
+import copy
import os
import time
import logging
@@ -128,6 +129,15 @@ def __init__(
):
self.logger = setup_urdf_logging()
+ # Global name normalization strategy for this assembly. By default,
+ # this preserves the legacy behavior: link names are lowercase and
+ # joint names are uppercase. The same mapping is passed down to
+ # managers that deal with naming so that the policy stays consistent.
+ self._name_case: dict[str, str] = {
+ "joint": "upper",
+ "link": "lower",
+ }
+
# Use registries for components and sensors
self.component_registry = component_registry or ComponentRegistry()
self.sensor_registry = sensor_registry or SensorRegistry()
@@ -137,13 +147,13 @@ def __init__(
# Initialize managers for components and sensors
self.component_manager = component_manager or URDFComponentManager(
- self.mesh_manager
+ self.mesh_manager, name_case=self._name_case
)
self.sensor_manager = sensor_manager or URDFSensorManager(self.mesh_manager)
# Processing order for components with their name prefixes
# Tuple format: (component_name, prefix)
- self.component_order = [
+ self._component_order_and_prefix = [
("chassis", None),
("legs", None),
("torso", None),
@@ -205,6 +215,150 @@ def __init__(
# Initialize signature manager instead of cache manager
self.signature_manager = URDFAssemblySignatureManager()
+ @property
+ def name_case(self):
+ """Get the current name case policy for joints and links.
+
+ Returns:
+ dict[str, str]: A dictionary mapping 'joint' and 'link' to their respective case modes.
+ """
+ return self._name_case
+
+ @name_case.setter
+ def name_case(self, new_name_case: dict[str, str]):
+ """Set a new name case policy for joints and links.
+
+ This method updates the name case policy and propagates it to the component and sensor managers.
+
+ Args:
+ new_name_case (dict[str, str]): A dictionary mapping 'joint' and 'link' to their desired case modes (e.g., 'upper', 'lower', 'none').
+ """
+ if not isinstance(new_name_case, dict):
+ raise ValueError(
+ "name_case must be a dictionary mapping 'joint' and 'link' to case modes."
+ )
+ if "joint" not in new_name_case or "link" not in new_name_case:
+ raise ValueError("name_case must contain keys 'joint' and 'link'.")
+
+ self._name_case = new_name_case
+
+ def _apply_case(self, kind: str, name: str | None) -> str | None:
+ """Normalize a name according to the assembly-wide case policy.
+
+ This helper mirrors the behavior of the managers' own case helpers so
+ that any name sets computed here (e.g. for sensors) stay consistent
+ with how names are written into the URDF.
+
+ Args:
+ kind (str): One of ``"joint"`` or ``"link"``.
+ name (str | None): The original name.
+
+ Returns:
+ str | None: The normalized name, or the original value if the
+ kind is unknown or its mode is ``"none"``.
+ """
+
+ if name is None:
+ return None
+
+ mode = self._name_case.get(kind, "none")
+ if mode == "lower":
+ return name.lower()
+ if mode == "upper":
+ return name.upper()
+ return name
+
+ @property
+ def component_order_and_prefix(self):
+ """Get the internal component order with their name prefixes.
+
+ Note:
+ This exposes the internal list of ``(component_name, prefix)`` pairs
+ used when assembling URDFs. In most user code it is recommended to
+ use :attr:`component_prefix` instead, which focuses on configuring
+ prefixes rather than ordering.
+
+ Returns:
+ list[tuple[str, str | None]]: A list of tuples specifying component
+ names and their prefixes.
+ """
+ return self._component_order_and_prefix
+
+ @component_order_and_prefix.setter
+ def component_order_and_prefix(self, new_order):
+ """Set the internal component prefix configuration.
+ Args:
+ new_order: Value assigned directly to the internal
+ ``_component_order_and_prefix`` attribute, typically a list of
+ ``(component_name, prefix)`` tuples.
+ Note:
+ This setter performs no validation or patch-style merging; it
+ stores ``new_order`` as provided.
+ """
+ self._component_order_and_prefix = new_order
+
+ @property
+ def component_prefix(self):
+ """Configure name prefixes per component type.
+
+ This is a user-facing alias over :attr:`component_order_and_prefix`.
+
+ Semantics:
+ This setter is **patch-only**: it updates prefixes for components that
+ already exist in the current internal order and does **not** allow
+ introducing new component names.
+
+ Returns:
+ list[tuple[str, str | None]]: The internal list of
+ ``(component_name, prefix)`` pairs.
+ """
+
+ return self.component_order_and_prefix
+
+ @component_prefix.setter
+ def component_prefix(self, new_prefixes):
+ if not isinstance(new_prefixes, list) or not all(
+ isinstance(item, tuple) and len(item) == 2 for item in new_prefixes
+ ):
+ raise ValueError(
+ "component_prefix must be a list of (component_name, prefix) tuples."
+ )
+
+ # Treat new_prefixes as a patch on top of the existing/default order:
+ # - For components already present in self._component_order_and_prefix, update their prefix.
+ # - Preserve components that are not mentioned, keeping their relative order.
+ #
+ # Note: New/unknown component names are rejected to keep the assembly order
+ # controlled internally.
+
+ # Allowed components are exactly those already present in the default order.
+ existing_components = {comp for comp, _ in self._component_order_and_prefix}
+
+ # Build override map from the incoming list, but only for existing components.
+ override_map = {}
+ for comp, prefix in new_prefixes:
+ if not isinstance(comp, str):
+ raise ValueError("component name in component_prefix must be a string.")
+ if comp not in existing_components:
+ raise ValueError(
+ f"component_prefix cannot introduce new component '{comp}'. "
+ f"Allowed components: {sorted(existing_components)}"
+ )
+ override_map[comp] = prefix
+
+ merged_order: list[tuple[str, str | None]] = []
+
+ # First, walk the existing order and apply overrides where available.
+ # The relative order of components is kept internal and usually does
+ # not need to be changed by users.
+ for comp, prefix in self._component_order_and_prefix:
+ if comp in override_map:
+ merged_order.append((comp, override_map.pop(comp)))
+ else:
+ merged_order.append((comp, prefix))
+
+ self._component_order_and_prefix = merged_order
+
def add_component(
self,
component_type: str,
@@ -536,6 +690,40 @@ def _find_end_link(
break # No further links found in the chain
return current_link
+ def _log_names_once(
+ self,
+ kind: str,
+ elems: list[ET.Element],
+ *,
+ max_items: int = 300,
+ max_chars: int = 8000,
+ ) -> None:
+ """Log element names in a single line (truncated)."""
+ names: list[str] = []
+ for e in elems:
+ n = e.get("name")
+ if n:
+ names.append(n)
+
+ total = len(names)
+ shown_names = names[:max_items]
+ text = ", ".join(shown_names)
+
+ truncated_items = max(0, total - len(shown_names))
+ truncated_chars = 0
+ if len(text) > max_chars:
+ text = text[:max_chars] + "..."
+ truncated_chars = 1
+
+ suffix_parts: list[str] = []
+ if truncated_items:
+ suffix_parts.append(f"truncated_items={truncated_items}")
+ if truncated_chars:
+ suffix_parts.append("truncated_chars=1")
+ suffix = f" ({', '.join(suffix_parts)})" if suffix_parts else ""
+
+ self.logger.info(f"[merge_urdfs] {kind}: count={total} names=[{text}]{suffix}")
+
@performance_monitor
def merge_urdfs(
self,
@@ -563,6 +751,16 @@ def merge_urdfs(
]
self.logger.info(f"🔧 Preparing to merge components: {available_components}")
+ order_items = " ".join(
+ f"[{comp}]({prefix})" for comp, prefix in self.component_order_and_prefix
+ )
+ self.logger.info(f"[component_order_and_prefix] {order_items}")
+
+ case_keys = [k for k in ("joint", "link") if k in self.name_case]
+ case_keys += [k for k in sorted(self.name_case) if k not in case_keys]
+ case_items = " ".join(f"[{k}]({self.name_case[k]})" for k in case_keys)
+ self.logger.info(f"[name_case] {case_items}")
+
for comp in available_components:
comp_obj = self.component_registry.get(comp)
self.logger.info(f" [{comp}]: {comp_obj.urdf_path}")
@@ -572,9 +770,21 @@ def merge_urdfs(
self.logger.debug(f" Transform: applied")
if use_signature_check:
- # Calculate current assembly signature
+ # Calculate current assembly signature. In addition to the component
+ # registry contents, include the current component_order_and_prefix
+ # so that changes to name prefixes also invalidate the cache.
+ component_info = self.component_registry.all().copy()
+ component_info["__component_order_and_prefix__"] = list(
+ self.component_order_and_prefix
+ )
+ # Also include the assembly-wide name_case policy so that
+ # renaming rules (e.g. link/joint casing) participate in the
+ # signature. This ensures that changing naming strategy forces
+ # a rebuild.
+ component_info["__name_case__"] = dict(self._name_case)
+
assembly_signature = self.signature_manager.calculate_assembly_signature(
- self.component_registry.all(), output_path
+ component_info, output_path
)
self.logger.info(f"Current assembly signature: [{assembly_signature}]")
@@ -606,6 +816,46 @@ def merge_urdfs(
robot_name = os.path.splitext(os.path.basename(output_path))[0]
merged_urdf = ET.Element("robot", name=robot_name)
+ # Global definitions live directly under and are not part
+ # of links/joints. To avoid polluting the merged URDF, we only merge global
+ # materials that are actually referenced by merged links' visuals.
+ materials: list[ET.Element] = []
+ material_names: set[str] = set()
+ material_sources: list[tuple[ET.Element, str]] = []
+
+ def _register_material_source(root: ET.Element, source: str) -> None:
+ material_sources.append((root, source))
+
+ def _merge_material_if_defined(mat_name: str) -> bool:
+ """Merge a global definition from known sources.
+
+ Only merges if the material is referenced and if a source URDF actually
+ defines it at the root. This prevents bringing in unused
+ materials from component URDFs.
+ """
+ if not mat_name or mat_name in material_names:
+ return False
+
+ matches: list[tuple[ET.Element, str]] = []
+ for root, source in material_sources:
+ for mat in root.findall("material"):
+ if mat.get("name") == mat_name:
+ matches.append((mat, source))
+
+ if not matches:
+ return False
+
+ if len(matches) > 1:
+ self.logger.debug(
+ f"Material '{mat_name}' defined in multiple URDF sources; using the first: {matches[0][1]}"
+ )
+
+ mat, source = matches[0]
+ materials.append(copy.deepcopy(mat))
+ material_names.add(mat_name)
+ self.logger.debug(f"Merged referenced material '{mat_name}' from {source}")
+ return True
+
# 2. Create single base link for the entire robot
base_link = ET.Element("link", name=self.base_link_name)
# Store links and joints separately for proper ordering
@@ -622,8 +872,12 @@ def merge_urdfs(
ensure_directory_exists(output_dir, self.logger)
mesh_manager = URDFMeshManager(output_dir)
mesh_manager.ensure_dirs()
- component_manager = URDFComponentManager(mesh_manager)
- connection_manager = URDFConnectionManager(self.base_link_name)
+ component_manager = URDFComponentManager(
+ mesh_manager, name_case=self._name_case
+ )
+ connection_manager = URDFConnectionManager(
+ self.base_link_name, name_case=self._name_case
+ )
# Initialize sensor manager with mesh_manager
sensor_manager = URDFSensorManager(mesh_manager)
@@ -647,7 +901,7 @@ def merge_urdfs(
if comp_obj and comp_obj.transform is not None:
component_transforms[comp] = comp_obj.transform
- for comp, prefix in self.component_order:
+ for comp, prefix in self.component_order_and_prefix:
comp_obj = self.component_registry.get(comp)
if not comp_obj:
continue
@@ -658,6 +912,7 @@ def merge_urdfs(
# Parse component URDF to analyze its structure
urdf_root = ET.parse(comp_obj.urdf_path).getroot()
+ _register_material_source(urdf_root, str(comp_obj.urdf_path))
# Determine parent component and attachment point for current component
parent_component = None
@@ -747,16 +1002,32 @@ def merge_urdfs(
component_transforms,
)
- # Track existing names for sensor processing
+ # Track existing names for sensor processing. Use the same case policy
+ # as the rest of the assembly so that collision checks are consistent
+ # with how names are written.
existing_link_names = {
- link.get("name").lower() for link in links if link.get("name")
+ self._apply_case("link", link.get("name"))
+ for link in links
+ if link.get("name")
}
existing_joint_names = {
- joint.get("name").upper() for joint in joints if joint.get("name")
+ self._apply_case("joint", joint.get("name"))
+ for joint in joints
+ if joint.get("name")
}
# 5. Process sensor attachments using the new sensor manager
for sensor_name, sensor_attach in self.sensor_registry.all().items():
+ # Register sensor URDF as a material source (do not merge materials eagerly).
+ try:
+ sensor_root = ET.parse(sensor_attach.sensor_urdf).getroot()
+ except Exception as exc:
+ self.logger.debug(
+ f"Failed to parse sensor URDF for material sourcing ({sensor_attach.sensor_urdf}): {exc}"
+ )
+ else:
+ _register_material_source(sensor_root, str(sensor_attach.sensor_urdf))
+
sensor_manager.attach_sensor(
sensor_name=sensor_name,
sensor_source=sensor_attach.sensor_urdf,
@@ -769,9 +1040,40 @@ def merge_urdfs(
links, joints, base_points, existing_link_names, existing_joint_names
)
- # 6. Add all links and joints to merged URDF in proper order
+ # 6. Merge only the global materials that are actually referenced by merged links.
+ # If a link references but no source URDF defines a global
+ # under , we warn but do not inject guessed fallbacks.
+ referenced_materials: set[str] = set()
+ for link in links:
+ for mat in link.findall(".//visual/material"):
+ mat_name = mat.get("name")
+ if not mat_name:
+ continue
+ # A material with children is already defined inline.
+ if list(mat):
+ continue
+ referenced_materials.add(mat_name)
+
+ missing_materials: list[str] = []
+ for mat_name in sorted(referenced_materials):
+ if mat_name in material_names:
+ continue
+ if not _merge_material_if_defined(mat_name):
+ missing_materials.append(mat_name)
+
+ for mat_name in missing_materials:
+ self.logger.warning(
+ f"Material '{mat_name}' referenced but not defined in any source URDF"
+ )
+
+ # Add global materials, then links/joints to merged URDF in proper order
+ for mat in materials:
+ merged_urdf.append(mat)
+
+ self._log_names_once("links", links)
for link in links:
merged_urdf.append(link)
+ self._log_names_once("joints", joints)
for joint in joints:
merged_urdf.append(joint)
diff --git a/embodichain/utils/__init__.py b/embodichain/utils/__init__.py
index 6285965f..b77db093 100644
--- a/embodichain/utils/__init__.py
+++ b/embodichain/utils/__init__.py
@@ -16,7 +16,6 @@
from .configclass import configclass, is_configclass
-
GLOBAL_SEED = 1024
diff --git a/embodichain/utils/configclass.py b/embodichain/utils/configclass.py
index c9f22ca5..7ca2671a 100644
--- a/embodichain/utils/configclass.py
+++ b/embodichain/utils/configclass.py
@@ -20,7 +20,6 @@
from typing import Any, ClassVar
from .string import callable_to_string, string_to_callable
-
_CONFIGCLASS_METHODS = ["to_dict", "replace", "copy", "validate"]
"""List of class methods added at runtime to dataclass."""
diff --git a/embodichain/utils/warp/kinematics/opw_solver.py b/embodichain/utils/warp/kinematics/opw_solver.py
index 1f1cf459..877324d1 100644
--- a/embodichain/utils/warp/kinematics/opw_solver.py
+++ b/embodichain/utils/warp/kinematics/opw_solver.py
@@ -18,7 +18,6 @@
import numpy as np
from typing import Tuple
-
wp_vec48f = wp.types.vector(length=48, dtype=float)
wp_vec6f = wp.types.vector(length=6, dtype=float)
@@ -30,6 +29,23 @@ def normalize_to_pi(angle: float) -> float:
return wp.atan2(wp.sin(angle), wp.cos(angle))
+@wp.func
+def normalize_in_limit(angle: float, lower: float, upper: float) -> float:
+ two_pi = 2.0 * wp.pi
+ k = wp.ceil((lower - angle) / two_pi)
+ result = angle + k * two_pi
+ return result
+
+
+@wp.func
+def is_within_limit(
+ angle: float, lower: float, upper: float, safe_margin: float
+) -> bool:
+ if angle < lower + safe_margin or angle > upper - safe_margin:
+ return False
+ return True
+
+
@wp.func
def safe_acos(x: float) -> float:
return wp.acos(wp.clamp(x, -1.0, 1.0))
@@ -219,6 +235,9 @@ def opw_ik_kernel(
params: OPWparam,
offsets: wp.array(dtype=float),
sign_corrections: wp.array(dtype=float),
+ lower_limits: wp_vec6f,
+ upper_limits: wp_vec6f,
+ safe_margin: float,
qpos: wp.array(dtype=float),
ik_valid: wp.array(dtype=int),
):
@@ -433,8 +452,10 @@ def opw_ik_kernel(
for k in range(DOF):
idx = j * DOF + k
- qpos[qpos_start + k] = normalize_to_pi(
- (theta[idx] + offsets[k]) * sign_corrections[k]
+ qpos[qpos_start + k] = normalize_in_limit(
+ (theta[idx] + offsets[k]) * sign_corrections[k],
+ lower=lower_limits[k],
+ upper=upper_limits[k],
)
# filter invalid solutions
@@ -449,42 +470,46 @@ def opw_ik_kernel(
)
t_err, r_err = get_transform_err(check_ee_pose, ee_pose)
# mark invalid solutions (cannot pass ik check)
+ ik_valid[i * N_SOL + j] = 1
+ for k in range(DOF):
+ if not is_within_limit(
+ qpos[qpos_start + k],
+ lower_limits[k],
+ upper_limits[k],
+ safe_margin=safe_margin,
+ ):
+ ik_valid[i * N_SOL + j] = 0
+ break
if t_err > 1e-2 or r_err > 1e-1:
ik_valid[i * N_SOL + j] = 0
- else:
- ik_valid[i * N_SOL + j] = 1
@wp.kernel
-def opw_best_ik_kernel(
- full_ik_result: wp.array(dtype=float),
- full_ik_valid: wp.array(dtype=int),
- qpos_seed: wp.array(dtype=float),
+def opw_ik_select_kernel(
+ full_ik_result: wp.array(dtype=float, ndim=3), # [n_sample, N_SOL, DOF]
+ full_ik_valid: wp.array(dtype=int, ndim=2), # [n_sample, N_SOL]
+ qpos_seed: wp.array(dtype=float, ndim=2), # [n_sample, DOF]
joint_weights: wp_vec6f,
- best_ik_result: wp.array(dtype=float),
- best_ik_valid: wp.array(dtype=int),
+ best_ik_result: wp.array(dtype=float, ndim=2), # [n_sample, DOF]
+ best_ik_valid: wp.array(dtype=int, ndim=1), # [n_sample, ]
):
- i = wp.tid()
- DOF = 6
- N_SOL = 8
-
+ i = wp.tid() # index for sample
best_weighted_dis = float(1e10)
best_ids = int(-1)
+ DOF = 6
+ N_SOL = 8
for j in range(N_SOL):
- is_full_valid = full_ik_valid[i * N_SOL + j]
+ is_full_valid = full_ik_valid[i, j]
if is_full_valid == 0:
# invalid ik result
continue
weighted_dis = 0.0
for t in range(DOF):
weighted_dis += (
- (full_ik_result[i * N_SOL * DOF + j * DOF + t] - qpos_seed[i * DOF + t])
- * joint_weights[0]
- * (
- full_ik_result[i * N_SOL * DOF + j * DOF + t]
- - qpos_seed[i * DOF + t]
- )
- * joint_weights[0]
+ (full_ik_result[i, j, t] - qpos_seed[i, t])
+ * joint_weights[t]
+ * (full_ik_result[i, j, t] - qpos_seed[i, t])
+ * joint_weights[t]
)
if weighted_dis < best_weighted_dis:
best_weighted_dis = weighted_dis
@@ -493,9 +518,7 @@ def opw_best_ik_kernel(
# found best solution
best_ik_valid[i] = 1
for k in range(DOF):
- best_ik_result[i * DOF + k] = full_ik_result[
- i * N_SOL * DOF + best_ids * DOF + k
- ]
+ best_ik_result[i, k] = full_ik_result[i, best_ids, k]
else:
# no valid solution
best_ik_valid[i] = 0
diff --git a/examples/agents/datasets/online_dataset_demo.py b/examples/agents/datasets/online_dataset_demo.py
index 84429a24..3bd07d3b 100644
--- a/examples/agents/datasets/online_dataset_demo.py
+++ b/examples/agents/datasets/online_dataset_demo.py
@@ -28,7 +28,7 @@
Usage::
- python examples/agents/datasets/online_dataset_demo.py
+ python examples/agents/datasets/online_dataset_demo.py
"""
from __future__ import annotations
@@ -76,7 +76,7 @@ def _build_engine(args: argparse.Namespace) -> OnlineDataEngine:
gym_config = load_json(config_path)
gym_config["headless"] = True
- gym_config["enable_rt"] = True
+ gym_config.setdefault("renderer", True)
gym_config["gpu_id"] = 0
gym_config["device"] = args.device
cfg = OnlineDataEngineCfg(
diff --git a/examples/sim/demo/grasp_cup_to_caffe.py b/examples/sim/demo/grasp_cup_to_caffe.py
index c2c69ab6..c59526ed 100644
--- a/examples/sim/demo/grasp_cup_to_caffe.py
+++ b/examples/sim/demo/grasp_cup_to_caffe.py
@@ -28,6 +28,7 @@
from embodichain.lab.sim import SimulationManager, SimulationManagerCfg
from embodichain.lab.sim.objects import Robot, RigidObject
from embodichain.lab.sim.cfg import (
+ RenderCfg,
LightCfg,
JointDrivePropertiesCfg,
RigidObjectCfg,
@@ -38,7 +39,7 @@
from embodichain.lab.sim.shapes import MeshCfg
from embodichain.data import get_data_path
from embodichain.utils import logger
-
+from embodichain.lab.gym.utils.gym_utils import add_env_launcher_args_to_parser
from embodichain.lab.sim.robots.dexforce_w1.cfg import DexforceW1Cfg
@@ -52,19 +53,7 @@ def parse_arguments():
parser = argparse.ArgumentParser(
description="Create and simulate a robot in SimulationManager"
)
- parser.add_argument(
- "--num_envs", type=int, default=9, help="Number of parallel environments"
- )
- parser.add_argument(
- "--enable_rt", action="store_true", help="Enable ray tracing rendering"
- )
- parser.add_argument("--headless", action="store_true", help="Enable headless mode")
- parser.add_argument(
- "--device",
- type=str,
- default="cpu",
- help="device to run the environment on, e.g., 'cpu' or 'cuda'",
- )
+ add_env_launcher_args_to_parser(parser)
return parser.parse_args()
@@ -81,23 +70,13 @@ def initialize_simulation(args) -> SimulationManager:
config = SimulationManagerCfg(
headless=True,
sim_device=args.device,
- enable_rt=args.enable_rt,
+ render_cfg=RenderCfg(renderer=args.renderer),
physics_dt=1.0 / 100.0,
num_envs=args.num_envs,
arena_space=2.5,
)
sim = SimulationManager(config)
- if args.enable_rt:
- light = sim.add_light(
- cfg=LightCfg(
- uid="main_light",
- color=(0.6, 0.6, 0.6),
- intensity=30.0,
- init_pos=(1.0, 0, 3.0),
- )
- )
-
return sim
@@ -440,6 +419,7 @@ def main():
table = create_table(sim)
caffe = create_caffe(sim)
cup = create_cup(sim)
+ sim.update(step=1)
# apply random perturbation
apply_random_xy_perturbation(cup, max_perturbation=0.05)
diff --git a/examples/sim/demo/pick_up_cloth.py b/examples/sim/demo/pick_up_cloth.py
index 36d1c243..d6f8e3fa 100644
--- a/examples/sim/demo/pick_up_cloth.py
+++ b/examples/sim/demo/pick_up_cloth.py
@@ -35,6 +35,7 @@
from embodichain.data import get_data_path
from embodichain.utils import logger
from embodichain.lab.sim.cfg import (
+ RenderCfg,
JointDrivePropertiesCfg,
RobotCfg,
RigidObjectCfg,
@@ -47,51 +48,7 @@
import os
from embodichain.lab.sim.shapes import MeshCfg, CubeCfg
import tempfile
-
-
-def parse_arguments():
- """
- Parse command-line arguments to configure the simulation.
-
- Returns:
- argparse.Namespace: Parsed arguments including number of environments, device, and rendering options.
- """
- parser = argparse.ArgumentParser(
- description="Create and simulate a robot in SimulationManager"
- )
- parser.add_argument(
- "--enable_rt", action="store_true", help="Enable ray tracing rendering"
- )
- parser.add_argument(
- "--num_envs", type=int, default=1, help="Number of parallel environments"
- )
- return parser.parse_args()
-
-
-def initialize_simulation(args):
- """
- Initialize the simulation environment based on the provided arguments.
-
- Args:
- args (argparse.Namespace): Parsed command-line arguments.
-
- Returns:
- SimulationManager: Configured simulation manager instance.
- """
- config = SimulationManagerCfg(
- headless=True,
- sim_device="cuda",
- enable_rt=args.enable_rt,
- physics_dt=1.0 / 100.0,
- num_envs=args.num_envs,
- )
- sim = SimulationManager(config)
-
- light = sim.add_light(
- cfg=LightCfg(uid="main_light", intensity=50.0, init_pos=(0, 0, 2.0))
- )
-
- return sim
+from embodichain.lab.gym.utils.gym_utils import add_env_launcher_args_to_parser
def create_robot(sim: SimulationManager, position=[0.0, 0.0, 0.0]):
@@ -148,18 +105,18 @@ def create_padding_box(sim: SimulationManager):
padding_box_cfg = RigidObjectCfg(
uid="padding_box",
shape=CubeCfg(
- size=[0.01, 0.04, 0.03],
+ size=[0.02, 0.07, 0.05],
),
attrs=RigidBodyAttributesCfg(
mass=1.0,
- static_friction=0.95,
- dynamic_friction=0.9,
+ static_friction=0.01,
+ dynamic_friction=0.00,
restitution=0.01,
min_position_iters=32,
min_velocity_iters=8,
),
body_type="kinematic",
- init_pos=[0.5, 0.0, 0.01],
+ init_pos=[0.5, 0.0, 0.026],
init_rot=[0.0, 0.0, 0.0],
)
padding_box = sim.add_rigid_object(cfg=padding_box_cfg)
@@ -219,7 +176,7 @@ def create_cloth(sim: SimulationManager):
mass=0.01,
youngs=1e10,
poissons=0.4,
- thickness=0.04,
+ thickness=0.06,
bending_stiffness=0.01,
bending_damping=0.1,
dynamic_friction=0.95,
@@ -283,8 +240,26 @@ def main():
This function initializes the simulation, creates the robot and other objects,
and performs the press softbody task.
"""
- args = parse_arguments()
- sim = initialize_simulation(args)
+ parser = argparse.ArgumentParser(
+ description="Create a simulation scene with SimulationManager"
+ )
+ add_env_launcher_args_to_parser(parser)
+ args = parser.parse_args()
+ # Configure the simulation
+ sim_cfg = SimulationManagerCfg(
+ width=1920,
+ height=1080,
+ num_envs=args.num_envs,
+ headless=True,
+ physics_dt=1.0 / 100.0, # Physics timestep (100 Hz)
+ sim_device="cuda",
+ render_cfg=RenderCfg(
+ renderer=args.renderer
+ ), # Enable ray tracing for better visuals
+ )
+
+ # Create the simulation instance
+ sim = SimulationManager(sim_cfg)
robot = create_robot(sim)
cloth = create_cloth(sim)
@@ -312,8 +287,7 @@ def main():
n_waypoint = grab_traj.shape[1]
for i in range(n_waypoint):
robot.set_qpos(grab_traj[:, i, :])
- sim.update(step=4)
- time.sleep(1e-2)
+ sim.update(step=3)
input("Press Enter to exit the simulation...")
diff --git a/examples/sim/demo/press_softbody.py b/examples/sim/demo/press_softbody.py
index 25e1640d..f5fada63 100644
--- a/examples/sim/demo/press_softbody.py
+++ b/examples/sim/demo/press_softbody.py
@@ -34,6 +34,7 @@
from embodichain.data import get_data_path
from embodichain.utils import logger
from embodichain.lab.sim.cfg import (
+ RenderCfg,
RobotCfg,
LightCfg,
SoftObjectCfg,
@@ -41,6 +42,7 @@
SoftbodyPhysicalAttributesCfg,
URDFCfg,
)
+from embodichain.lab.gym.utils.gym_utils import add_env_launcher_args_to_parser
from embodichain.lab.sim.shapes import MeshCfg
@@ -54,12 +56,7 @@ def parse_arguments():
parser = argparse.ArgumentParser(
description="Create and simulate a robot in SimulationManager"
)
- parser.add_argument(
- "--enable_rt", action="store_true", help="Enable ray tracing rendering"
- )
- parser.add_argument(
- "--num_envs", type=int, default=9, help="Number of parallel environments"
- )
+ add_env_launcher_args_to_parser(parser)
return parser.parse_args()
@@ -76,16 +73,12 @@ def initialize_simulation(args):
config = SimulationManagerCfg(
headless=True,
sim_device="cuda",
- enable_rt=args.enable_rt,
+ render_cfg=RenderCfg(renderer=args.renderer),
physics_dt=1.0 / 100.0,
num_envs=args.num_envs,
)
sim = SimulationManager(config)
- light = sim.add_light(
- cfg=LightCfg(uid="main_light", intensity=50.0, init_pos=(0, 0, 2.0))
- )
-
return sim
diff --git a/examples/sim/demo/scoop_ice.py b/examples/sim/demo/scoop_ice.py
index 00e05d77..3f861d98 100644
--- a/examples/sim/demo/scoop_ice.py
+++ b/examples/sim/demo/scoop_ice.py
@@ -29,6 +29,7 @@
from embodichain.lab.sim import SimulationManager, SimulationManagerCfg
from embodichain.lab.sim.objects import Robot, RigidObject, RigidObjectGroup
from embodichain.lab.sim.cfg import (
+ RenderCfg,
JointDrivePropertiesCfg,
RobotCfg,
URDFCfg,
@@ -44,9 +45,10 @@
from embodichain.lab.sim.solvers import PytorchSolverCfg
from embodichain.data import get_data_path
from embodichain.utils import logger
+from embodichain.lab.gym.utils.gym_utils import add_env_launcher_args_to_parser
-def initialize_simulation():
+def initialize_simulation(args):
"""
Initialize the simulation environment based on the provided arguments.
@@ -58,14 +60,13 @@ def initialize_simulation():
"""
config = SimulationManagerCfg(
headless=True,
- sim_device="cpu",
- enable_rt=True,
+ render_cfg=RenderCfg(renderer=args.renderer),
physics_dt=1.0 / 100.0,
)
sim = SimulationManager(config)
light = sim.add_light(
- cfg=LightCfg(uid="main_light", intensity=50.0, init_pos=(0, 0, 2.0))
+ cfg=LightCfg(uid="main_light", intensity=30.0, init_pos=(0, 0, 2.0))
)
return sim
@@ -308,7 +309,7 @@ def create_ice_cubes(sim: SimulationManager):
cfg=VisualMaterialCfg(
base_color=[1.0, 1.0, 1.0, 1.0],
ior=1.31,
- roughness=0.05,
+ roughness=0.2,
material_type="BSDF",
)
)
@@ -529,13 +530,17 @@ def scoop_ice(sim: SimulationManager, robot: Robot, scoop: RigidObject):
def main():
+ parser = argparse.ArgumentParser(description="Scoop ice task simulation")
+ add_env_launcher_args_to_parser(parser)
+ args = parser.parse_args()
+
"""
Main function to demonstrate robot simulation.
This function initializes the simulation, creates the robot and other objects,
and performs the scoop ice task.
"""
- sim = initialize_simulation()
+ sim = initialize_simulation(args)
# Create simulation objects
robot = create_robot(sim)
diff --git a/examples/sim/gizmo/gizmo_camera.py b/examples/sim/gizmo/gizmo_camera.py
index 4cb9071b..296c3be4 100644
--- a/examples/sim/gizmo/gizmo_camera.py
+++ b/examples/sim/gizmo/gizmo_camera.py
@@ -28,9 +28,10 @@
from embodichain.lab.sim import SimulationManager, SimulationManagerCfg
from embodichain.lab.sim.sensors import Camera, CameraCfg
-from embodichain.lab.sim.cfg import RigidObjectCfg, RigidBodyAttributesCfg
+from embodichain.lab.sim.cfg import RigidObjectCfg, RigidBodyAttributesCfg, RenderCfg
from embodichain.lab.sim.shapes import CubeCfg
from embodichain.utils import logger
+from embodichain.lab.gym.utils.gym_utils import add_env_launcher_args_to_parser
def main():
@@ -40,20 +41,7 @@ def main():
parser = argparse.ArgumentParser(
description="Create and simulate a camera with gizmo in SimulationManager"
)
- parser.add_argument(
- "--device",
- type=str,
- default="cpu",
- choices=["cpu", "cuda"],
- help="Device to run simulation on",
- )
- parser.add_argument("--headless", action="store_true", help="Run in headless mode")
- parser.add_argument(
- "--enable_rt",
- action="store_true",
- default=False,
- help="Enable ray tracing for better visuals",
- )
+ add_env_launcher_args_to_parser(parser)
args = parser.parse_args()
# Configure the simulation
@@ -62,7 +50,7 @@ def main():
height=1080,
physics_dt=1.0 / 100.0,
sim_device=args.device,
- enable_rt=args.enable_rt,
+ render_cfg=RenderCfg(renderer=args.renderer),
)
# Create simulation context
diff --git a/examples/sim/gizmo/gizmo_object.py b/examples/sim/gizmo/gizmo_object.py
index 06066e06..b0931f24 100644
--- a/examples/sim/gizmo/gizmo_object.py
+++ b/examples/sim/gizmo/gizmo_object.py
@@ -23,9 +23,9 @@
import time
from embodichain.lab.sim import SimulationManager, SimulationManagerCfg
-from embodichain.lab.sim.cfg import RigidBodyAttributesCfg
+from embodichain.lab.sim.cfg import RigidBodyAttributesCfg, RenderCfg
from embodichain.lab.sim.shapes import CubeCfg
-
+from embodichain.lab.gym.utils.gym_utils import add_env_launcher_args_to_parser
from embodichain.lab.sim.objects import RigidObject, RigidObjectCfg
from embodichain.utils import logger
@@ -37,22 +37,7 @@ def main():
parser = argparse.ArgumentParser(
description="Create a simulation scene with SimulationManager"
)
- parser.add_argument(
- "--headless",
- action="store_true",
- default=False,
- help="Run simulation in headless mode",
- )
- parser.add_argument(
- "--device", type=str, default="cpu", help="Simulation device (cuda or cpu)"
- )
- parser.add_argument(
- "--enable_rt",
- action="store_true",
- default=False,
- help="Enable ray tracing for better visuals",
- )
-
+ add_env_launcher_args_to_parser(parser)
args = parser.parse_args()
# Configure the simulation
@@ -62,7 +47,9 @@ def main():
headless=args.headless,
physics_dt=1.0 / 100.0, # Physics timestep (100 Hz)
sim_device=args.device,
- enable_rt=args.enable_rt, # Enable ray tracing for better visuals
+ render_cfg=RenderCfg(
+ renderer=args.renderer
+ ), # Enable ray tracing for better visuals
)
# Create the simulation instance
diff --git a/examples/sim/gizmo/gizmo_robot.py b/examples/sim/gizmo/gizmo_robot.py
index c6ccf473..40f0d0c1 100644
--- a/examples/sim/gizmo/gizmo_robot.py
+++ b/examples/sim/gizmo/gizmo_robot.py
@@ -24,11 +24,12 @@
from embodichain.lab.sim import SimulationManager, SimulationManagerCfg
from embodichain.lab.sim.cfg import (
+ RenderCfg,
RobotCfg,
URDFCfg,
JointDrivePropertiesCfg,
)
-
+from embodichain.lab.gym.utils.gym_utils import add_env_launcher_args_to_parser
from embodichain.lab.sim.solvers import PinkSolverCfg
from embodichain.data import get_data_path
from embodichain.utils import logger
@@ -41,15 +42,7 @@ def main():
parser = argparse.ArgumentParser(
description="Create a simulation scene with SimulationManager"
)
- parser.add_argument(
- "--device", type=str, default="cpu", help="Simulation device (cuda or cpu)"
- )
- parser.add_argument(
- "--enable_rt",
- action="store_true",
- default=False,
- help="Enable ray tracing for better visuals",
- )
+ add_env_launcher_args_to_parser(parser)
args = parser.parse_args()
# Configure the simulation
@@ -58,7 +51,7 @@ def main():
height=1080,
physics_dt=1.0 / 100.0,
sim_device=args.device,
- enable_rt=args.enable_rt,
+ render_cfg=RenderCfg(renderer=args.renderer),
)
sim = SimulationManager(sim_cfg)
diff --git a/examples/sim/gizmo/gizmo_scene.py b/examples/sim/gizmo/gizmo_scene.py
index 15144487..a37e6eb8 100644
--- a/examples/sim/gizmo/gizmo_scene.py
+++ b/examples/sim/gizmo/gizmo_scene.py
@@ -30,12 +30,14 @@
from embodichain.lab.sim import SimulationManager, SimulationManagerCfg
from embodichain.lab.sim.cfg import (
+ RenderCfg,
RobotCfg,
URDFCfg,
JointDrivePropertiesCfg,
RigidObjectCfg,
RigidBodyAttributesCfg,
)
+from embodichain.lab.gym.utils.gym_utils import add_env_launcher_args_to_parser
from embodichain.lab.sim.shapes import CubeCfg
from embodichain.lab.sim.sensors import CameraCfg
from embodichain.lab.sim.solvers import PinkSolverCfg
@@ -49,24 +51,17 @@ def main():
parser = argparse.ArgumentParser(
description="Create a simulation scene with SimulationManager"
)
- parser.add_argument(
- "--device", type=str, default="cpu", help="Simulation device (cuda or cpu)"
- )
- parser.add_argument(
- "--enable_rt",
- action="store_true",
- default=False,
- help="Enable ray tracing for better visuals",
- )
+ add_env_launcher_args_to_parser(parser)
args = parser.parse_args()
# Configure the simulation
sim_cfg = SimulationManagerCfg(
width=1920,
height=1080,
+ headless=args.headless,
physics_dt=1.0 / 100.0,
sim_device=args.device,
- enable_rt=args.enable_rt,
+ render_cfg=RenderCfg(renderer=args.renderer),
)
sim = SimulationManager(sim_cfg)
diff --git a/examples/sim/gizmo/gizmo_w1.py b/examples/sim/gizmo/gizmo_w1.py
index 7eacab29..09779c84 100644
--- a/examples/sim/gizmo/gizmo_w1.py
+++ b/examples/sim/gizmo/gizmo_w1.py
@@ -24,11 +24,12 @@
from embodichain.lab.sim import SimulationManager, SimulationManagerCfg
from embodichain.lab.sim.cfg import (
+ RenderCfg,
RobotCfg,
URDFCfg,
JointDrivePropertiesCfg,
)
-
+from embodichain.lab.gym.utils.gym_utils import add_env_launcher_args_to_parser
from embodichain.lab.sim.solvers import PinkSolverCfg
from embodichain.data import get_data_path
from embodichain.utils import logger
@@ -41,24 +42,17 @@ def main():
parser = argparse.ArgumentParser(
description="Create a simulation scene with SimulationManager"
)
- parser.add_argument(
- "--device", type=str, default="cpu", help="Simulation device (cuda or cpu)"
- )
- parser.add_argument(
- "--enable_rt",
- action="store_true",
- default=False,
- help="Enable ray tracing for better visuals",
- )
+ add_env_launcher_args_to_parser(parser)
args = parser.parse_args()
# Configure the simulation
sim_cfg = SimulationManagerCfg(
width=1920,
height=1080,
+ headless=args.headless,
physics_dt=1.0 / 100.0,
sim_device=args.device,
- enable_rt=args.enable_rt,
+ render_cfg=RenderCfg(renderer=args.renderer),
)
sim = SimulationManager(sim_cfg)
diff --git a/examples/sim/scene/scene_demo.py b/examples/sim/scene/scene_demo.py
index 711145c8..b119cdfb 100644
--- a/examples/sim/scene/scene_demo.py
+++ b/examples/sim/scene/scene_demo.py
@@ -24,11 +24,18 @@
import math
import embodichain.utils.logger as logger
from embodichain.lab.sim import SimulationManager, SimulationManagerCfg
-from embodichain.lab.sim.cfg import RigidBodyAttributesCfg, LightCfg, RobotCfg, URDFCfg
+from embodichain.lab.sim.cfg import (
+ RenderCfg,
+ RigidBodyAttributesCfg,
+ LightCfg,
+ RobotCfg,
+ URDFCfg,
+)
from embodichain.lab.sim.shapes import MeshCfg
from embodichain.lab.sim.objects import RigidObject, RigidObjectCfg, Robot
from embodichain.data.assets.scene_assets import SceneData
from embodichain.data.constants import EMBODICHAIN_DEFAULT_DATA_ROOT
+from embodichain.lab.gym.utils.gym_utils import add_env_launcher_args_to_parser
def resolve_asset_path(scene_name: str) -> str:
@@ -91,18 +98,7 @@ def main():
choices=["kitchen", "factory", "office", "local"],
help="Choose which scene to load",
)
- parser.add_argument(
- "--num_envs", type=int, default=1, help="Number of parallel environments"
- )
- parser.add_argument(
- "--device", type=str, default="cpu", help="Simulation device (cuda or cpu)"
- )
- parser.add_argument(
- "--disable_rt",
- action="store_true",
- default=False,
- help="Disable ray tracing for better visuals",
- )
+ add_env_launcher_args_to_parser(parser)
args = parser.parse_args()
logger.log_info(f"Initializing scene '{args.scene}'")
@@ -121,7 +117,7 @@ def main():
headless=True,
physics_dt=1.0 / 100.0,
sim_device=args.device,
- enable_rt=not args.disable_rt,
+ render_cfg=RenderCfg(renderer=args.renderer),
num_envs=args.num_envs,
arena_space=10.0,
)
diff --git a/examples/sim/sensors/batch_camera.py b/examples/sim/sensors/batch_camera.py
index 7e46b44d..f9c10cd4 100644
--- a/examples/sim/sensors/batch_camera.py
+++ b/examples/sim/sensors/batch_camera.py
@@ -19,7 +19,7 @@
import matplotlib.pyplot as plt
from embodichain.lab.sim import SimulationManager, SimulationManagerCfg
-from embodichain.lab.sim.cfg import RigidObjectCfg, LightCfg
+from embodichain.lab.sim.cfg import RenderCfg, RigidObjectCfg, LightCfg
from embodichain.lab.sim.shapes import MeshCfg
from embodichain.lab.sim.objects import RigidObject, Light
from embodichain.lab.sim.sensors import (
@@ -28,6 +28,7 @@
CameraCfg,
StereoCameraCfg,
)
+from embodichain.lab.gym.utils.gym_utils import add_env_launcher_args_to_parser
from embodichain.data import get_data_path
@@ -37,7 +38,7 @@ def main(args):
sim_device=args.device,
num_envs=args.num_envs,
arena_space=2,
- enable_rt=args.enable_rt,
+ render_cfg=RenderCfg(renderer=args.renderer),
)
sim = SimulationManager(config)
@@ -120,22 +121,7 @@ def main(args):
import argparse
parser = argparse.ArgumentParser(description="Run the batch robot simulation.")
- parser.add_argument(
- "--num_envs", type=int, default=4, help="Number of environments to simulate."
- )
- parser.add_argument(
- "--device",
- type=str,
- default="cpu",
- choices=["cpu", "cuda"],
- help="Device to run the simulation on.",
- )
- parser.add_argument(
- "--headless", action="store_true", help="Run the simulation in headless mode."
- )
- parser.add_argument(
- "--enable_rt", action="store_true", help="Enable ray tracing rendering."
- )
+ add_env_launcher_args_to_parser(parser)
parser.add_argument(
"--sensor_type",
type=str,
diff --git a/examples/sim/sensors/create_contact_sensor.py b/examples/sim/sensors/create_contact_sensor.py
index 3a1c933a..17c26caf 100644
--- a/examples/sim/sensors/create_contact_sensor.py
+++ b/examples/sim/sensors/create_contact_sensor.py
@@ -25,6 +25,7 @@
from embodichain.lab.sim import SimulationManager, SimulationManagerCfg
from embodichain.lab.sim.cfg import (
+ RenderCfg,
RigidBodyAttributesCfg,
)
from embodichain.lab.sim.sensors import (
@@ -34,6 +35,7 @@
from embodichain.lab.sim.shapes import CubeCfg
from embodichain.lab.sim.objects import RigidObject, RigidObjectCfg, Robot, RobotCfg
from embodichain.data import get_data_path
+from embodichain.lab.gym.utils.gym_utils import add_env_launcher_args_to_parser
def create_cube(
@@ -177,24 +179,7 @@ def main():
parser = argparse.ArgumentParser(
description="Create a simulation scene with SimulationManager"
)
- parser.add_argument(
- "--headless",
- action="store_true",
- default=False,
- help="Run simulation in headless mode",
- )
- parser.add_argument(
- "--num_envs", type=int, default=64, help="Number of parallel environments"
- )
- parser.add_argument(
- "--device", type=str, default="cpu", help="Simulation device (cuda or cpu)"
- )
- parser.add_argument(
- "--enable_rt",
- action="store_true",
- default=False,
- help="Enable ray tracing for better visuals",
- )
+ add_env_launcher_args_to_parser(parser)
args = parser.parse_args()
# Configure the simulation
@@ -202,10 +187,12 @@ def main():
width=1920,
height=1080,
num_envs=args.num_envs,
- headless=args.headless,
+ headless=True,
physics_dt=1.0 / 100.0, # Physics timestep (100 Hz)
sim_device=args.device,
- enable_rt=args.enable_rt, # Enable ray tracing for better visuals
+ render_cfg=RenderCfg(
+ renderer=args.renderer
+ ), # Enable ray tracing for better visuals
)
# Create the simulation instance
diff --git a/examples/sim/utility/workspace_analyzer/analyze_cartesian_workspace.py b/examples/sim/utility/workspace_analyzer/analyze_cartesian_workspace.py
index 0871b6ad..8d2b5b9c 100644
--- a/examples/sim/utility/workspace_analyzer/analyze_cartesian_workspace.py
+++ b/examples/sim/utility/workspace_analyzer/analyze_cartesian_workspace.py
@@ -20,7 +20,7 @@
from embodichain.lab.sim import SimulationManager, SimulationManagerCfg
from embodichain.lab.sim.robots import DexforceW1Cfg
-from embodichain.lab.sim.cfg import MarkerCfg
+from embodichain.lab.sim.cfg import MarkerCfg, RenderCfg
from embodichain.lab.sim.utility.workspace_analyzer.workspace_analyzer import (
WorkspaceAnalyzer,
WorkspaceAnalyzerConfig,
@@ -36,10 +36,12 @@
torch.set_printoptions(precision=5, sci_mode=False)
config = SimulationManagerCfg(
- headless=False, sim_device="cpu", width=1080, height=1080
+ headless=False,
+ sim_device="cuda",
+ width=1080,
+ height=1080,
)
sim = SimulationManager(config)
- sim.set_manual_update(False)
cfg = DexforceW1Cfg.from_dict(
{"uid": "dexforce_w1", "version": "v021", "arm_kind": "industrial"}
@@ -48,7 +50,11 @@
print("DexforceW1 robot added to the simulation.")
# Set left arm joint positions (mirrored)
- left_qpos = torch.tensor([0, -np.pi / 4, 0.0, -np.pi / 2, -np.pi / 4, 0.0, 0.0])
+ left_qpos = torch.tensor(
+ [0, -np.pi / 4, 0.0, -np.pi / 2, -np.pi / 4, 0.0, 0.0],
+ dtype=torch.float32,
+ device=robot.device,
+ )
right_qpos = -left_qpos
robot.set_qpos(
qpos=left_qpos,
@@ -87,7 +93,7 @@
wa_cartesian = WorkspaceAnalyzer(
robot=robot, config=cartesian_config, sim_manager=sim
)
- results_cartesian = wa_cartesian.analyze(num_samples=1000, visualize=True)
+ results_cartesian = wa_cartesian.analyze(num_samples=50000, visualize=True)
print(f"\nCartesian Space Results:")
print(
f" Reachable points: {results_cartesian['num_reachable']} / {results_cartesian['num_samples']}"
diff --git a/examples/sim/utility/workspace_analyzer/analyze_joint_workspace.py b/examples/sim/utility/workspace_analyzer/analyze_joint_workspace.py
index 6ba8ad4c..5c658fa9 100644
--- a/examples/sim/utility/workspace_analyzer/analyze_joint_workspace.py
+++ b/examples/sim/utility/workspace_analyzer/analyze_joint_workspace.py
@@ -20,7 +20,6 @@
from embodichain.lab.sim import SimulationManager, SimulationManagerCfg
from embodichain.lab.sim.robots import DexforceW1Cfg
-
from embodichain.lab.sim.utility.workspace_analyzer.workspace_analyzer import (
WorkspaceAnalyzer,
)
@@ -43,7 +42,7 @@
print("Example: Joint Space Analysis")
wa_joint = WorkspaceAnalyzer(robot=robot, sim_manager=sim_manager)
- results_joint = wa_joint.analyze(num_samples=3000, visualize=True)
+ results_joint = wa_joint.analyze(num_samples=30000, visualize=True)
print(f"\nJoint Space Results:")
print(
diff --git a/examples/sim/utility/workspace_analyzer/analyze_plane_workspace.py b/examples/sim/utility/workspace_analyzer/analyze_plane_workspace.py
index 957b3535..8bd1b4ce 100644
--- a/examples/sim/utility/workspace_analyzer/analyze_plane_workspace.py
+++ b/examples/sim/utility/workspace_analyzer/analyze_plane_workspace.py
@@ -25,19 +25,21 @@
WorkspaceAnalyzerConfig,
AnalysisMode,
)
-from embodichain.lab.sim.cfg import MarkerCfg
+from embodichain.lab.sim.cfg import MarkerCfg, RenderCfg
from embodichain.lab.sim.utility.workspace_analyzer.configs.visualization_config import (
VisualizationConfig,
)
-
if __name__ == "__main__":
# Example usage
np.set_printoptions(precision=5, suppress=True)
torch.set_printoptions(precision=5, sci_mode=False)
config = SimulationManagerCfg(
- headless=False, sim_device="cpu", width=1080, height=1080
+ headless=False,
+ sim_device="cpu",
+ width=1080,
+ height=1080,
)
sim = SimulationManager(config)
sim.set_manual_update(False)
diff --git a/pyproject.toml b/pyproject.toml
index 25b15290..728190e5 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -26,7 +26,7 @@ dynamic = ["version"]
# Core install dependencies (kept from requirements.txt). Some VCS links are
# specified using PEP 508 direct references where present.
dependencies = [
- "dexsim_engine==0.3.11",
+ "dexsim_engine==0.4.0",
"setuptools>=78.1.1",
"gymnasium>=0.29.1",
"langchain",
@@ -36,28 +36,22 @@ dependencies = [
"pin-pink",
"casadi",
"qpsolvers[osqp]==4.8.1",
- "pytorch_kinematics==0.7.6",
+ "pytorch_kinematics==0.10.0",
"polars==1.31.0",
"PyYAML>=6.0",
- "accelerate>=1.10.0",
"wandb>=0.21.0",
"tensorboard>=2.20.0",
- "transformers>=4.53.0",
- "diffusers>=0.32.1",
- "deepspeed>=0.16.2",
"ortools",
"prettytable",
- "black==24.3.0",
+ "black==26.3.1",
"fvcore",
"h5py",
"tensordict",
- "viser==1.0.21"
+ "viser==1.0.21",
+ "lerobot>=0.4.4"
]
[project.optional-dependencies]
-lerobot = [
- "lerobot==0.4.4"
-]
[tool.setuptools.dynamic]
version = { file = ["VERSION"] }
diff --git a/scripts/benchmark/__init__.py b/scripts/benchmark/__init__.py
new file mode 100644
index 00000000..dd650e90
--- /dev/null
+++ b/scripts/benchmark/__init__.py
@@ -0,0 +1,15 @@
+# ----------------------------------------------------------------------------
+# Copyright (c) 2021-2026 DexForce Technology Co., Ltd.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ----------------------------------------------------------------------------
diff --git a/scripts/benchmark/__main__.py b/scripts/benchmark/__main__.py
new file mode 100644
index 00000000..ee9eac0a
--- /dev/null
+++ b/scripts/benchmark/__main__.py
@@ -0,0 +1,103 @@
+# ----------------------------------------------------------------------------
+# Copyright (c) 2021-2026 DexForce Technology Co., Ltd.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ----------------------------------------------------------------------------
+
+"""Unified CLI entry point for ``python -m scripts.benchmark``.
+
+Usage examples::
+
+ python -m scripts.benchmark rl --tasks push_cube --algorithms ppo --suite default
+ python -m scripts.benchmark rl --rebuild-report-only
+ python -m scripts.benchmark robotics-kinematic-solver -s pytorch
+"""
+
+from __future__ import annotations
+
+import argparse
+import sys
+
+
+def _run_robotics_kinematic_solver_cli(args: argparse.Namespace) -> None:
+ """Run robotics kinematic solver benchmark with forwarded CLI args."""
+ from scripts.benchmark.robotics.kinematic_solver.run_benchmark import (
+ run_all_benchmarks,
+ )
+
+ run_all_benchmarks(selected_solvers=args.solvers)
+
+
+def _run_rl_cli(_: argparse.Namespace) -> None:
+ """Run RL benchmark CLI entrypoint."""
+ from scripts.benchmark.rl.run_benchmark import main as rl_main
+
+ rl_main()
+
+
+def main() -> None:
+ """Dispatch to the appropriate benchmark sub-command CLI."""
+ parser = argparse.ArgumentParser(
+ prog="scripts.benchmark",
+ description="EmbodiChain benchmark command-line interface.",
+ )
+ subparsers = parser.add_subparsers(dest="command")
+
+ # -- rl ------------------------------------------------------------------
+ rl_parser = subparsers.add_parser(
+ "rl",
+ help="Run RL benchmark: train, evaluate, aggregate, and report results.",
+ )
+ rl_parser.set_defaults(func=_run_rl_cli)
+
+ # -- robotics-kinematic-solver -------------------------------------------
+ robotics_ks_parser = subparsers.add_parser(
+ "robotics-kinematic-solver",
+ help="Benchmark the OPW kinematic solver (FK/IK accuracy and speed).",
+ )
+ robotics_ks_parser.add_argument(
+ "--solvers",
+ "-s",
+ nargs="+",
+ choices=("opw", "pytorch", "all"),
+ default=["all"],
+ help="Solvers to benchmark. Use one or more of: opw, pytorch, all.",
+ )
+ robotics_ks_parser.set_defaults(func=_run_robotics_kinematic_solver_cli)
+
+ # -- Parse ---------------------------------------------------------------
+ # If no sub-command is given, print help and exit.
+ if len(sys.argv) < 2 or sys.argv[1] in ("-h", "--help"):
+ parser.print_help()
+ sys.exit(0)
+
+ # Determine which sub-command was selected, then reconstruct argv so
+ # that each sub-command's entry point can call ``parse_args()`` normally.
+ known, _ = parser.parse_known_args()
+
+ if hasattr(known, "func"):
+ # Rewrite sys.argv so the sub-command's argparse sees only its own args.
+ subcommand_argv = [f"scripts.benchmark {sys.argv[1]}"] + sys.argv[2:]
+ original_argv = sys.argv
+ sys.argv = subcommand_argv
+ try:
+ known.func(known)
+ finally:
+ sys.argv = original_argv
+ else:
+ parser.print_help()
+ sys.exit(1)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/scripts/benchmark/opw_solver.py b/scripts/benchmark/opw_solver.py
deleted file mode 100644
index c248eaba..00000000
--- a/scripts/benchmark/opw_solver.py
+++ /dev/null
@@ -1,155 +0,0 @@
-# ----------------------------------------------------------------------------
-# Copyright (c) 2021-2026 DexForce Technology Co., Ltd.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ----------------------------------------------------------------------------
-
-import torch
-import numpy as np
-import warp as wp
-from scipy.spatial.transform import Rotation
-from embodichain.lab.sim.solvers.opw_solver import OPWSolver, OPWSolverCfg
-from typing import Tuple, List
-import time
-
-
-def get_pose_err(matrix_a: np.ndarray, matrix_b: np.ndarray) -> Tuple[float, float]:
- t_err = np.linalg.norm(matrix_a[:3, 3] - matrix_b[:3, 3])
- relative_rot = matrix_a[:3, :3].T @ matrix_b[:3, :3]
- cos_angle = (np.trace(relative_rot) - 1) / 2.0
- cos_angle = np.clip(cos_angle, -1.0, 1.0)
- r_err = np.arccos(cos_angle)
- return t_err, r_err
-
-
-def get_poses_err(
- matrix_a_list: List[np.ndarray], matrix_b_list: List[np.ndarray]
-) -> Tuple[float, float]:
- t_errs = []
- r_errs = []
- for mat_a, mat_b in zip(matrix_a_list, matrix_b_list):
- t_err, r_err = get_pose_err(mat_a, mat_b)
- t_errs.append(t_err)
- r_errs.append(r_err)
- return np.mean(t_errs), np.mean(r_errs)
-
-
-def check_opw_solver(solver_warp, solver_py_opw, n_samples=1000):
- DOF = 6
- qpos_np = np.random.uniform(low=-np.pi, high=np.pi, size=(n_samples, DOF)).astype(
- float
- )
- qpos = torch.tensor(qpos_np, device=torch.device("cuda"), dtype=torch.float32)
- xpos = solver_warp.get_fk(qpos)
- qpos_seed = torch.tensor(
- [0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
- device=torch.device("cuda"),
- dtype=torch.float32,
- )
-
- warp_ik_start_time = time.time()
- warp_ik_success, warp_ik_qpos = solver_warp.get_ik(
- xpos,
- qpos_seed=qpos_seed,
- initial_guess=qpos,
- # return_all_solutions=True,
- )
- warp_cost_time = time.time() - warp_ik_start_time
-
- # TODO: debug code
- # warp_ik_success_np = warp_ik_success.cpu().numpy()
- # warp_ik_failure_indices = np.where(warp_ik_success_np == False)[0]
- # if len(warp_ik_failure_indices) > 0:
- # failure_qpos = qpos_np[warp_ik_failure_indices]
- # failure_xpos = xpos.cpu().numpy()[warp_ik_failure_indices]
- # print("=====warp_ik_failure_qpos:\n", repr(failure_qpos))
- # print("=====warp_ik_failure_xpos:\n", repr(failure_xpos))
-
- # print("=====xpos:\n", repr(xpos.cpu().numpy()))
- # print("=====warp_ik_qpos:\n", repr(warp_ik_qpos.cpu().numpy()))
- # print("=====warp_ik_success:\n", repr(warp_ik_success.cpu().numpy()))
-
- check_xpos = solver_warp.get_fk(warp_ik_qpos)
- warp_t_mean_err, warp_r_mean_err = get_poses_err(
- [x.cpu().numpy() for x in xpos],
- [x.cpu().numpy() for x in check_xpos],
- )
-
- py_opw_ik_start_time = time.time()
- py_opw_ik_success, py_opw_ik_qpos = solver_py_opw.get_ik(
- xpos, qpos_seed=qpos_seed, initial_guess=qpos
- )
- py_opw_cost_time = time.time() - py_opw_ik_start_time
-
- check_xpos = solver_warp.get_fk(py_opw_ik_qpos.to(torch.device("cuda")))
- py_opw_t_mean_err, py_opw_r_mean_err = get_poses_err(
- [x.cpu().numpy() for x in xpos],
- [x.cpu().numpy() for x in check_xpos],
- )
-
- return (
- warp_cost_time,
- warp_t_mean_err,
- warp_r_mean_err,
- py_opw_cost_time,
- py_opw_t_mean_err,
- py_opw_r_mean_err,
- )
-
-
-def benchmark_opw_solver():
- cfg = OPWSolverCfg()
- cfg.a1 = 400.333
- cfg.a2 = -251.449
- cfg.b = 0.0
- cfg.c1 = 830
- cfg.c2 = 1177.556
- cfg.c3 = 1443.593
- cfg.c4 = 230
- cfg.offsets = (
- 0.0,
- 82.21350356417211 * np.pi / 180.0,
- -167.21710113148163 * np.pi / 180.0,
- 0.0,
- 0.0,
- 0.0,
- )
- cfg.flip_axes = (True, False, True, True, False, True)
- cfg.has_parallelogram = False
-
- # TODO: ignore pk_serial_chain for OPW
- solver_warp = cfg.init_solver(device=torch.device("cuda"), pk_serial_chain="")
- solver_py_opw = cfg.init_solver(device=torch.device("cpu"), pk_serial_chain="")
- n_samples = [100, 1000, 10000, 100000]
- # n_samples = [100]
- for n_sample in n_samples:
- # check_opw_solver(solver_warp, solver_py_opw, device=device, n_samples=n_sample)
- (
- warp_cost_time,
- warp_t_mean_err,
- warp_r_mean_err,
- py_opw_cost_time,
- py_opw_t_mean_err,
- py_opw_r_mean_err,
- ) = check_opw_solver(solver_warp, solver_py_opw, n_samples=n_sample)
- print(f"===warp OPW Solver FK/IK test over {n_sample} samples:")
- print(f" Warp IK time: {warp_cost_time * 1000:.6f} ms")
- print(f"Translation mean error: {warp_t_mean_err*1000:.6f} mm")
- print(f"Rotation mean error: {warp_r_mean_err*180/np.pi:.6f} degrees")
- print(f"===Py OPW IK time: {py_opw_cost_time * 1000:.6f} ms")
- print(f"Translation mean error: {py_opw_t_mean_err*1000:.6f} mm")
- print(f"Rotation mean error: {py_opw_r_mean_err*180/np.pi:.6f} degrees")
-
-
-if __name__ == "__main__":
- benchmark_opw_solver()
diff --git a/scripts/benchmark/rl/__init__.py b/scripts/benchmark/rl/__init__.py
new file mode 100644
index 00000000..b142c88c
--- /dev/null
+++ b/scripts/benchmark/rl/__init__.py
@@ -0,0 +1,21 @@
+# ----------------------------------------------------------------------------
+# Copyright (c) 2021-2026 DexForce Technology Co., Ltd.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ----------------------------------------------------------------------------
+
+from __future__ import annotations
+
+from .runner import BenchmarkRunner
+
+__all__ = ["BenchmarkRunner"]
diff --git a/scripts/benchmark/rl/algorithms/__init__.py b/scripts/benchmark/rl/algorithms/__init__.py
new file mode 100644
index 00000000..dd650e90
--- /dev/null
+++ b/scripts/benchmark/rl/algorithms/__init__.py
@@ -0,0 +1,15 @@
+# ----------------------------------------------------------------------------
+# Copyright (c) 2021-2026 DexForce Technology Co., Ltd.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ----------------------------------------------------------------------------
diff --git a/scripts/benchmark/rl/algorithms/grpo.yaml b/scripts/benchmark/rl/algorithms/grpo.yaml
new file mode 100644
index 00000000..e33c673b
--- /dev/null
+++ b/scripts/benchmark/rl/algorithms/grpo.yaml
@@ -0,0 +1,24 @@
+name: grpo
+config:
+ policy:
+ name: actor_only
+ actor:
+ type: mlp
+ network_cfg:
+ hidden_sizes: [256, 256]
+ activation: relu
+ algorithm:
+ name: grpo
+ cfg:
+ learning_rate: 0.0001
+ n_epochs: 10
+ batch_size: 8192
+ gamma: 0.99
+ clip_coef: 0.2
+ ent_coef: 0.01
+ kl_coef: 0.0
+ group_size: 4
+ eps: 1.0e-8
+ reset_every_rollout: true
+ truncate_at_first_done: true
+ max_grad_norm: 0.5
diff --git a/scripts/benchmark/rl/algorithms/ppo.yaml b/scripts/benchmark/rl/algorithms/ppo.yaml
new file mode 100644
index 00000000..361c9386
--- /dev/null
+++ b/scripts/benchmark/rl/algorithms/ppo.yaml
@@ -0,0 +1,26 @@
+name: ppo
+config:
+ policy:
+ name: actor_critic
+ actor:
+ type: mlp
+ network_cfg:
+ hidden_sizes: [256, 256]
+ activation: relu
+ critic:
+ type: mlp
+ network_cfg:
+ hidden_sizes: [256, 256]
+ activation: relu
+ algorithm:
+ name: ppo
+ cfg:
+ learning_rate: 0.0001
+ n_epochs: 10
+ batch_size: 8192
+ gamma: 0.99
+ gae_lambda: 0.95
+ clip_coef: 0.2
+ ent_coef: 0.01
+ vf_coef: 0.5
+ max_grad_norm: 0.5
diff --git a/scripts/benchmark/rl/config.py b/scripts/benchmark/rl/config.py
new file mode 100644
index 00000000..da5131d3
--- /dev/null
+++ b/scripts/benchmark/rl/config.py
@@ -0,0 +1,70 @@
+# ----------------------------------------------------------------------------
+# Copyright (c) 2021-2026 DexForce Technology Co., Ltd.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ----------------------------------------------------------------------------
+
+from __future__ import annotations
+
+from copy import deepcopy
+from pathlib import Path
+from typing import Any
+
+import yaml
+
+BENCHMARK_ROOT = Path(__file__).resolve().parent
+
+
+def load_yaml(path: str | Path) -> dict[str, Any]:
+ """Load a YAML file into a dictionary."""
+ with Path(path).open("r", encoding="utf-8") as file:
+ data = yaml.safe_load(file) or {}
+ if not isinstance(data, dict):
+ raise TypeError(f"Expected mapping in YAML file {path}, got {type(data)!r}.")
+ return data
+
+
+def deep_update(base: dict[str, Any], override: dict[str, Any]) -> dict[str, Any]:
+ """Recursively merge `override` into `base` and return a new mapping."""
+ merged = deepcopy(base)
+ for key, value in override.items():
+ if key in merged and isinstance(merged[key], dict) and isinstance(value, dict):
+ merged[key] = deep_update(merged[key], value)
+ else:
+ merged[key] = deepcopy(value)
+ return merged
+
+
+def load_task_spec(name: str) -> dict[str, Any]:
+ """Load a benchmark task specification by name."""
+ return load_yaml(BENCHMARK_ROOT / "tasks" / f"{name}.yaml")
+
+
+def load_algorithm_spec(name: str) -> dict[str, Any]:
+ """Load a benchmark algorithm specification by name."""
+ return load_yaml(BENCHMARK_ROOT / "algorithms" / f"{name}.yaml")
+
+
+def load_suite_spec(name: str = "default") -> dict[str, Any]:
+ """Load a benchmark suite specification by name."""
+ return load_yaml(BENCHMARK_ROOT / "suites" / f"{name}.yaml")
+
+
+__all__ = [
+ "BENCHMARK_ROOT",
+ "deep_update",
+ "load_algorithm_spec",
+ "load_suite_spec",
+ "load_task_spec",
+ "load_yaml",
+]
diff --git a/scripts/benchmark/rl/metrics.py b/scripts/benchmark/rl/metrics.py
new file mode 100644
index 00000000..f1ce9185
--- /dev/null
+++ b/scripts/benchmark/rl/metrics.py
@@ -0,0 +1,253 @@
+# ----------------------------------------------------------------------------
+# Copyright (c) 2021-2026 DexForce Technology Co., Ltd.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ----------------------------------------------------------------------------
+
+from __future__ import annotations
+
+from collections import defaultdict
+from math import isnan
+from statistics import mean, pstdev
+from typing import Any
+
+
+def _iter_valid_threshold_points(
+ eval_history: list[dict[str, float]],
+ metric_key: str,
+):
+ """Yield `(step, metric)` pairs with valid numeric values."""
+ for item in eval_history:
+ metric_value = item.get(metric_key)
+ step_value = item.get("global_step")
+ if metric_value is None or step_value is None:
+ continue
+ if not isinstance(metric_value, (int, float)) or not isinstance(
+ step_value, (int, float)
+ ):
+ continue
+ if isnan(metric_value):
+ continue
+ yield int(step_value), float(metric_value)
+
+
+def compute_final_metric_stable(
+ eval_history: list[dict[str, float]],
+ metric_key: str,
+ window_size: int = 3,
+) -> float | None:
+ """Return the mean of the last `window_size` valid metric values."""
+ valid_values = [
+ metric_value
+ for _, metric_value in _iter_valid_threshold_points(eval_history, metric_key)
+ ]
+ if not valid_values:
+ return None
+ effective_window = max(1, window_size)
+ return mean(valid_values[-effective_window:])
+
+
+def compute_steps_to_threshold_first_hit(
+ eval_history: list[dict[str, float]],
+ metric_key: str,
+ threshold: float,
+) -> int | None:
+ """Return the first step where `metric_key` reaches `threshold`."""
+ for step_value, metric_value in _iter_valid_threshold_points(
+ eval_history, metric_key
+ ):
+ if metric_value >= threshold:
+ return step_value
+ return None
+
+
+def compute_steps_to_threshold_sustained(
+ eval_history: list[dict[str, float]],
+ metric_key: str,
+ threshold: float,
+ sustain_count: int = 3,
+) -> int | None:
+ """Return the first step where the threshold is met for `sustain_count` evals."""
+ if sustain_count <= 1:
+ return compute_steps_to_threshold_first_hit(eval_history, metric_key, threshold)
+
+ consecutive_hits = 0
+ first_step_in_window: int | None = None
+ for step_value, metric_value in _iter_valid_threshold_points(
+ eval_history, metric_key
+ ):
+ if metric_value >= threshold:
+ consecutive_hits += 1
+ if first_step_in_window is None:
+ first_step_in_window = step_value
+ if consecutive_hits >= sustain_count:
+ return first_step_in_window
+ else:
+ consecutive_hits = 0
+ first_step_in_window = None
+ return None
+
+
+def aggregate_runs(run_results: list[dict[str, Any]]) -> list[dict[str, Any]]:
+ """Aggregate run results by task and algorithm."""
+ grouped: dict[tuple[str, str], list[dict[str, Any]]] = defaultdict(list)
+ for result in run_results:
+ grouped[(result["task"], result["algorithm"])].append(result)
+
+ summaries: list[dict[str, Any]] = []
+ for (task, algorithm), runs in sorted(grouped.items()):
+ summary: dict[str, Any] = {
+ "task": task,
+ "algorithm": algorithm,
+ "num_runs": len(runs),
+ }
+ scalar_keys = {
+ "final_reward",
+ "final_success_rate",
+ "final_success_rate_stable",
+ "final_episode_length",
+ "training_fps",
+ "environment_fps",
+ "peak_gpu_memory_mb",
+ }
+ for key in scalar_keys:
+ values = [
+ float(run[key])
+ for run in runs
+ if isinstance(run.get(key), (int, float)) and not isnan(run[key])
+ ]
+ if values:
+ summary[f"{key}_mean"] = mean(values)
+ summary[f"{key}_std"] = pstdev(values) if len(values) > 1 else 0.0
+ step_keys = {
+ "steps_to_success_threshold",
+ "steps_to_success_threshold_first_hit",
+ }
+ for step_key in step_keys:
+ steps = [
+ int(run[step_key]) for run in runs if isinstance(run.get(step_key), int)
+ ]
+ if steps:
+ summary[f"{step_key}_mean"] = mean(steps)
+ summary[f"{step_key}_std"] = pstdev(steps) if len(steps) > 1 else 0.0
+ summaries.append(summary)
+
+ return summaries
+
+
+def _valid_float(value: Any) -> float | None:
+ if isinstance(value, (int, float)) and not isnan(float(value)):
+ return float(value)
+ return None
+
+
+def build_leaderboard(
+ aggregate_results: list[dict[str, Any]],
+ run_results: list[dict[str, Any]] | None = None,
+) -> list[dict[str, Any]]:
+ """Build leaderboard entries from aggregated benchmark summaries."""
+ grouped_summary: dict[str, list[dict[str, Any]]] = defaultdict(list)
+ for item in aggregate_results:
+ grouped_summary[item["algorithm"]].append(item)
+
+ grouped_runs: dict[str, list[dict[str, Any]]] = defaultdict(list)
+ for item in run_results or []:
+ grouped_runs[item["algorithm"]].append(item)
+
+ leaderboard: list[dict[str, Any]] = []
+ for algorithm, items in grouped_summary.items():
+ stable_success_values = [
+ float(item["final_success_rate_stable_mean"])
+ for item in items
+ if isinstance(item.get("final_success_rate_stable_mean"), (int, float))
+ and not isnan(item["final_success_rate_stable_mean"])
+ ]
+ success_values = [
+ float(item["final_success_rate_mean"])
+ for item in items
+ if isinstance(item.get("final_success_rate_mean"), (int, float))
+ and not isnan(item["final_success_rate_mean"])
+ ]
+ reward_values = [
+ float(item["final_reward_mean"])
+ for item in items
+ if isinstance(item.get("final_reward_mean"), (int, float))
+ and not isnan(item["final_reward_mean"])
+ ]
+ score = mean(stable_success_values) if stable_success_values else float("nan")
+ steps_values = [
+ float(item["steps_to_success_threshold_mean"])
+ for item in items
+ if isinstance(item.get("steps_to_success_threshold_mean"), (int, float))
+ and not isnan(item["steps_to_success_threshold_mean"])
+ ]
+ run_success_values = [
+ float(run["final_success_rate"])
+ for run in grouped_runs.get(algorithm, [])
+ if _valid_float(run.get("final_success_rate")) is not None
+ ]
+ task_scores = {
+ item["task"]: float(item["final_success_rate_stable_mean"])
+ for item in items
+ if _valid_float(item.get("final_success_rate_stable_mean")) is not None
+ }
+ raw_task_scores = {
+ item["task"]: float(item["final_success_rate_mean"])
+ for item in items
+ if _valid_float(item.get("final_success_rate_mean")) is not None
+ }
+ leaderboard.append(
+ {
+ "algorithm": algorithm,
+ "score": score,
+ "steps_to_success_threshold": (
+ mean(steps_values) if steps_values else float("nan")
+ ),
+ "success_rate_std": (
+ pstdev(run_success_values) if len(run_success_values) > 1 else 0.0
+ ),
+ "avg_success_rate": (
+ mean(success_values) if success_values else float("nan")
+ ),
+ "avg_success_rate_stable": score,
+ "avg_final_reward": (
+ mean(reward_values) if reward_values else float("nan")
+ ),
+ "tasks_covered": len(items),
+ "tasks": task_scores,
+ "tasks_raw": raw_task_scores,
+ }
+ )
+
+ leaderboard.sort(
+ key=lambda item: (
+ (
+ -(item["score"])
+ if isinstance(item["score"], float) and not isnan(item["score"])
+ else float("inf")
+ ),
+ item["algorithm"],
+ )
+ )
+ for index, item in enumerate(leaderboard, start=1):
+ item["rank"] = index
+ return leaderboard
+
+
+__all__ = [
+ "aggregate_runs",
+ "build_leaderboard",
+ "compute_final_metric_stable",
+ "compute_steps_to_threshold_first_hit",
+ "compute_steps_to_threshold_sustained",
+]
diff --git a/scripts/benchmark/rl/plots.py b/scripts/benchmark/rl/plots.py
new file mode 100644
index 00000000..e84f6964
--- /dev/null
+++ b/scripts/benchmark/rl/plots.py
@@ -0,0 +1,211 @@
+# ----------------------------------------------------------------------------
+# Copyright (c) 2021-2026 DexForce Technology Co., Ltd.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ----------------------------------------------------------------------------
+
+from __future__ import annotations
+
+from collections import defaultdict
+from math import isnan
+from pathlib import Path
+from statistics import mean
+from typing import Any
+
+COLORS = ["#1768ac", "#f26419", "#2a9134", "#c44536", "#6a4c93", "#1982c4"]
+
+
+def _svg_header(width: int, height: int) -> list[str]:
+ return [
+ f'")
+ return "\n".join(lines)
+
+
+def _bar_chart_svg(
+ title: str,
+ items: list[tuple[str, float]],
+ width: int = 900,
+ height: int = 420,
+) -> str:
+ margin_left = 80
+ margin_right = 20
+ margin_top = 40
+ margin_bottom = 80
+ plot_width = width - margin_left - margin_right
+ plot_height = height - margin_top - margin_bottom
+ values = [value for _, value in items if not isnan(value)] or [1.0]
+ value_max = max(values)
+ if value_max <= 0:
+ value_max = 1.0
+
+ lines = _svg_header(width, height)
+ lines.append(
+ f'{title}'
+ )
+ bar_width = plot_width / max(len(items), 1)
+ for idx, (label, value) in enumerate(items):
+ color = COLORS[idx % len(COLORS)]
+ bar_height = 0.0 if isnan(value) else (value / value_max) * plot_height
+ x = margin_left + idx * bar_width + 10
+ y = margin_top + plot_height - bar_height
+ lines.append(
+ f''
+ )
+ lines.append(
+ f'{label}'
+ )
+ lines.append(
+ f'{value:.3f}'
+ )
+ lines.append("")
+ return "\n".join(lines)
+
+
+def build_plot_artifacts(
+ run_results: list[dict[str, Any]],
+ leaderboard: list[dict[str, Any]],
+ output_dir: str | Path,
+) -> dict[str, str]:
+ """Generate SVG plot artifacts and return named paths."""
+ output = Path(output_dir)
+ output.mkdir(parents=True, exist_ok=True)
+ artifacts: dict[str, str] = {}
+
+ grouped_histories: dict[tuple[str, str], dict[float, list[float]]] = defaultdict(
+ lambda: defaultdict(list)
+ )
+ grouped_rewards: dict[tuple[str, str], dict[float, list[float]]] = defaultdict(
+ lambda: defaultdict(list)
+ )
+ for result in run_results:
+ key = (result["task"], result["algorithm"])
+ for item in result.get("eval_history", []):
+ step = item.get("global_step")
+ success = item.get("eval/success_rate")
+ reward = item.get("eval/avg_reward")
+ if isinstance(step, (int, float)) and isinstance(success, (int, float)):
+ grouped_histories[key][float(step)].append(float(success))
+ if isinstance(step, (int, float)) and isinstance(reward, (int, float)):
+ grouped_rewards[key][float(step)].append(float(reward))
+
+ tasks = sorted({result["task"] for result in run_results})
+ for task in tasks:
+ success_series = {}
+ reward_series = {}
+ for task_name, algorithm in sorted(grouped_histories.keys()):
+ if task_name != task:
+ continue
+ success_series[algorithm] = sorted(
+ (step, mean(values))
+ for step, values in grouped_histories[(task_name, algorithm)].items()
+ )
+ reward_series[algorithm] = sorted(
+ (step, mean(values))
+ for step, values in grouped_rewards[(task_name, algorithm)].items()
+ )
+ if success_series:
+ path = output / f"{task}_success_rate.svg"
+ path.write_text(
+ _line_chart_svg(f"{task} Success Rate", success_series),
+ encoding="utf-8",
+ )
+ artifacts[f"{task}_success_rate"] = str(path)
+ if reward_series:
+ path = output / f"{task}_reward.svg"
+ path.write_text(
+ _line_chart_svg(f"{task} Evaluation Reward", reward_series),
+ encoding="utf-8",
+ )
+ artifacts[f"{task}_reward"] = str(path)
+
+ leaderboard_path = output / "leaderboard_score.svg"
+ leaderboard_path.write_text(
+ _bar_chart_svg(
+ "Leaderboard Score",
+ [(item["algorithm"], float(item["score"])) for item in leaderboard],
+ ),
+ encoding="utf-8",
+ )
+ artifacts["leaderboard_score"] = str(leaderboard_path)
+ return artifacts
+
+
+__all__ = ["build_plot_artifacts"]
diff --git a/scripts/benchmark/rl/reporting.py b/scripts/benchmark/rl/reporting.py
new file mode 100644
index 00000000..635123df
--- /dev/null
+++ b/scripts/benchmark/rl/reporting.py
@@ -0,0 +1,292 @@
+# ----------------------------------------------------------------------------
+# Copyright (c) 2021-2026 DexForce Technology Co., Ltd.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ----------------------------------------------------------------------------
+
+from __future__ import annotations
+
+import math
+from collections import defaultdict
+from datetime import datetime
+from pathlib import Path
+from typing import Any
+
+
+def _fmt(value: Any, digits: int = 3) -> str:
+ if isinstance(value, float):
+ return f"{value:.{digits}f}"
+ return str(value)
+
+
+def _safe_divide(numerator: float, denominator: float) -> float:
+ if denominator <= 0:
+ return float("nan")
+ return numerator / denominator
+
+
+def _sortable_success_rate(item: dict[str, Any]) -> float:
+ value = float(item.get("avg_success_rate", float("nan")))
+ if math.isnan(value):
+ return float("-inf")
+ return value
+
+
+def _build_report_leaderboard_rows(
+ leaderboard: list[dict[str, Any]],
+ aggregate_results: list[dict[str, Any]],
+) -> list[dict[str, Any]]:
+ """Build complete leaderboard rows and sort by overall success rate."""
+ by_algorithm: dict[str, dict[str, Any]] = {}
+ for item in leaderboard:
+ algorithm = str(item.get("algorithm", ""))
+ if not algorithm:
+ continue
+ by_algorithm[algorithm] = dict(item)
+
+ grouped_aggregate: dict[str, list[dict[str, Any]]] = defaultdict(list)
+ for item in aggregate_results:
+ algorithm = str(item.get("algorithm", ""))
+ if not algorithm:
+ continue
+ grouped_aggregate[algorithm].append(item)
+
+ for algorithm, items in grouped_aggregate.items():
+ if algorithm in by_algorithm:
+ continue
+
+ success_values = [
+ float(entry["final_success_rate_mean"])
+ for entry in items
+ if isinstance(entry.get("final_success_rate_mean"), (int, float))
+ and not math.isnan(float(entry["final_success_rate_mean"]))
+ ]
+ stable_success_values = [
+ float(entry["final_success_rate_stable_mean"])
+ for entry in items
+ if isinstance(entry.get("final_success_rate_stable_mean"), (int, float))
+ and not math.isnan(float(entry["final_success_rate_stable_mean"]))
+ ]
+ by_algorithm[algorithm] = {
+ "algorithm": algorithm,
+ "avg_success_rate": (
+ sum(success_values) / len(success_values)
+ if success_values
+ else float("nan")
+ ),
+ "avg_success_rate_stable": (
+ sum(stable_success_values) / len(stable_success_values)
+ if stable_success_values
+ else float("nan")
+ ),
+ "score": (
+ sum(stable_success_values) / len(stable_success_values)
+ if stable_success_values
+ else float("nan")
+ ),
+ "tasks_covered": len(items),
+ }
+
+ return sorted(
+ by_algorithm.values(),
+ key=lambda item: (
+ -_sortable_success_rate(item),
+ str(item.get("algorithm", "")),
+ ),
+ )
+
+
+def generate_markdown_report(
+ run_results: list[dict[str, Any]],
+ aggregate_results: list[dict[str, Any]],
+ leaderboard: list[dict[str, Any]],
+ plot_artifacts: dict[str, str],
+ protocol: dict[str, Any] | None,
+ output_path: str | Path,
+) -> Path:
+ """Write a benchmark markdown report with exactly three tables."""
+ output = Path(output_path)
+ output.parent.mkdir(parents=True, exist_ok=True)
+
+ ordered_runs = sorted(
+ run_results,
+ key=lambda item: (
+ str(item.get("task", "")),
+ str(item.get("algorithm", "")),
+ int(item.get("seed", 0)),
+ ),
+ )
+
+ lines = [
+ "# RL Benchmark Report",
+ "",
+ f"Generated at: {datetime.now().isoformat(timespec='seconds')}",
+ "",
+ "## Benchmark Overview",
+ "",
+ ]
+ if protocol:
+ lines.extend(
+ [
+ f"- device: `{protocol.get('device')}`",
+ f"- headless: `{protocol.get('headless')}`",
+ f"- iterations: `{protocol.get('iterations')}`",
+ f"- buffer_size: `{protocol.get('buffer_size')}`",
+ f"- num_envs: `{protocol.get('num_envs')}`",
+ f"- num_eval_envs: `{protocol.get('num_eval_envs')}`",
+ f"- evaluation_interval: `{protocol.get('evaluation_interval')}`",
+ f"- evaluation_episodes: `{protocol.get('evaluation_episodes')}`",
+ f"- threshold_sustain_count: `{protocol.get('threshold_sustain_count', 3)}`",
+ f"- final_eval_window: `{protocol.get('final_eval_window', 3)}`",
+ "",
+ ]
+ )
+ lines.extend(
+ [
+ "## Time & Memory",
+ "",
+ "| task | algorithm | seed | cost_time_ms | cpu_delta_mb | gpu_delta_mb | peak_gpu_mb | training_fps | env_fps |",
+ "| --- | --- | ---: | ---: | ---: | ---: | ---: | ---: | ---: |",
+ ]
+ )
+ for result in ordered_runs:
+ train_steps = float(result.get("train_steps", float("nan")))
+ training_fps = float(result.get("training_fps", float("nan")))
+ cost_time_ms = _safe_divide(train_steps, training_fps) * 1000.0
+ lines.append(
+ "| {task} | {algorithm} | {seed} | {cost_time_ms} | {cpu_delta} | {gpu_delta} | {peak_gpu} | {train_fps} | {env_fps} |".format(
+ task=result["task"],
+ algorithm=result["algorithm"],
+ seed=result["seed"],
+ cost_time_ms=_fmt(cost_time_ms),
+ cpu_delta=_fmt(result.get("cpu_delta_mb", "n/a")),
+ gpu_delta=_fmt(result.get("gpu_delta_mb", "n/a")),
+ peak_gpu=_fmt(result.get("peak_gpu_memory_mb", float("nan"))),
+ train_fps=_fmt(result.get("training_fps", float("nan"))),
+ env_fps=_fmt(result.get("environment_fps", float("nan")), digits=2),
+ )
+ )
+
+ lines.extend(
+ [
+ "",
+ "## Success & Other Metrics",
+ "",
+ "| task | algorithm | seed | success_rate | stable_success_rate | steps_to_threshold | first_hit | final_reward | final_episode_length |",
+ "| --- | --- | ---: | ---: | ---: | ---: | ---: | ---: | ---: |",
+ ]
+ )
+ for result in ordered_runs:
+ lines.append(
+ "| {task} | {algorithm} | {seed} | {success} | {stable_success} | {steps} | {first_hit} | {reward} | {episode_len} |".format(
+ task=result["task"],
+ algorithm=result["algorithm"],
+ seed=result["seed"],
+ success=_fmt(result.get("final_success_rate", float("nan"))),
+ stable_success=_fmt(
+ result.get("final_success_rate_stable", float("nan"))
+ ),
+ steps=_fmt(result.get("steps_to_success_threshold", float("nan"))),
+ first_hit=_fmt(
+ result.get("steps_to_success_threshold_first_hit", float("nan"))
+ ),
+ reward=_fmt(result.get("final_reward", float("nan"))),
+ episode_len=_fmt(result.get("final_episode_length", float("nan"))),
+ )
+ )
+
+ leaderboard_by_success = _build_report_leaderboard_rows(
+ leaderboard=leaderboard,
+ aggregate_results=aggregate_results,
+ )
+ lines.extend(
+ [
+ "",
+ "## Leaderboard",
+ "",
+ "| rank | algorithm | overall_success_rate | stable_success_rate | score | tasks_covered |",
+ "| ---: | --- | ---: | ---: | ---: | ---: |",
+ ]
+ )
+ for rank, item in enumerate(leaderboard_by_success, start=1):
+ lines.append(
+ "| {rank} | {algorithm} | {success} | {stable_success} | {score} | {tasks} |".format(
+ rank=rank,
+ algorithm=item.get("algorithm", "n/a"),
+ success=_fmt(item.get("avg_success_rate", float("nan"))),
+ stable_success=_fmt(item.get("avg_success_rate_stable", float("nan"))),
+ score=_fmt(item.get("score", float("nan"))),
+ tasks=item.get("tasks_covered", 0),
+ )
+ )
+
+ lines.extend(["", "## Notes", ""])
+ if leaderboard_by_success:
+ top = leaderboard_by_success[0]
+ lines.append(
+ "- Top algorithm by overall success rate: "
+ f"`{top.get('algorithm', 'n/a')}` "
+ f"(success_rate={_fmt(top.get('avg_success_rate', float('nan')))})."
+ )
+ if aggregate_results:
+ lines.append(f"- Aggregate summaries available: `{len(aggregate_results)}`.")
+
+ if plot_artifacts:
+ lines.extend(["", "## Plots", ""])
+ for plot_name, plot_path in sorted(plot_artifacts.items()):
+ relative = Path(plot_path).relative_to(output.parent)
+ lines.append(f"- {plot_name}: })")
+
+ output.write_text("\n".join(lines) + "\n", encoding="utf-8")
+ return output
+
+
+def generate_leaderboard_markdown(
+ leaderboard: list[dict[str, Any]],
+ output_path: str | Path,
+) -> Path:
+ """Write a dedicated leaderboard markdown artifact sorted by success rate."""
+ output = Path(output_path)
+ output.parent.mkdir(parents=True, exist_ok=True)
+ leaderboard_by_success = sorted(
+ leaderboard,
+ key=lambda item: (
+ -_sortable_success_rate(item),
+ str(item.get("algorithm", "")),
+ ),
+ )
+ lines = [
+ "# Benchmark Leaderboard",
+ "",
+ "| Rank | Algorithm | Score | Steps To Threshold (Sustained) | Success Rate Std | Avg Success Rate | Avg Stable Success Rate | Avg Final Reward | Tasks |",
+ "| ---: | --- | ---: | ---: | ---: | ---: | ---: | ---: | ---: |",
+ ]
+ for rank, item in enumerate(leaderboard_by_success, start=1):
+ lines.append(
+ "| {rank} | {algorithm} | {score} | {steps} | {std} | {success} | {stable_success} | {reward} | {tasks} |".format(
+ rank=rank,
+ algorithm=item["algorithm"],
+ score=_fmt(item.get("score", float("nan"))),
+ steps=_fmt(item.get("steps_to_success_threshold", float("nan"))),
+ std=_fmt(item.get("success_rate_std", float("nan"))),
+ success=_fmt(item.get("avg_success_rate", float("nan"))),
+ stable_success=_fmt(item.get("avg_success_rate_stable", float("nan"))),
+ reward=_fmt(item.get("avg_final_reward", float("nan"))),
+ tasks=item.get("tasks_covered", 0),
+ )
+ )
+ output.write_text("\n".join(lines) + "\n", encoding="utf-8")
+ return output
+
+
+__all__ = ["generate_leaderboard_markdown", "generate_markdown_report"]
diff --git a/scripts/benchmark/rl/run_benchmark.py b/scripts/benchmark/rl/run_benchmark.py
new file mode 100644
index 00000000..bd85e12f
--- /dev/null
+++ b/scripts/benchmark/rl/run_benchmark.py
@@ -0,0 +1,106 @@
+# ----------------------------------------------------------------------------
+# Copyright (c) 2021-2026 DexForce Technology Co., Ltd.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ----------------------------------------------------------------------------
+
+"""Run RL benchmark training/evaluation and generate one markdown report.
+
+Run: python -m scripts.benchmark.rl.run_benchmark
+"""
+
+from __future__ import annotations
+
+import argparse
+
+from .runner import BenchmarkRunner
+
+
+def parse_args() -> argparse.Namespace:
+ """Parse CLI arguments for full benchmark execution."""
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--tasks", nargs="*", default=None)
+ parser.add_argument("--algorithms", nargs="*", default=None)
+ parser.add_argument("--seeds", nargs="*", type=int, default=None)
+ parser.add_argument("--suite", type=str, default="default")
+ parser.add_argument(
+ "--output-root", type=str, default="scripts/benchmark/rl/reports"
+ )
+ parser.add_argument("--device", type=str, default=None)
+ parser.add_argument("--iterations", type=int, default=None)
+ parser.add_argument("--buffer-size", type=int, default=None)
+ parser.add_argument("--evaluation-interval", type=int, default=None)
+ parser.add_argument("--evaluation-episodes", type=int, default=None)
+ parser.add_argument("--num-envs", type=int, default=None)
+ parser.add_argument("--num-eval-envs", type=int, default=None)
+ parser.add_argument("--headless", action="store_true")
+ parser.add_argument("--skip-existing", action="store_true")
+ parser.add_argument("--rebuild-report-only", action="store_true")
+ return parser.parse_args()
+
+
+def main() -> None:
+ """Train, evaluate, aggregate, and report benchmark results."""
+ args = parse_args()
+ overrides = {
+ key: value
+ for key, value in {
+ "device": args.device,
+ "iterations": args.iterations,
+ "buffer_size": args.buffer_size,
+ "evaluation_interval": args.evaluation_interval,
+ "evaluation_episodes": args.evaluation_episodes,
+ "num_envs": args.num_envs,
+ "num_eval_envs": args.num_eval_envs,
+ "headless": args.headless if args.headless else None,
+ }.items()
+ if value is not None
+ }
+ runner = BenchmarkRunner(
+ tasks=args.tasks,
+ algorithms=args.algorithms,
+ seeds=args.seeds,
+ suite=args.suite,
+ output_root=args.output_root,
+ overrides=overrides,
+ )
+
+ if args.rebuild_report_only:
+ run_results = runner.collect_existing_run_results()
+ if not run_results:
+ training_runs = runner.collect_existing_training_runs()
+ if training_runs:
+ run_results = runner.run_evaluation(training_runs)
+ else:
+ raise SystemExit(
+ "No compatible existing benchmark results were found for the requested jobs under "
+ f"{runner.output_root / 'runs'}. "
+ "Run once without --rebuild-report-only to generate artifacts, "
+ "or pass --output-root to the directory containing existing runs."
+ )
+ else:
+ existing_results = (
+ runner.collect_existing_run_results() if args.skip_existing else []
+ )
+ training_runs = runner.run_training(skip_existing=args.skip_existing)
+ new_results = runner.run_evaluation(training_runs)
+ run_results = runner.merge_run_results(existing_results, new_results)
+
+ aggregate_result = runner.aggregate_results(run_results)
+ leaderboard = runner.update_leaderboard(aggregate_result, run_results)
+ report_path = runner.generate_report(run_results, aggregate_result, leaderboard)
+ print(f"Markdown report saved: {report_path}")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/scripts/benchmark/rl/runner.py b/scripts/benchmark/rl/runner.py
new file mode 100644
index 00000000..84dcda87
--- /dev/null
+++ b/scripts/benchmark/rl/runner.py
@@ -0,0 +1,415 @@
+# ----------------------------------------------------------------------------
+# Copyright (c) 2021-2026 DexForce Technology Co., Ltd.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ----------------------------------------------------------------------------
+
+from __future__ import annotations
+
+import json
+from copy import deepcopy
+from pathlib import Path
+from typing import Any
+
+from .config import deep_update, load_algorithm_spec, load_suite_spec, load_task_spec
+from .metrics import (
+ aggregate_runs,
+ build_leaderboard,
+ compute_final_metric_stable,
+ compute_steps_to_threshold_first_hit,
+ compute_steps_to_threshold_sustained,
+)
+from .plots import build_plot_artifacts
+from .reporting import generate_leaderboard_markdown, generate_markdown_report
+from .runtime import dump_json, evaluate_checkpoint, train_with_config
+
+
+class BenchmarkRunner:
+ """Coordinate benchmark training, evaluation, aggregation, and reporting."""
+
+ def __init__(
+ self,
+ tasks: list[str] | None = None,
+ algorithms: list[str] | None = None,
+ seeds: list[int] | None = None,
+ suite: str = "default",
+ output_root: str | Path = "benchmark/reports",
+ overrides: dict[str, Any] | None = None,
+ ) -> None:
+ suite_spec = load_suite_spec(suite)
+ self.suite = suite
+ self.tasks = tasks or list(suite_spec["tasks"])
+ self.algorithms = algorithms or list(suite_spec["algorithms"])
+ self.seeds = seeds or list(suite_spec["seeds"])
+ self.protocol = deepcopy(suite_spec.get("protocol", {}))
+ if overrides:
+ self.protocol = deep_update(self.protocol, overrides)
+ self.output_root = Path(output_root)
+
+ def build_run_config(
+ self,
+ task_name: str,
+ algorithm_name: str,
+ seed: int,
+ ) -> dict[str, Any]:
+ task_spec = load_task_spec(task_name)
+ algorithm_spec = load_algorithm_spec(algorithm_name)
+
+ cfg = deep_update(task_spec["base_config"], algorithm_spec["config"])
+ cfg["trainer"]["exp_name"] = f"{task_name}_{algorithm_name}_seed{seed}"
+ cfg["trainer"]["seed"] = seed
+ train_eval_enabled = bool(task_spec.get("train_eval_enabled", True))
+ cfg["trainer"]["enable_eval"] = train_eval_enabled
+ if train_eval_enabled:
+ cfg["trainer"]["eval_freq"] = int(self.protocol["evaluation_interval"])
+ cfg["trainer"]["num_eval_episodes"] = int(
+ self.protocol["evaluation_episodes"]
+ )
+ cfg["trainer"]["iterations"] = int(self.protocol["iterations"])
+ cfg["trainer"]["buffer_size"] = int(self.protocol["buffer_size"])
+ cfg["trainer"]["num_envs"] = int(self.protocol["num_envs"])
+ cfg["trainer"]["num_eval_envs"] = int(self.protocol["num_eval_envs"])
+ cfg["trainer"]["device"] = str(self.protocol["device"])
+ cfg["trainer"]["headless"] = bool(self.protocol["headless"])
+ cfg["trainer"]["save_freq"] = int(self.protocol["save_interval"])
+ cfg["trainer"]["use_wandb"] = False
+ return cfg
+
+ def _iter_jobs(self) -> list[tuple[str, str, int]]:
+ jobs = []
+ for task_name in self.tasks:
+ for algorithm_name in self.algorithms:
+ for seed in self.seeds:
+ jobs.append((task_name, algorithm_name, seed))
+ return jobs
+
+ def _run_dir(self, task_name: str, algorithm_name: str, seed: int) -> Path:
+ return self.output_root / "runs" / task_name / algorithm_name / f"seed_{seed}"
+
+ @staticmethod
+ def _job_key(
+ task_name: str, algorithm_name: str, seed: int
+ ) -> tuple[str, str, int]:
+ return (task_name, algorithm_name, int(seed))
+
+ @staticmethod
+ def _load_json_artifact(path: str | Path) -> dict[str, Any] | None:
+ artifact_path = Path(path)
+ if not artifact_path.exists():
+ return None
+ data = json.loads(artifact_path.read_text(encoding="utf-8"))
+ if not isinstance(data, dict):
+ raise TypeError(
+ f"Expected JSON object at {artifact_path}, got {type(data)!r}."
+ )
+ return data
+
+ @staticmethod
+ def _record_matches_job(
+ record: dict[str, Any],
+ task_name: str,
+ algorithm_name: str,
+ seed: int,
+ ) -> bool:
+ return (
+ record.get("task") == task_name
+ and record.get("algorithm") == algorithm_name
+ and int(record.get("seed", -1)) == int(seed)
+ )
+
+ @staticmethod
+ def _protocol_from_run_config(run_config: dict[str, Any]) -> dict[str, Any]:
+ trainer = run_config.get("trainer", {})
+ return {
+ "device": trainer.get("device"),
+ "headless": trainer.get("headless"),
+ "iterations": trainer.get("iterations"),
+ "buffer_size": trainer.get("buffer_size"),
+ "num_envs": trainer.get("num_envs"),
+ "num_eval_envs": trainer.get("num_eval_envs"),
+ "evaluation_interval": trainer.get("eval_freq"),
+ "evaluation_episodes": trainer.get("num_eval_episodes"),
+ }
+
+ def _expected_protocol_for_job(
+ self,
+ task_name: str,
+ algorithm_name: str,
+ seed: int,
+ ) -> dict[str, Any]:
+ return self._protocol_from_run_config(
+ self.build_run_config(task_name, algorithm_name, seed)
+ )
+
+ def _artifact_is_compatible(
+ self,
+ artifact: dict[str, Any],
+ task_name: str,
+ algorithm_name: str,
+ seed: int,
+ run_dir: Path,
+ ) -> bool:
+ artifact_protocol = artifact.get("protocol")
+ if isinstance(artifact_protocol, dict):
+ return artifact_protocol == self.protocol
+ run_config = self._load_json_artifact(run_dir / "run_config.json")
+ if run_config is None:
+ return False
+ return self._protocol_from_run_config(
+ run_config
+ ) == self._expected_protocol_for_job(task_name, algorithm_name, seed)
+
+ def _load_existing_training_record(
+ self,
+ task_name: str,
+ algorithm_name: str,
+ seed: int,
+ ) -> dict[str, Any] | None:
+ run_dir = self._run_dir(task_name, algorithm_name, seed)
+ record = self._load_json_artifact(run_dir / "train_result.json")
+ if record is None:
+ return None
+ if not self._record_matches_job(record, task_name, algorithm_name, seed):
+ return None
+ if not self._artifact_is_compatible(
+ record, task_name, algorithm_name, seed, run_dir
+ ):
+ return None
+ checkpoint_path = record.get("checkpoint_path")
+ if not checkpoint_path or not Path(checkpoint_path).exists():
+ return None
+ return record
+
+ def collect_existing_run_results(self) -> list[dict[str, Any]]:
+ """Load compatible existing result artifacts for the requested jobs."""
+ results: list[dict[str, Any]] = []
+ for task_name, algorithm_name, seed in self._iter_jobs():
+ run_dir = self._run_dir(task_name, algorithm_name, seed)
+ record = self._load_json_artifact(run_dir / "result.json")
+ if record is None:
+ continue
+ if not self._record_matches_job(record, task_name, algorithm_name, seed):
+ continue
+ if not self._artifact_is_compatible(
+ record, task_name, algorithm_name, seed, run_dir
+ ):
+ continue
+ results.append(record)
+ return results
+
+ def collect_existing_training_runs(self) -> list[dict[str, Any]]:
+ """Load compatible existing training artifacts for the requested jobs."""
+ records: list[dict[str, Any]] = []
+ for task_name, algorithm_name, seed in self._iter_jobs():
+ record = self._load_existing_training_record(
+ task_name, algorithm_name, seed
+ )
+ if record is not None:
+ records.append(record)
+ return records
+
+ def merge_run_results(
+ self,
+ *result_sets: list[dict[str, Any]],
+ ) -> list[dict[str, Any]]:
+ """Merge multiple run result lists, preferring later duplicates."""
+ merged: dict[tuple[str, str, int], dict[str, Any]] = {}
+ for result_set in result_sets:
+ for record in result_set:
+ key = self._job_key(
+ str(record["task"]),
+ str(record["algorithm"]),
+ int(record["seed"]),
+ )
+ merged[key] = record
+ return [
+ merged[key]
+ for key in sorted(
+ merged.keys(), key=lambda item: (item[0], item[1], item[2])
+ )
+ ]
+
+ def run_training(self, skip_existing: bool = False) -> list[dict[str, Any]]:
+ """Run benchmark training and store per-run training artifacts."""
+ training_runs: list[dict[str, Any]] = []
+ existing_result_keys = set()
+ if skip_existing:
+ existing_result_keys = {
+ self._job_key(item["task"], item["algorithm"], item["seed"])
+ for item in self.collect_existing_run_results()
+ }
+ for task_name, algorithm_name, seed in self._iter_jobs():
+ run_dir = self._run_dir(task_name, algorithm_name, seed)
+ if (
+ skip_existing
+ and self._job_key(task_name, algorithm_name, seed)
+ in existing_result_keys
+ ):
+ continue
+ if skip_existing:
+ existing_training = self._load_existing_training_record(
+ task_name, algorithm_name, seed
+ )
+ if existing_training is not None:
+ training_runs.append(existing_training)
+ continue
+
+ task_spec = load_task_spec(task_name)
+ run_config = self.build_run_config(task_name, algorithm_name, seed)
+ dump_json(run_config, run_dir / "run_config.json")
+ train_summary = train_with_config(run_config, run_dir)
+ training_record = {
+ "task": task_name,
+ "env_id": task_spec["env_id"],
+ "algorithm": algorithm_name,
+ "seed": seed,
+ "suite": self.suite,
+ "protocol": deepcopy(self.protocol),
+ "train_steps": int(train_summary["global_step"]),
+ "training_fps": train_summary["training_fps"],
+ "peak_gpu_memory_mb": train_summary["peak_gpu_memory_mb"],
+ "checkpoint_path": train_summary["checkpoint_path"],
+ "output_dir": train_summary["output_dir"],
+ "eval_history": train_summary.get("eval_history", []),
+ "train_history": train_summary.get("train_history", []),
+ }
+ dump_json(training_record, run_dir / "train_result.json")
+ training_runs.append(training_record)
+ return training_runs
+
+ def run_evaluation(
+ self, training_runs: list[dict[str, Any]]
+ ) -> list[dict[str, Any]]:
+ """Evaluate trained checkpoints and write final per-run benchmark results."""
+ results: list[dict[str, Any]] = []
+ for training_record in training_runs:
+ task_name = training_record["task"]
+ algorithm_name = training_record["algorithm"]
+ seed = training_record["seed"]
+ task_spec = load_task_spec(task_name)
+ run_dir = Path(training_record["output_dir"])
+ run_config = self.build_run_config(task_name, algorithm_name, seed)
+ dump_json(run_config, run_dir / "run_config.json")
+ eval_summary = evaluate_checkpoint(
+ cfg_json=run_config,
+ checkpoint_path=training_record["checkpoint_path"],
+ num_episodes=int(self.protocol["evaluation_episodes"]),
+ num_envs=int(self.protocol["num_eval_envs"]),
+ )
+ result = {
+ "task": task_name,
+ "env_id": task_spec["env_id"],
+ "algorithm": algorithm_name,
+ "seed": seed,
+ "suite": self.suite,
+ "protocol": deepcopy(self.protocol),
+ "train_steps": training_record["train_steps"],
+ "final_reward": eval_summary["avg_reward"],
+ "final_success_rate": eval_summary["success_rate"],
+ "final_episode_length": eval_summary["avg_episode_length"],
+ "training_fps": training_record["training_fps"],
+ "environment_fps": eval_summary["environment_fps"],
+ "peak_gpu_memory_mb": training_record["peak_gpu_memory_mb"],
+ "checkpoint_path": training_record["checkpoint_path"],
+ "output_dir": training_record["output_dir"],
+ "eval_history": training_record.get("eval_history", []),
+ "train_history": training_record.get("train_history", []),
+ }
+ threshold = task_spec.get("success_threshold", 0.8)
+ sustain_count = int(self.protocol.get("threshold_sustain_count", 3))
+ stable_eval_window = int(self.protocol.get("final_eval_window", 3))
+ result["final_success_rate_stable"] = compute_final_metric_stable(
+ training_record.get("eval_history", []),
+ metric_key="eval/success_rate",
+ window_size=stable_eval_window,
+ )
+ result["steps_to_success_threshold_first_hit"] = (
+ compute_steps_to_threshold_first_hit(
+ training_record.get("eval_history", []),
+ metric_key="eval/success_rate",
+ threshold=float(threshold),
+ )
+ )
+ result["steps_to_success_threshold"] = compute_steps_to_threshold_sustained(
+ training_record.get("eval_history", []),
+ metric_key="eval/success_rate",
+ threshold=float(threshold),
+ sustain_count=sustain_count,
+ )
+ result["final_metrics"] = eval_summary["metrics"]
+ dump_json(result, run_dir / "result.json")
+ results.append(result)
+ return results
+
+ def aggregate_results(
+ self, run_results: list[dict[str, Any]]
+ ) -> list[dict[str, Any]]:
+ """Aggregate multiple seeds into task-algorithm summaries."""
+ return aggregate_runs(run_results)
+
+ def update_leaderboard(
+ self,
+ aggregate_result: list[dict[str, Any]],
+ run_results: list[dict[str, Any]],
+ ) -> list[dict[str, Any]]:
+ """Build and persist leaderboard artifacts."""
+ leaderboard = build_leaderboard(aggregate_result, run_results=run_results)
+ leaderboard_dir = self.output_root / "leaderboard"
+ dump_json({"leaderboard": leaderboard}, leaderboard_dir / "leaderboard.json")
+ generate_leaderboard_markdown(
+ leaderboard=leaderboard,
+ output_path=leaderboard_dir / "leaderboard.md",
+ )
+ return leaderboard
+
+ def generate_report(
+ self,
+ run_results: list[dict[str, Any]],
+ aggregate_result: list[dict[str, Any]],
+ leaderboard: list[dict[str, Any]] | None = None,
+ ) -> Path:
+ """Create a markdown benchmark report and result json files."""
+ leaderboard = leaderboard or self.update_leaderboard(
+ aggregate_result, run_results
+ )
+ plot_artifacts = build_plot_artifacts(
+ run_results=run_results,
+ leaderboard=leaderboard,
+ output_dir=self.output_root / "plots",
+ )
+ dump_json({"runs": run_results}, self.output_root / "benchmark_runs.json")
+ dump_json(
+ {"aggregate": aggregate_result},
+ self.output_root / "benchmark_summary.json",
+ )
+ dump_json(
+ {
+ "suite": self.suite,
+ "tasks": self.tasks,
+ "algorithms": self.algorithms,
+ "seeds": self.seeds,
+ "protocol": self.protocol,
+ },
+ self.output_root / "benchmark_protocol.json",
+ )
+ return generate_markdown_report(
+ run_results=run_results,
+ aggregate_results=aggregate_result,
+ leaderboard=leaderboard,
+ plot_artifacts=plot_artifacts,
+ protocol=self.protocol,
+ output_path=self.output_root / "benchmark_report.md",
+ )
+
+
+__all__ = ["BenchmarkRunner"]
diff --git a/scripts/benchmark/rl/runtime.py b/scripts/benchmark/rl/runtime.py
new file mode 100644
index 00000000..666880f9
--- /dev/null
+++ b/scripts/benchmark/rl/runtime.py
@@ -0,0 +1,441 @@
+# ----------------------------------------------------------------------------
+# Copyright (c) 2021-2026 DexForce Technology Co., Ltd.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ----------------------------------------------------------------------------
+
+from __future__ import annotations
+
+import json
+import time
+from copy import deepcopy
+from pathlib import Path
+from typing import Any
+
+import numpy as np
+import torch
+from tensordict import TensorDict
+from torch.utils.tensorboard import SummaryWriter
+
+from embodichain.agents.rl.algo import build_algo
+from embodichain.agents.rl.models import build_mlp_from_cfg, build_policy
+from embodichain.agents.rl.utils import dict_to_tensordict, flatten_dict_observation
+from embodichain.agents.rl.utils.trainer import Trainer
+from embodichain.lab.gym.envs.managers.cfg import EventCfg
+from embodichain.lab.gym.envs.tasks.rl import build_env
+from embodichain.lab.gym.utils.gym_utils import DEFAULT_MANAGER_MODULES, config_to_cfg
+from embodichain.lab.sim import SimulationManagerCfg
+from embodichain.utils.module_utils import find_function_from_modules
+from embodichain.utils.utility import load_json
+
+EVENT_MODULES = [
+ "embodichain.lab.gym.envs.managers.randomization",
+ "embodichain.lab.gym.envs.managers.record",
+ "embodichain.lab.gym.envs.managers.events",
+]
+
+
+def resolve_device(device_str: str) -> torch.device:
+ """Resolve a runtime device string into a validated torch device."""
+ device = torch.device(device_str)
+ if device.type == "cuda":
+ if not torch.cuda.is_available():
+ raise ValueError("CUDA requested but no CUDA device is available.")
+ index = (
+ device.index if device.index is not None else torch.cuda.current_device()
+ )
+ if index < 0 or index >= torch.cuda.device_count():
+ raise ValueError(f"CUDA device index {index} is out of range.")
+ torch.cuda.set_device(index)
+ return torch.device(f"cuda:{index}")
+ if device.type != "cpu":
+ raise ValueError(f"Unsupported device type: {device.type}")
+ return device
+
+
+def set_random_seed(seed: int, device: torch.device) -> None:
+ """Set deterministic random seeds for numpy and torch."""
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+ torch.backends.cudnn.deterministic = True
+ if device.type == "cuda":
+ torch.cuda.manual_seed_all(seed)
+ torch.cuda.reset_peak_memory_stats(device)
+
+
+def _parse_event_cfg(events_dict: dict[str, Any]) -> dict[str, EventCfg]:
+ parsed: dict[str, EventCfg] = {}
+ for event_name, event_info in events_dict.items():
+ event_func = find_function_from_modules(
+ event_info["func"], EVENT_MODULES, raise_if_not_found=True
+ )
+ parsed[event_name] = EventCfg(
+ func=event_func,
+ mode=event_info.get("mode", "interval"),
+ params=event_info.get("params", {}),
+ interval_step=event_info.get("interval_step", 1),
+ )
+ return parsed
+
+
+def _build_env_cfg(
+ gym_config_path: str,
+ num_envs: int | None,
+ headless: bool,
+ device: torch.device,
+ gpu_id: int,
+):
+ gym_config_data = load_json(gym_config_path)
+ gym_env_cfg = config_to_cfg(
+ gym_config_data, manager_modules=DEFAULT_MANAGER_MODULES
+ )
+ if num_envs is not None:
+ gym_env_cfg.num_envs = int(num_envs)
+ if gym_env_cfg.sim_cfg is None:
+ gym_env_cfg.sim_cfg = SimulationManagerCfg()
+ gym_env_cfg.seed = getattr(gym_env_cfg, "seed", None)
+ gym_env_cfg.sim_cfg.headless = headless
+ gym_env_cfg.sim_cfg.gpu_id = gpu_id
+ gym_env_cfg.sim_cfg.sim_device = device
+ return gym_config_data, gym_env_cfg
+
+
+def _allocate_eval_rollout_buffer(env, policy, device: torch.device) -> TensorDict:
+ """Allocate a small RL-style rollout buffer for evaluation-only environments."""
+ rollout_len = 2
+ return TensorDict(
+ {
+ "obs": torch.zeros(
+ env.num_envs,
+ rollout_len + 1,
+ policy.obs_dim,
+ dtype=torch.float32,
+ device=device,
+ ),
+ "action": torch.zeros(
+ env.num_envs,
+ rollout_len + 1,
+ policy.action_dim,
+ dtype=torch.float32,
+ device=device,
+ ),
+ "sample_log_prob": torch.zeros(
+ env.num_envs,
+ rollout_len + 1,
+ dtype=torch.float32,
+ device=device,
+ ),
+ "value": torch.zeros(
+ env.num_envs,
+ rollout_len + 1,
+ dtype=torch.float32,
+ device=device,
+ ),
+ "reward": torch.zeros(
+ env.num_envs,
+ rollout_len + 1,
+ dtype=torch.float32,
+ device=device,
+ ),
+ "done": torch.zeros(
+ env.num_envs,
+ rollout_len + 1,
+ dtype=torch.bool,
+ device=device,
+ ),
+ "terminated": torch.zeros(
+ env.num_envs,
+ rollout_len + 1,
+ dtype=torch.bool,
+ device=device,
+ ),
+ "truncated": torch.zeros(
+ env.num_envs,
+ rollout_len + 1,
+ dtype=torch.bool,
+ device=device,
+ ),
+ },
+ batch_size=[env.num_envs, rollout_len + 1],
+ device=device,
+ )
+
+
+def _compact_eval_rollout_buffer(env, rollout_buffer: TensorDict) -> None:
+ """Keep only the previous transition needed by rollout-dependent eval rewards."""
+ if getattr(env, "current_rollout_step", 0) < 2:
+ return
+ for key in ("action", "reward", "done", "terminated", "truncated"):
+ rollout_buffer[key][:, 0].copy_(rollout_buffer[key][:, 1])
+ rollout_buffer[key][:, 1:].zero_()
+ env.current_rollout_step = 1
+
+
+def build_policy_from_env(policy_block: dict[str, Any], env, device: torch.device):
+ """Build a policy using the current environment spaces."""
+ sample_obs, _ = env.reset()
+ sample_obs_td = dict_to_tensordict(sample_obs, device)
+ obs_dim = flatten_dict_observation(sample_obs_td).shape[-1]
+ flat_obs_space = env.flattened_observation_space
+ env_action_dim = env.action_space.shape[-1]
+
+ policy_name = policy_block["name"].lower()
+ if policy_name == "actor_critic":
+ actor = build_mlp_from_cfg(policy_block["actor"], obs_dim, env_action_dim)
+ critic = build_mlp_from_cfg(policy_block["critic"], obs_dim, 1)
+ return build_policy(
+ policy_block,
+ flat_obs_space,
+ env.action_space,
+ device,
+ actor=actor,
+ critic=critic,
+ )
+ if policy_name == "actor_only":
+ actor = build_mlp_from_cfg(policy_block["actor"], obs_dim, env_action_dim)
+ return build_policy(
+ policy_block,
+ flat_obs_space,
+ env.action_space,
+ device,
+ actor=actor,
+ )
+ return build_policy(policy_block, flat_obs_space, env.action_space, device)
+
+
+def train_with_config(
+ cfg_json: dict[str, Any],
+ output_dir: str | Path,
+) -> dict[str, Any]:
+ """Train an RL configuration and return a structured summary."""
+ trainer_cfg = deepcopy(cfg_json["trainer"])
+ policy_block = deepcopy(cfg_json["policy"])
+ algo_block = deepcopy(cfg_json["algorithm"])
+
+ device = resolve_device(trainer_cfg.get("device", "cpu"))
+ seed = int(trainer_cfg.get("seed", 1))
+ set_random_seed(seed, device)
+
+ output_root = Path(output_dir)
+ log_dir = output_root / "logs"
+ checkpoint_dir = output_root / "checkpoints"
+ log_dir.mkdir(parents=True, exist_ok=True)
+ checkpoint_dir.mkdir(parents=True, exist_ok=True)
+
+ gym_config_data, gym_env_cfg = _build_env_cfg(
+ gym_config_path=trainer_cfg["gym_config"],
+ num_envs=trainer_cfg.get("num_envs"),
+ headless=bool(trainer_cfg.get("headless", True)),
+ device=device,
+ gpu_id=int(trainer_cfg.get("gpu_id", 0)),
+ )
+ env = None
+ eval_env = None
+ writer = SummaryWriter(str(log_dir))
+ try:
+ env = build_env(gym_config_data["id"], base_env_cfg=gym_env_cfg)
+
+ enable_eval = bool(trainer_cfg.get("enable_eval", True))
+ if enable_eval:
+ eval_gym_env_cfg = deepcopy(gym_env_cfg)
+ eval_gym_env_cfg.num_envs = int(
+ trainer_cfg.get("num_eval_envs", min(4, gym_env_cfg.num_envs))
+ )
+ eval_gym_env_cfg.sim_cfg.headless = True
+ eval_env = build_env(gym_config_data["id"], base_env_cfg=eval_gym_env_cfg)
+
+ policy = build_policy_from_env(policy_block, env, device)
+ algo = build_algo(algo_block["name"], algo_block["cfg"], policy, device)
+
+ events_dict = trainer_cfg.get("events", {})
+ trainer = Trainer(
+ policy=policy,
+ env=env,
+ algorithm=algo,
+ buffer_size=int(trainer_cfg.get("buffer_size", 2048)),
+ batch_size=int(algo_block["cfg"]["batch_size"]),
+ writer=writer,
+ eval_freq=int(trainer_cfg.get("eval_freq", 0)) if enable_eval else 0,
+ save_freq=int(trainer_cfg.get("save_freq", 0)) or 10**18,
+ checkpoint_dir=str(checkpoint_dir),
+ exp_name=str(trainer_cfg.get("exp_name", "benchmark_run")),
+ use_wandb=False,
+ eval_env=eval_env,
+ event_cfg=_parse_event_cfg(events_dict.get("train", {})),
+ eval_event_cfg=(
+ _parse_event_cfg(events_dict.get("eval", {})) if enable_eval else {}
+ ),
+ num_eval_episodes=int(trainer_cfg.get("num_eval_episodes", 5)),
+ )
+
+ total_steps = (
+ int(trainer_cfg.get("iterations", 1))
+ * int(trainer_cfg.get("buffer_size", 2048))
+ * int(env.num_envs)
+ )
+ start_time = time.perf_counter()
+ summary = trainer.train(total_steps)
+ wall_time = time.perf_counter() - start_time
+ checkpoint_path = trainer.save_checkpoint()
+ finally:
+ writer.close()
+ if eval_env is not None:
+ eval_env.close()
+ if env is not None:
+ env.close()
+
+ peak_gpu_memory_mb = 0.0
+ if device.type == "cuda":
+ peak_gpu_memory_mb = torch.cuda.max_memory_allocated(device=device) / (
+ 1024.0 * 1024.0
+ )
+
+ summary.update(
+ {
+ "checkpoint_path": checkpoint_path,
+ "output_dir": str(output_root),
+ "wall_time_sec": float(wall_time),
+ "training_fps": float(total_steps / max(wall_time, 1e-6)),
+ "peak_gpu_memory_mb": float(peak_gpu_memory_mb),
+ }
+ )
+ return summary
+
+
+def evaluate_checkpoint(
+ cfg_json: dict[str, Any],
+ checkpoint_path: str | Path,
+ num_episodes: int,
+ num_envs: int | None = None,
+) -> dict[str, Any]:
+ """Evaluate a checkpoint deterministically and collect task metrics."""
+ trainer_cfg = deepcopy(cfg_json["trainer"])
+ policy_block = deepcopy(cfg_json["policy"])
+
+ device = resolve_device(trainer_cfg.get("device", "cpu"))
+ gym_config_data, gym_env_cfg = _build_env_cfg(
+ gym_config_path=trainer_cfg["gym_config"],
+ num_envs=num_envs if num_envs is not None else trainer_cfg.get("num_eval_envs"),
+ headless=True,
+ device=device,
+ gpu_id=int(trainer_cfg.get("gpu_id", 0)),
+ )
+ env = None
+ try:
+ env = build_env(gym_config_data["id"], base_env_cfg=gym_env_cfg)
+ policy = build_policy_from_env(policy_block, env, device)
+ eval_rollout_buffer = None
+ if hasattr(env, "set_rollout_buffer"):
+ eval_rollout_buffer = _allocate_eval_rollout_buffer(env, policy, device)
+
+ checkpoint = torch.load(checkpoint_path, map_location=device)
+ policy.load_state_dict(checkpoint["policy"])
+ policy.eval()
+
+ target_episodes = int(num_episodes)
+ completed = 0
+ cumulative_reward = torch.zeros(
+ env.num_envs, dtype=torch.float32, device=device
+ )
+ step_count = torch.zeros(env.num_envs, dtype=torch.int32, device=device)
+
+ returns: list[float] = []
+ lengths: list[int] = []
+ successes: list[float] = []
+ metric_values: dict[str, list[float]] = {}
+ env_step_count = 0
+ env_step_time = 0.0
+
+ if eval_rollout_buffer is not None:
+ env.set_rollout_buffer(eval_rollout_buffer)
+ obs, _ = env.reset()
+ while completed < target_episodes:
+ flat_obs = flatten_dict_observation(obs)
+ action_td = TensorDict(
+ {"obs": flat_obs},
+ batch_size=[env.num_envs],
+ device=device,
+ )
+ action_td = policy.get_action(action_td, deterministic=True)
+ action_manager = getattr(env, "action_manager", None)
+ if action_manager is None:
+ action_in = action_td["action"]
+ else:
+ action_in = action_manager.convert_policy_action_to_env_action(
+ action_td["action"]
+ )
+
+ if eval_rollout_buffer is not None:
+ _compact_eval_rollout_buffer(env, eval_rollout_buffer)
+ eval_rollout_buffer["action"][:, env.current_rollout_step].copy_(
+ action_td["action"]
+ )
+ step_start = time.perf_counter()
+ obs, reward, terminated, truncated, info = env.step(action_in)
+ env_step_time += time.perf_counter() - step_start
+ env_step_count += env.num_envs
+
+ done = terminated | truncated
+ cumulative_reward += reward.float()
+ step_count += 1
+
+ newly_done = done.nonzero(as_tuple=False).squeeze(-1)
+ for env_id in newly_done.tolist():
+ if completed >= target_episodes:
+ break
+ returns.append(float(cumulative_reward[env_id].item()))
+ lengths.append(int(step_count[env_id].item()))
+ if "success" in info:
+ successes.append(float(info["success"][env_id].item()))
+ if "metrics" in info:
+ for key, value in info["metrics"].items():
+ metric_values.setdefault(key, []).append(
+ float(value[env_id].item())
+ )
+ cumulative_reward[env_id] = 0.0
+ step_count[env_id] = 0
+ completed += 1
+ finally:
+ if env is not None:
+ env.close()
+
+ return {
+ "num_episodes": completed,
+ "avg_reward": float(np.mean(returns)) if returns else float("nan"),
+ "avg_episode_length": float(np.mean(lengths)) if lengths else float("nan"),
+ "success_rate": float(np.mean(successes)) if successes else float("nan"),
+ "environment_fps": float(env_step_count / max(env_step_time, 1e-6)),
+ "metrics": {
+ key: float(np.mean(values))
+ for key, values in metric_values.items()
+ if values
+ },
+ }
+
+
+def dump_json(data: dict[str, Any], path: str | Path) -> Path:
+ """Write a JSON artifact to disk."""
+ output = Path(path)
+ output.parent.mkdir(parents=True, exist_ok=True)
+ output.write_text(json.dumps(data, indent=2), encoding="utf-8")
+ return output
+
+
+__all__ = [
+ "build_policy_from_env",
+ "dump_json",
+ "evaluate_checkpoint",
+ "resolve_device",
+ "set_random_seed",
+ "train_with_config",
+]
diff --git a/scripts/benchmark/rl/suites/__init__.py b/scripts/benchmark/rl/suites/__init__.py
new file mode 100644
index 00000000..dd650e90
--- /dev/null
+++ b/scripts/benchmark/rl/suites/__init__.py
@@ -0,0 +1,15 @@
+# ----------------------------------------------------------------------------
+# Copyright (c) 2021-2026 DexForce Technology Co., Ltd.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ----------------------------------------------------------------------------
diff --git a/scripts/benchmark/rl/suites/default.yaml b/scripts/benchmark/rl/suites/default.yaml
new file mode 100644
index 00000000..34476006
--- /dev/null
+++ b/scripts/benchmark/rl/suites/default.yaml
@@ -0,0 +1,21 @@
+tasks:
+ - cart_pole
+ - push_cube
+algorithms:
+ - ppo
+ - grpo
+seeds:
+ - 0
+ - 1
+protocol:
+ device: cuda:0
+ headless: true
+ iterations: 200
+ buffer_size: 1024
+ num_envs: 64
+ num_eval_envs: 16
+ evaluation_interval: 200
+ evaluation_episodes: 20
+ threshold_sustain_count: 3
+ final_eval_window: 3
+ save_interval: 200
diff --git a/scripts/benchmark/rl/suites/smoke.yaml b/scripts/benchmark/rl/suites/smoke.yaml
new file mode 100644
index 00000000..4bb1e67f
--- /dev/null
+++ b/scripts/benchmark/rl/suites/smoke.yaml
@@ -0,0 +1,20 @@
+tasks:
+ - cart_pole
+ - push_cube
+algorithms:
+ - ppo
+ - grpo
+seeds:
+ - 0
+protocol:
+ device: cpu
+ headless: true
+ iterations: 10
+ buffer_size: 128
+ num_envs: 32
+ num_eval_envs: 8
+ evaluation_interval: 2
+ evaluation_episodes: 10
+ threshold_sustain_count: 3
+ final_eval_window: 3
+ save_interval: 1000
diff --git a/scripts/benchmark/rl/tasks/__init__.py b/scripts/benchmark/rl/tasks/__init__.py
new file mode 100644
index 00000000..dd650e90
--- /dev/null
+++ b/scripts/benchmark/rl/tasks/__init__.py
@@ -0,0 +1,15 @@
+# ----------------------------------------------------------------------------
+# Copyright (c) 2021-2026 DexForce Technology Co., Ltd.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ----------------------------------------------------------------------------
diff --git a/scripts/benchmark/rl/tasks/cart_pole.yaml b/scripts/benchmark/rl/tasks/cart_pole.yaml
new file mode 100644
index 00000000..8b90a61f
--- /dev/null
+++ b/scripts/benchmark/rl/tasks/cart_pole.yaml
@@ -0,0 +1,20 @@
+name: cart_pole
+env_id: CartPoleRL
+success_threshold: 0.8
+base_config:
+ trainer:
+ gym_config: configs/agents/rl/basic/cart_pole/gym_config.json
+ exp_name: cart_pole
+ device: cpu
+ headless: true
+ gpu_id: 0
+ num_envs: 64
+ iterations: 200
+ buffer_size: 1024
+ enable_eval: true
+ num_eval_envs: 8
+ num_eval_episodes: 10
+ eval_freq: 200
+ save_freq: 200
+ use_wandb: false
+ events: {}
diff --git a/scripts/benchmark/rl/tasks/push_cube.yaml b/scripts/benchmark/rl/tasks/push_cube.yaml
new file mode 100644
index 00000000..3f524685
--- /dev/null
+++ b/scripts/benchmark/rl/tasks/push_cube.yaml
@@ -0,0 +1,21 @@
+name: push_cube
+env_id: PushCubeRL
+success_threshold: 0.6
+train_eval_enabled: false
+base_config:
+ trainer:
+ gym_config: configs/agents/rl/push_cube/gym_config.json
+ exp_name: push_cube
+ device: cpu
+ headless: true
+ gpu_id: 0
+ num_envs: 64
+ iterations: 200
+ buffer_size: 1024
+ enable_eval: true
+ num_eval_envs: 8
+ num_eval_episodes: 10
+ eval_freq: 200
+ save_freq: 200
+ use_wandb: false
+ events: {}
diff --git a/scripts/benchmark/robotics/kinematic_solver/run_benchmark.py b/scripts/benchmark/robotics/kinematic_solver/run_benchmark.py
new file mode 100644
index 00000000..5f4451ae
--- /dev/null
+++ b/scripts/benchmark/robotics/kinematic_solver/run_benchmark.py
@@ -0,0 +1,722 @@
+# ----------------------------------------------------------------------------
+# Copyright (c) 2021-2026 DexForce Technology Co., Ltd.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ----------------------------------------------------------------------------
+
+"""Unified benchmark for OPW and Pytorch kinematic solvers.
+
+Measures IK wall-clock latency, pose accuracy, success rate, and memory usage
+across OPW (Warp CUDA vs CPU) and Pytorch solver (CPU vs optional CUDA).
+Run: python -m scripts.benchmark.robotics.kinematic_solver.run_benchmark
+"""
+
+from __future__ import annotations
+
+import argparse
+import os
+import time
+from datetime import datetime
+from pathlib import Path
+
+import numpy as np
+import psutil
+import torch
+
+from embodichain.data import get_data_path
+from embodichain.lab.sim.solvers.opw_solver import OPWSolverCfg
+from embodichain.lab.sim.solvers.pytorch_solver import PytorchSolver, PytorchSolverCfg
+
+OPW_LOWER_LIMITS = [-2.618, 0.0, -2.967, -1.745, -1.22, -2.0944]
+OPW_UPPER_LIMITS = [2.618, 3.14159, 0.0, 1.745, 1.22, 2.0944]
+
+# TODO: Easy to failed if use full joint range, consider adding a margin to avoid sampling near the joint limits.
+# PYTORCH_LOWER_LIMITS = [-6.2832, -6.2832, -3.1416, -6.2832, -6.2832, -6.2832]
+# PYTORCH_UPPER_LIMITS = [6.2832, 6.2832, 3.1416, 6.2832, 6.2832, 6.2832]
+PYTORCH_LOWER_LIMITS = [1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
+PYTORCH_UPPER_LIMITS = [2.5, 2.5, 2.5, 2.5, 2.5, 2.5]
+
+SAMPLE_SIZES = [100, 1000, 10000]
+SUPPORTED_SOLVERS = ("opw", "pytorch")
+
+
+def _parse_args() -> argparse.Namespace:
+ """Parse command line arguments for selecting benchmark solvers."""
+ parser = argparse.ArgumentParser(
+ description="Run kinematic solver benchmarks for selected solver backends."
+ )
+ parser.add_argument(
+ "--solvers",
+ "-s",
+ nargs="+",
+ choices=(*SUPPORTED_SOLVERS, "all"),
+ default=["all"],
+ help=(
+ "Solvers to benchmark. Use one or more of: opw, pytorch, all. "
+ "Default: all"
+ ),
+ )
+ return parser.parse_args()
+
+
+def _normalize_selected_solvers(selected_solvers: list[str] | None) -> set[str]:
+ """Normalize selected solver names to a canonical set."""
+ if not selected_solvers or "all" in selected_solvers:
+ return set(SUPPORTED_SOLVERS)
+ return {solver for solver in selected_solvers if solver in SUPPORTED_SOLVERS}
+
+
+def _sync_cuda() -> None:
+ """Synchronize CUDA stream when available."""
+ if torch.cuda.is_available():
+ torch.cuda.synchronize()
+
+
+def _reset_peak_gpu_memory() -> None:
+ """Reset PyTorch peak GPU memory stats when CUDA is available."""
+ if torch.cuda.is_available():
+ torch.cuda.reset_peak_memory_stats()
+
+
+def _peak_gpu_memory_mb() -> float:
+ """Return peak GPU memory allocated by PyTorch in MB."""
+ if not torch.cuda.is_available():
+ return 0.0
+ return torch.cuda.max_memory_allocated() / 1024**2
+
+
+def _memory_snapshot() -> dict[str, float]:
+ """Return current process memory usage snapshot in MB."""
+ process = psutil.Process(os.getpid())
+ cpu_mb = process.memory_info().rss / 1024**2
+ gpu_mb = (
+ torch.cuda.memory_allocated() / 1024**2 if torch.cuda.is_available() else 0.0
+ )
+ return {"cpu_mb": cpu_mb, "gpu_mb": gpu_mb}
+
+
+def _format_markdown_table(rows: list[dict[str, object]]) -> list[str]:
+ """Format rows into a markdown table."""
+ if not rows:
+ return ["No data."]
+
+ headers = list(rows[0].keys())
+ lines = [
+ "| " + " | ".join(headers) + " |",
+ "| " + " | ".join(["---"] * len(headers)) + " |",
+ ]
+ for row in rows:
+ lines.append("| " + " | ".join(str(row[h]) for h in headers) + " |")
+ return lines
+
+
+def _build_leaderboard_rows(
+ metric_rows: list[dict[str, object]],
+) -> list[dict[str, object]]:
+ """Aggregate and rank algorithms by overall success rate."""
+ aggregate: dict[str, dict[str, float]] = {}
+ for row in metric_rows:
+ impl = str(row["impl"])
+ if impl not in aggregate:
+ aggregate[impl] = {
+ "success_sum": 0.0,
+ "t_err_sum": 0.0,
+ "r_err_sum": 0.0,
+ "count": 0.0,
+ }
+
+ aggregate[impl]["success_sum"] += float(row["success_rate"])
+ aggregate[impl]["t_err_sum"] += float(row["translation_err_mm"])
+ aggregate[impl]["r_err_sum"] += float(row["rotation_err_deg"])
+ aggregate[impl]["count"] += 1.0
+
+ ranked = sorted(
+ aggregate.items(),
+ key=lambda item: item[1]["success_sum"] / max(item[1]["count"], 1.0),
+ reverse=True,
+ )
+
+ leaderboard_rows: list[dict[str, object]] = []
+ for rank, (algorithm, stats) in enumerate(ranked, start=1):
+ count = max(stats["count"], 1.0)
+ leaderboard_rows.append(
+ {
+ "rank": rank,
+ "algorithm": algorithm,
+ "overall_success_rate": f"{stats['success_sum'] / count:.2%}",
+ "avg_translation_err_mm": f"{stats['t_err_sum'] / count:.6f}",
+ "avg_rotation_err_deg": f"{stats['r_err_sum'] / count:.6f}",
+ }
+ )
+ return leaderboard_rows
+
+
+def _write_markdown_report(
+ benchmark_name: str,
+ perf_rows: list[dict[str, object]],
+ metric_rows: list[dict[str, object]],
+ leaderboard_rows: list[dict[str, object]],
+ notes: list[str] | None = None,
+) -> Path:
+ """Write benchmark results to a markdown report with three tables."""
+ output_dir = Path("outputs/benchmarks")
+ output_dir.mkdir(parents=True, exist_ok=True)
+
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
+ report_path = output_dir / f"{benchmark_name}_{timestamp}.md"
+
+ lines: list[str] = [
+ f"# {benchmark_name} Benchmark Report",
+ "",
+ f"Generated at: {datetime.now().isoformat(timespec='seconds')}",
+ "",
+ "## Time & Memory",
+ "",
+ ]
+ lines.extend(_format_markdown_table(perf_rows))
+ lines.extend(["", "## Success & Other Metrics", ""])
+ lines.extend(_format_markdown_table(metric_rows))
+
+ lines.extend(["", "## Leaderboard", ""])
+ lines.extend(_format_markdown_table(leaderboard_rows))
+
+ if notes:
+ lines.extend(["", "## Notes", ""])
+ lines.extend([f"- {note}" for note in notes])
+
+ report_path.write_text("\n".join(lines) + "\n", encoding="utf-8")
+ return report_path
+
+
+def get_pose_err(
+ matrix_a: np.ndarray | torch.Tensor,
+ matrix_b: np.ndarray | torch.Tensor,
+) -> tuple[torch.Tensor, torch.Tensor]:
+ """Return translation and rotation errors between paired poses.
+
+ Supports either a single 4x4 pose or a batch with shape (N, 4, 4).
+ """
+ tensor_a = torch.as_tensor(matrix_a, dtype=torch.float64)
+ tensor_b = torch.as_tensor(matrix_b, dtype=torch.float64, device=tensor_a.device)
+
+ if tensor_a.ndim == 2:
+ tensor_a = tensor_a.unsqueeze(0)
+ if tensor_b.ndim == 2:
+ tensor_b = tensor_b.unsqueeze(0)
+
+ t_err = torch.linalg.norm(tensor_a[:, :3, 3] - tensor_b[:, :3, 3], dim=-1)
+
+ relative_rot = torch.matmul(
+ tensor_a[:, :3, :3].transpose(-1, -2),
+ tensor_b[:, :3, :3],
+ )
+ trace = torch.diagonal(relative_rot, dim1=-2, dim2=-1).sum(dim=-1)
+ cos_angle = torch.clamp((trace - 1.0) / 2.0, min=-1.0, max=1.0)
+ r_err = torch.arccos(cos_angle)
+ return t_err, r_err
+
+
+def _timed_ik_call(
+ solver, xpos: torch.Tensor, qpos_seed: torch.Tensor, initial_guess: torch.Tensor
+) -> tuple[float, dict[str, float], float, torch.Tensor, torch.Tensor]:
+ """Run a timed IK call and return elapsed seconds, memory deltas, and outputs."""
+ _reset_peak_gpu_memory()
+ mem_before = _memory_snapshot()
+ _sync_cuda()
+
+ start = time.perf_counter()
+ ik_success, ik_qpos = solver.get_ik(
+ xpos,
+ qpos_seed=qpos_seed,
+ initial_guess=initial_guess,
+ )
+ _sync_cuda()
+ elapsed = time.perf_counter() - start
+
+ mem_after = _memory_snapshot()
+ deltas = {
+ "cpu_mb": mem_after["cpu_mb"] - mem_before["cpu_mb"],
+ "gpu_mb": mem_after["gpu_mb"] - mem_before["gpu_mb"],
+ }
+ return elapsed, deltas, _peak_gpu_memory_mb(), ik_success, ik_qpos
+
+
+def _init_pytorch_solver(device: torch.device) -> PytorchSolver:
+ """Initialize Pytorch kinematic solver on the target device."""
+ solver_cfg = PytorchSolverCfg(
+ urdf_path=get_data_path("UniversalRobots/UR10/UR10.urdf"),
+ end_link_name="ee_link",
+ root_link_name="base_link",
+ joint_names=["J1", "J2", "J3", "J4", "J5", "J6"],
+ user_qpos_limits=[PYTORCH_LOWER_LIMITS, PYTORCH_UPPER_LIMITS],
+ )
+ return PytorchSolver(solver_cfg, device=device)
+
+
+def _sample_qpos(
+ n_samples: int,
+ lower_limits: list[float],
+ upper_limits: list[float],
+ margin: float,
+ device: torch.device,
+ dtype: torch.dtype,
+) -> torch.Tensor:
+ """Sample joint positions with margin from lower/upper limits."""
+ qpos_np = np.random.uniform(
+ low=np.array(lower_limits) + margin,
+ high=np.array(upper_limits) - margin,
+ size=(n_samples, 6),
+ ).astype(float)
+ return torch.tensor(qpos_np, device=device, dtype=dtype)
+
+
+def _timed_pytorch_ik_call(
+ solver: PytorchSolver,
+ fk_xpos: torch.Tensor,
+ qpos_seed: torch.Tensor,
+) -> tuple[float, dict[str, float], float, torch.Tensor, torch.Tensor]:
+ """Run a timed Pytorch IK call and return elapsed/memory/outputs."""
+ _reset_peak_gpu_memory()
+ mem_before = _memory_snapshot()
+ _sync_cuda()
+
+ start = time.perf_counter()
+ for i in range(3):
+ if i == 1: # skip first run to avoid initialization overhead
+ start = time.perf_counter()
+ ik_success, ik_qpos = solver.get_ik(
+ fk_xpos,
+ joint_seed=qpos_seed,
+ return_all_solutions=False,
+ )
+ _sync_cuda()
+ elapsed = time.perf_counter() - start
+ elapsed /= 2.0
+
+ mem_after = _memory_snapshot()
+ deltas = {
+ "cpu_mb": mem_after["cpu_mb"] - mem_before["cpu_mb"],
+ "gpu_mb": mem_after["gpu_mb"] - mem_before["gpu_mb"],
+ }
+ return elapsed, deltas, _peak_gpu_memory_mb(), ik_success, ik_qpos[:, 0, :]
+
+
+def check_opw_solver(
+ solver_warp, solver_py_opw, n_samples: int = 1000
+) -> dict[str, float]:
+ """Run Warp and CPU OPW IK/FK checks and return timing, memory, and accuracy."""
+ dof = 6
+ qpos_np = np.random.uniform(
+ low=np.array(OPW_LOWER_LIMITS)
+ + 5.1 / 180.0 * np.pi, # add a margin to avoid sampling near the joint limits
+ high=np.array(OPW_UPPER_LIMITS) + -5.1 / 180.0 * np.pi,
+ size=(n_samples, dof),
+ ).astype(float)
+
+ qpos_cuda = torch.tensor(qpos_np, device=torch.device("cuda"), dtype=torch.float32)
+ xpos_cuda = solver_warp.get_fk(qpos_cuda)
+ qpos_seed = torch.tensor(
+ [0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
+ device=torch.device("cuda"),
+ dtype=torch.float32,
+ )
+
+ (
+ warp_elapsed,
+ warp_mem,
+ warp_peak_gpu,
+ warp_ik_success,
+ warp_ik_qpos,
+ ) = _timed_ik_call(
+ solver=solver_warp,
+ xpos=xpos_cuda,
+ qpos_seed=qpos_seed,
+ initial_guess=qpos_cuda,
+ )
+
+ check_xpos = solver_warp.get_fk(warp_ik_qpos)
+ warp_t_err, warp_r_err = get_pose_err(xpos_cuda, check_xpos)
+ warp_t_mean_err, warp_r_mean_err = (
+ warp_t_err.mean().item(),
+ warp_r_err.mean().item(),
+ )
+
+ xpos_cpu = xpos_cuda.to(torch.device("cpu"))
+ qpos_seed_cpu = qpos_seed.to(torch.device("cpu"))
+ qpos_cpu = qpos_cuda.to(torch.device("cpu"))
+
+ (
+ cpu_elapsed,
+ cpu_mem,
+ cpu_peak_gpu,
+ py_opw_ik_success,
+ py_opw_ik_qpos,
+ ) = _timed_ik_call(
+ solver=solver_py_opw,
+ xpos=xpos_cpu,
+ qpos_seed=qpos_seed_cpu,
+ initial_guess=qpos_cpu,
+ )
+
+ check_xpos = solver_warp.get_fk(py_opw_ik_qpos.to(torch.device("cuda")))
+ py_opw_t_err, py_opw_r_err = get_pose_err(xpos_cpu, check_xpos)
+ py_opw_t_mean_err, py_opw_r_mean_err = (
+ py_opw_t_err.mean().item(),
+ py_opw_r_err.mean().item(),
+ )
+
+ warp_success_rate = float(warp_ik_success.float().mean().item())
+ cpu_success_rate = float(py_opw_ik_success.float().mean().item())
+
+ return {
+ "warp_ms": warp_elapsed * 1000.0,
+ "warp_t_err_mm": warp_t_mean_err * 1000.0,
+ "warp_r_err_deg": warp_r_mean_err * 180.0 / np.pi,
+ "warp_success_rate": warp_success_rate,
+ "warp_cpu_delta_mb": warp_mem["cpu_mb"],
+ "warp_gpu_delta_mb": warp_mem["gpu_mb"],
+ "warp_peak_gpu_mb": warp_peak_gpu,
+ "cpu_ms": cpu_elapsed * 1000.0,
+ "cpu_t_err_mm": py_opw_t_mean_err * 1000.0,
+ "cpu_r_err_deg": py_opw_r_mean_err * 180.0 / np.pi,
+ "cpu_success_rate": cpu_success_rate,
+ "cpu_cpu_delta_mb": cpu_mem["cpu_mb"],
+ "cpu_gpu_delta_mb": cpu_mem["gpu_mb"],
+ "cpu_peak_gpu_mb": cpu_peak_gpu,
+ }
+
+
+def benchmark_pytorch_solver() -> (
+ tuple[list[dict[str, object]], list[dict[str, object]]]
+):
+ """Benchmark Pytorch solver for CPU and optional CUDA implementations."""
+ perf_rows: list[dict[str, object]] = []
+ metric_rows: list[dict[str, object]] = []
+
+ cpu_solver = _init_pytorch_solver(device=torch.device("cpu"))
+ has_cuda = torch.cuda.is_available()
+ cuda_solver = (
+ _init_pytorch_solver(device=torch.device("cuda")) if has_cuda else None
+ )
+
+ print("\n=== Pytorch Kinematic Benchmark ===")
+ if not has_cuda:
+ print(" CUDA unavailable; CUDA benchmark is skipped.")
+
+ for n_sample in SAMPLE_SIZES:
+ print(f"**** Test over {n_sample} samples:")
+
+ qpos_cpu = _sample_qpos(
+ n_samples=n_sample,
+ lower_limits=PYTORCH_LOWER_LIMITS,
+ upper_limits=PYTORCH_UPPER_LIMITS,
+ margin=1e-1,
+ device=torch.device("cpu"),
+ dtype=torch.float64,
+ )
+ fk_xpos_cpu = cpu_solver.get_fk(qpos_cpu)
+ (
+ cpu_elapsed,
+ cpu_mem,
+ cpu_peak_gpu,
+ cpu_success,
+ cpu_ik_qpos,
+ ) = _timed_pytorch_ik_call(cpu_solver, fk_xpos_cpu, qpos_cpu)
+ check_xpos_cpu = cpu_solver.get_fk(cpu_ik_qpos)
+ cpu_t_err, cpu_r_err = get_pose_err(fk_xpos_cpu, check_xpos_cpu)
+
+ cpu_result = {
+ "cost_time_ms": cpu_elapsed * 1000.0,
+ "cpu_delta_mb": cpu_mem["cpu_mb"],
+ "gpu_delta_mb": cpu_mem["gpu_mb"],
+ "peak_gpu_mb": cpu_peak_gpu,
+ "success_rate": float(cpu_success.float().mean().item()),
+ "translation_err_mm": cpu_t_err.mean().item() * 1000.0,
+ "rotation_err_deg": cpu_r_err.mean().item() * 180.0 / np.pi,
+ }
+
+ perf_rows.append(
+ {
+ "sample_size": n_sample,
+ "impl": "pytorch_cpu",
+ "component": "pytorch_ik",
+ "cost_time_ms": f"{cpu_result['cost_time_ms']:.6f}",
+ "cpu_delta_mb": f"{cpu_result['cpu_delta_mb']:.6f}",
+ "gpu_delta_mb": f"{cpu_result['gpu_delta_mb']:.6f}",
+ "peak_gpu_mb": f"{cpu_result['peak_gpu_mb']:.6f}",
+ }
+ )
+ metric_rows.append(
+ {
+ "sample_size": n_sample,
+ "impl": "pytorch_cpu",
+ "component": "pytorch_ik",
+ "success_rate": f"{cpu_result['success_rate']:.6f}",
+ "translation_err_mm": f"{cpu_result['translation_err_mm']:.6f}",
+ "rotation_err_deg": f"{cpu_result['rotation_err_deg']:.6f}",
+ }
+ )
+
+ print(f"===Pytorch CPU IK time: {cpu_result['cost_time_ms']:.6f} ms")
+ print(f" Translation mean error: {cpu_result['translation_err_mm']:.6f} mm")
+ print(
+ f" Rotation mean error: {cpu_result['rotation_err_deg']:.6f} degrees"
+ )
+ print(f" Success rate: {cpu_result['success_rate'] * 100.0:.2f}%")
+ print(
+ " "
+ f"CPU Δ={cpu_result['cpu_delta_mb']:+.1f} MB "
+ f"GPU Δ={cpu_result['gpu_delta_mb']:+.1f} MB "
+ f"peak GPU={cpu_result['peak_gpu_mb']:.1f} MB"
+ )
+
+ if has_cuda and cuda_solver is not None:
+ qpos_cuda = qpos_cpu.to(torch.device("cuda"))
+ fk_xpos_cuda = cuda_solver.get_fk(qpos_cuda)
+ (
+ cuda_elapsed,
+ cuda_mem,
+ cuda_peak_gpu,
+ cuda_success,
+ cuda_ik_qpos,
+ ) = _timed_pytorch_ik_call(cuda_solver, fk_xpos_cuda, qpos_cuda)
+ check_xpos_cuda = cuda_solver.get_fk(cuda_ik_qpos)
+ cuda_t_err, cuda_r_err = get_pose_err(fk_xpos_cuda, check_xpos_cuda)
+
+ cuda_result = {
+ "cost_time_ms": cuda_elapsed * 1000.0,
+ "cpu_delta_mb": cuda_mem["cpu_mb"],
+ "gpu_delta_mb": cuda_mem["gpu_mb"],
+ "peak_gpu_mb": cuda_peak_gpu,
+ "success_rate": float(cuda_success.float().mean().item()),
+ "translation_err_mm": cuda_t_err.mean().item() * 1000.0,
+ "rotation_err_deg": cuda_r_err.mean().item() * 180.0 / np.pi,
+ }
+
+ perf_rows.append(
+ {
+ "sample_size": n_sample,
+ "impl": "pytorch_cuda",
+ "component": "pytorch_ik",
+ "cost_time_ms": f"{cuda_result['cost_time_ms']:.6f}",
+ "cpu_delta_mb": f"{cuda_result['cpu_delta_mb']:.6f}",
+ "gpu_delta_mb": f"{cuda_result['gpu_delta_mb']:.6f}",
+ "peak_gpu_mb": f"{cuda_result['peak_gpu_mb']:.6f}",
+ }
+ )
+ metric_rows.append(
+ {
+ "sample_size": n_sample,
+ "impl": "pytorch_cuda",
+ "component": "pytorch_ik",
+ "success_rate": f"{cuda_result['success_rate']:.6f}",
+ "translation_err_mm": f"{cuda_result['translation_err_mm']:.6f}",
+ "rotation_err_deg": f"{cuda_result['rotation_err_deg']:.6f}",
+ }
+ )
+
+ print(f"===Pytorch CUDA IK time: {cuda_result['cost_time_ms']:.6f} ms")
+ print(
+ f" Translation mean error: {cuda_result['translation_err_mm']:.6f} mm"
+ )
+ print(
+ f" Rotation mean error: {cuda_result['rotation_err_deg']:.6f} degrees"
+ )
+ print(
+ f" Success rate: {cuda_result['success_rate'] * 100.0:.2f}%"
+ )
+ print(
+ " "
+ f"CPU Δ={cuda_result['cpu_delta_mb']:+.1f} MB "
+ f"GPU Δ={cuda_result['gpu_delta_mb']:+.1f} MB "
+ f"peak GPU={cuda_result['peak_gpu_mb']:.1f} MB"
+ )
+
+ return perf_rows, metric_rows
+
+
+def benchmark_opw_solver() -> tuple[list[dict[str, object]], list[dict[str, object]]]:
+ """Benchmark OPW solver for multiple sample sizes."""
+ if not torch.cuda.is_available():
+ print("\n=== OPW Solver Benchmark ===")
+ print(" Skipped -- requires CUDA for Warp implementation comparison.")
+ return [], [
+ {
+ "sample_size": "N/A",
+ "impl": "opw_solver",
+ "component": "opw_ik",
+ "success_rate": "N/A",
+ "other_metrics": "skipped: requires CUDA for Warp comparison",
+ }
+ ]
+
+ cfg = OPWSolverCfg(
+ joint_names=("J1", "J2", "J3", "J4", "J5", "J6"),
+ user_qpos_limits=(OPW_LOWER_LIMITS, OPW_UPPER_LIMITS),
+ )
+ cfg.a1 = 400.333
+ cfg.a2 = -251.449
+ cfg.b = 0.0
+ cfg.c1 = 830
+ cfg.c2 = 1177.556
+ cfg.c3 = 1443.593
+ cfg.c4 = 230
+ cfg.offsets = (
+ 0.0,
+ 82.21350356417211 * np.pi / 180.0,
+ -167.21710113148163 * np.pi / 180.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ )
+ cfg.flip_axes = (True, False, True, True, False, True)
+ cfg.has_parallelogram = False
+
+ solver_warp = cfg.init_solver(device=torch.device("cuda"), pk_serial_chain="")
+ solver_py_opw = cfg.init_solver(device=torch.device("cpu"), pk_serial_chain="")
+
+ print("\n=== OPW Solver Benchmark ===")
+ perf_rows: list[dict[str, object]] = []
+ metric_rows: list[dict[str, object]] = []
+
+ for n_sample in SAMPLE_SIZES:
+ result = check_opw_solver(solver_warp, solver_py_opw, n_samples=n_sample)
+ print(f"**** Test over {n_sample} samples:")
+ print(f"===Warp CUDA IK time: {result['warp_ms']:.6f} ms")
+ print(f" Translation mean error: {result['warp_t_err_mm']:.6f} mm")
+ print(f" Rotation mean error: {result['warp_r_err_deg']:.6f} degrees")
+ print(f" Success rate: {result['warp_success_rate'] * 100.0:.2f}%")
+ print(
+ " "
+ f"CPU Δ={result['warp_cpu_delta_mb']:+.1f} MB "
+ f"GPU Δ={result['warp_gpu_delta_mb']:+.1f} MB "
+ f"peak GPU={result['warp_peak_gpu_mb']:.1f} MB"
+ )
+ print(f"===CPU OPW IK time: {result['cpu_ms']:.6f} ms")
+ print(f" Translation mean error: {result['cpu_t_err_mm']:.6f} mm")
+ print(f" Rotation mean error: {result['cpu_r_err_deg']:.6f} degrees")
+ print(f" Success rate: {result['cpu_success_rate'] * 100.0:.2f}%")
+ print(
+ " "
+ f"CPU Δ={result['cpu_cpu_delta_mb']:+.1f} MB "
+ f"GPU Δ={result['cpu_gpu_delta_mb']:+.1f} MB "
+ f"peak GPU={result['cpu_peak_gpu_mb']:.1f} MB"
+ )
+
+ perf_rows.append(
+ {
+ "sample_size": n_sample,
+ "impl": "opw_cuda",
+ "component": "opw_ik",
+ "cost_time_ms": f"{result['warp_ms']:.6f}",
+ "cpu_delta_mb": f"{result['warp_cpu_delta_mb']:.6f}",
+ "gpu_delta_mb": f"{result['warp_gpu_delta_mb']:.6f}",
+ "peak_gpu_mb": f"{result['warp_peak_gpu_mb']:.6f}",
+ }
+ )
+ perf_rows.append(
+ {
+ "sample_size": n_sample,
+ "impl": "opw_cpu",
+ "component": "opw_ik",
+ "cost_time_ms": f"{result['cpu_ms']:.6f}",
+ "cpu_delta_mb": f"{result['cpu_cpu_delta_mb']:.6f}",
+ "gpu_delta_mb": f"{result['cpu_gpu_delta_mb']:.6f}",
+ "peak_gpu_mb": f"{result['cpu_peak_gpu_mb']:.6f}",
+ }
+ )
+ metric_rows.append(
+ {
+ "sample_size": n_sample,
+ "impl": "opw_cuda",
+ "component": "opw_ik",
+ "success_rate": f"{result['warp_success_rate']:.6f}",
+ "translation_err_mm": f"{result['warp_t_err_mm']:.6f}",
+ "rotation_err_deg": f"{result['warp_r_err_deg']:.6f}",
+ }
+ )
+ metric_rows.append(
+ {
+ "sample_size": n_sample,
+ "impl": "opw_cpu",
+ "component": "opw_ik",
+ "success_rate": f"{result['cpu_success_rate']:.6f}",
+ "translation_err_mm": f"{result['cpu_t_err_mm']:.6f}",
+ "rotation_err_deg": f"{result['cpu_r_err_deg']:.6f}",
+ }
+ )
+
+ return perf_rows, metric_rows
+
+
+def run_all_benchmarks(selected_solvers: list[str] | None = None) -> None:
+ """Run unified OPW + Pytorch kinematic solver benchmarks."""
+ solvers_to_run = _normalize_selected_solvers(selected_solvers)
+
+ print("=" * 60)
+ print("Kinematic Solver Performance Benchmarks")
+ print("=" * 60)
+
+ print("\nSelected solvers:", ", ".join(sorted(solvers_to_run)))
+
+ print("\nConfiguration differences:")
+ print(
+ "- OPW solver: analytic OPW parameters via OPWSolverCfg with "
+ "opw-specific joint limits."
+ )
+ print("- Pytorch solver: UR10 URDF-based PytorchSolver with " "UR10 joint limits.")
+
+ perf_rows: list[dict[str, object]] = []
+ metric_rows: list[dict[str, object]] = []
+
+ if "opw" in solvers_to_run:
+ opw_perf_rows, opw_metric_rows = benchmark_opw_solver()
+ perf_rows.extend(opw_perf_rows)
+ metric_rows.extend(opw_metric_rows)
+
+ if "pytorch" in solvers_to_run:
+ pytorch_perf_rows, pytorch_metric_rows = benchmark_pytorch_solver()
+ perf_rows.extend(pytorch_perf_rows)
+ metric_rows.extend(pytorch_metric_rows)
+
+ leaderboard_rows = _build_leaderboard_rows(metric_rows)
+
+ benchmark_name = "kinematic_solver"
+
+ print("\n" + "=" * 60)
+ print("Benchmarks complete.")
+ print("=" * 60)
+
+ report_path = _write_markdown_report(
+ benchmark_name=benchmark_name,
+ perf_rows=perf_rows,
+ metric_rows=metric_rows,
+ leaderboard_rows=leaderboard_rows,
+ notes=[
+ "CPU/GPU memory fields are deltas measured around timed calls.",
+ "This report contains exactly three tables: Time & Memory, Success & Other Metrics, and Leaderboard.",
+ ]
+ + (
+ [
+ "OPW and Pytorch solvers use different initialization paths and different lower/upper joint limits."
+ ]
+ if solvers_to_run == set(SUPPORTED_SOLVERS)
+ else []
+ ),
+ )
+ print(f"Markdown report saved: {report_path}")
+
+
+if __name__ == "__main__":
+ args = _parse_args()
+ run_all_benchmarks(selected_solvers=args.solvers)
diff --git a/scripts/benchmark/workspace_analyzer/benchmark_workspace_analyzer.py b/scripts/benchmark/workspace_analyzer/benchmark_workspace_analyzer.py
new file mode 100644
index 00000000..67185059
--- /dev/null
+++ b/scripts/benchmark/workspace_analyzer/benchmark_workspace_analyzer.py
@@ -0,0 +1,488 @@
+# ----------------------------------------------------------------------------
+# Copyright (c) 2021-2026 DexForce Technology Co., Ltd.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ----------------------------------------------------------------------------
+
+from __future__ import annotations
+
+"""Benchmark script for workspace analyzer performance optimizations.
+
+Measures each optimization independently across multiple sample sizes.
+Run: python -m scripts.benchmark.workspace_analyzer.benchmark_workspace_analyzer
+"""
+
+import os
+import time
+from datetime import datetime
+from pathlib import Path
+
+import numpy as np
+import psutil
+import torch
+
+SAMPLE_SIZES_SMALL = [100, 1000, 10000, 50000]
+SAMPLE_SIZES_MEDIUM = [1000, 10000, 100000, 500000]
+
+
+def _sync_cuda() -> None:
+ """Synchronize CUDA stream when available."""
+ if torch.cuda.is_available():
+ torch.cuda.synchronize()
+
+
+def _reset_peak_gpu_memory() -> None:
+ """Reset PyTorch peak GPU memory stats when CUDA is available."""
+ if torch.cuda.is_available():
+ torch.cuda.reset_peak_memory_stats()
+
+
+def _peak_gpu_memory_mb() -> float:
+ """Return peak GPU memory allocated by PyTorch in MB."""
+ if not torch.cuda.is_available():
+ return 0.0
+ return torch.cuda.max_memory_allocated() / 1024**2
+
+
+def _memory_snapshot() -> dict[str, float]:
+ """Return current process memory usage snapshot in MB."""
+ process = psutil.Process(os.getpid())
+ cpu_mb = process.memory_info().rss / 1024**2
+ gpu_mb = (
+ torch.cuda.memory_allocated() / 1024**2 if torch.cuda.is_available() else 0.0
+ )
+ return {"cpu_mb": cpu_mb, "gpu_mb": gpu_mb}
+
+
+def _time_call(callable_fn) -> tuple[float, dict[str, float], float, object]:
+ """Time a callable and return elapsed seconds, memory deltas, and result."""
+ _reset_peak_gpu_memory()
+ before = _memory_snapshot()
+ _sync_cuda()
+
+ start = time.perf_counter()
+ result = callable_fn()
+ _sync_cuda()
+ elapsed = time.perf_counter() - start
+
+ after = _memory_snapshot()
+ deltas = {
+ "cpu_mb": after["cpu_mb"] - before["cpu_mb"],
+ "gpu_mb": after["gpu_mb"] - before["gpu_mb"],
+ }
+ return elapsed, deltas, _peak_gpu_memory_mb(), result
+
+
+def _format_perf_line(
+ n: int,
+ elapsed_s: float,
+ memory_delta: dict[str, float],
+ peak_gpu_mb: float,
+ extra_info: str,
+) -> str:
+ """Format one benchmark output line with aligned fields."""
+ return (
+ f" n={n:>7d}: {elapsed_s * 1000:>10.2f} ms | "
+ f"CPU Δ={memory_delta['cpu_mb']:+.1f} MB "
+ f"GPU Δ={memory_delta['gpu_mb']:+.1f} MB "
+ f"peak GPU={peak_gpu_mb:.1f} MB" + (f" | {extra_info}" if extra_info else "")
+ )
+
+
+def _format_markdown_table(rows: list[dict[str, object]]) -> list[str]:
+ """Format rows into a markdown table."""
+ if not rows:
+ return ["No data."]
+
+ headers = list(rows[0].keys())
+ lines = [
+ "| " + " | ".join(headers) + " |",
+ "| " + " | ".join(["---"] * len(headers)) + " |",
+ ]
+ for row in rows:
+ lines.append("| " + " | ".join(str(row[h]) for h in headers) + " |")
+ return lines
+
+
+def _write_markdown_report(
+ benchmark_name: str,
+ perf_rows: list[dict[str, object]],
+ metric_rows: list[dict[str, object]],
+ notes: list[str] | None = None,
+) -> Path:
+ """Write benchmark results to a markdown report with two tables."""
+ output_dir = Path("outputs/benchmarks")
+ output_dir.mkdir(parents=True, exist_ok=True)
+
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
+ report_path = output_dir / f"{benchmark_name}_{timestamp}.md"
+
+ lines: list[str] = [
+ f"# {benchmark_name} Benchmark Report",
+ "",
+ f"Generated at: {datetime.now().isoformat(timespec='seconds')}",
+ "",
+ "## Time & Memory",
+ "",
+ ]
+ lines.extend(_format_markdown_table(perf_rows))
+ lines.extend(["", "## Success & Other Metrics", ""])
+ lines.extend(_format_markdown_table(metric_rows))
+
+ if notes:
+ lines.extend(["", "## Notes", ""])
+ lines.extend([f"- {note}" for note in notes])
+
+ report_path.write_text("\n".join(lines) + "\n", encoding="utf-8")
+ return report_path
+
+
+def benchmark_halton_sampler() -> (
+ tuple[list[dict[str, object]], list[dict[str, object]]]
+):
+ """Benchmark Halton sampler: vectorized vs loop-based."""
+ from embodichain.lab.sim.utility.workspace_analyzer.samplers.halton_sampler import (
+ HaltonSampler,
+ )
+
+ sampler = HaltonSampler(seed=42)
+ bounds = torch.tensor(
+ [
+ [-3.14, 3.14],
+ [-3.14, 3.14],
+ [-3.14, 3.14],
+ [-3.14, 3.14],
+ [-3.14, 3.14],
+ [-3.14, 3.14],
+ ],
+ dtype=torch.float32,
+ )
+
+ print("\n=== Halton Sampler Benchmark ===")
+ perf_rows: list[dict[str, object]] = []
+ metric_rows: list[dict[str, object]] = []
+
+ for n in [100, 1000, 10000, 100000]:
+ elapsed, mem_delta, peak_gpu, samples = _time_call(
+ lambda: sampler.sample(num_samples=n, bounds=bounds)
+ )
+ elapsed_ms = elapsed * 1000.0
+ print(
+ _format_perf_line(
+ n=n,
+ elapsed_s=elapsed,
+ memory_delta=mem_delta,
+ peak_gpu_mb=peak_gpu,
+ extra_info=f"shape={tuple(samples.shape)}",
+ )
+ )
+
+ perf_rows.append(
+ {
+ "sample_size": n,
+ "impl": "workspace_analyzer",
+ "component": "halton_sampler",
+ "cost_time_ms": f"{elapsed_ms:.6f}",
+ "cpu_delta_mb": f"{mem_delta['cpu_mb']:.6f}",
+ "gpu_delta_mb": f"{mem_delta['gpu_mb']:.6f}",
+ "peak_gpu_mb": f"{peak_gpu:.6f}",
+ }
+ )
+ metric_rows.append(
+ {
+ "sample_size": n,
+ "impl": "workspace_analyzer",
+ "component": "halton_sampler",
+ "success_rate": "N/A",
+ "other_metrics": f"shape={tuple(samples.shape)}",
+ }
+ )
+
+ return perf_rows, metric_rows
+
+
+def benchmark_density_metric() -> (
+ tuple[list[dict[str, object]], list[dict[str, object]]]
+):
+ """Benchmark density metric: KDTree vs brute-force."""
+ from embodichain.lab.sim.utility.workspace_analyzer.metrics.density_metric import (
+ DensityMetric,
+ )
+ from embodichain.lab.sim.utility.workspace_analyzer.configs.metric_config import (
+ DensityConfig,
+ )
+
+ config = DensityConfig(radius=0.05, compute_distribution=False)
+ metric = DensityMetric(config)
+
+ print("\n=== Density Metric Benchmark ===")
+ perf_rows: list[dict[str, object]] = []
+ metric_rows: list[dict[str, object]] = []
+
+ for n in SAMPLE_SIZES_SMALL:
+ points = np.random.randn(n, 3).astype(np.float32) * 0.5
+
+ elapsed, mem_delta, peak_gpu, result = _time_call(
+ lambda: metric.compute(points)
+ )
+ elapsed_ms = elapsed * 1000.0
+ print(
+ _format_perf_line(
+ n=n,
+ elapsed_s=elapsed,
+ memory_delta=mem_delta,
+ peak_gpu_mb=peak_gpu,
+ extra_info=f"mean_density={result['mean_density']:.2f}",
+ )
+ )
+
+ perf_rows.append(
+ {
+ "sample_size": n,
+ "impl": "workspace_analyzer",
+ "component": "density_metric",
+ "cost_time_ms": f"{elapsed_ms:.6f}",
+ "cpu_delta_mb": f"{mem_delta['cpu_mb']:.6f}",
+ "gpu_delta_mb": f"{mem_delta['gpu_mb']:.6f}",
+ "peak_gpu_mb": f"{peak_gpu:.6f}",
+ }
+ )
+ metric_rows.append(
+ {
+ "sample_size": n,
+ "impl": "workspace_analyzer",
+ "component": "density_metric",
+ "success_rate": "N/A",
+ "other_metrics": f"mean_density={result['mean_density']:.6f}",
+ }
+ )
+
+ return perf_rows, metric_rows
+
+
+def benchmark_voxelization() -> tuple[list[dict[str, object]], list[dict[str, object]]]:
+ """Benchmark voxelization: np.unique vs dict-based."""
+ from embodichain.lab.sim.utility.workspace_analyzer.metrics.reachability_metric import (
+ ReachabilityMetric,
+ )
+ from embodichain.lab.sim.utility.workspace_analyzer.configs.metric_config import (
+ ReachabilityConfig,
+ )
+
+ config = ReachabilityConfig(voxel_size=0.01, compute_coverage=True)
+ metric = ReachabilityMetric(config)
+
+ print("\n=== Voxelization Benchmark ===")
+ perf_rows: list[dict[str, object]] = []
+ metric_rows: list[dict[str, object]] = []
+
+ for n in SAMPLE_SIZES_MEDIUM:
+ points = np.random.randn(n, 3).astype(np.float32) * 0.5
+
+ elapsed, mem_delta, peak_gpu, result = _time_call(
+ lambda: metric.compute(points)
+ )
+ elapsed_ms = elapsed * 1000.0
+ print(
+ _format_perf_line(
+ n=n,
+ elapsed_s=elapsed,
+ memory_delta=mem_delta,
+ peak_gpu_mb=peak_gpu,
+ extra_info=(
+ f"volume={result['volume']:.4f}, " f"voxels={result['num_voxels']}"
+ ),
+ )
+ )
+
+ perf_rows.append(
+ {
+ "sample_size": n,
+ "impl": "workspace_analyzer",
+ "component": "voxelization",
+ "cost_time_ms": f"{elapsed_ms:.6f}",
+ "cpu_delta_mb": f"{mem_delta['cpu_mb']:.6f}",
+ "gpu_delta_mb": f"{mem_delta['gpu_mb']:.6f}",
+ "peak_gpu_mb": f"{peak_gpu:.6f}",
+ }
+ )
+ metric_rows.append(
+ {
+ "sample_size": n,
+ "impl": "workspace_analyzer",
+ "component": "voxelization",
+ "success_rate": "N/A",
+ "other_metrics": (
+ f"volume={result['volume']:.6f}, num_voxels={result['num_voxels']}"
+ ),
+ }
+ )
+
+ return perf_rows, metric_rows
+
+
+def benchmark_manipulability() -> (
+ tuple[list[dict[str, object]], list[dict[str, object]]]
+):
+ """Benchmark manipulability: batch vs per-sample."""
+ from embodichain.lab.sim.utility.workspace_analyzer.metrics.manipulability_metric import (
+ ManipulabilityMetric,
+ )
+ from embodichain.lab.sim.utility.workspace_analyzer.configs.metric_config import (
+ ManipulabilityConfig,
+ )
+
+ config = ManipulabilityConfig(compute_isotropy=True)
+ metric = ManipulabilityMetric(config)
+
+ print("\n=== Manipulability Metric Benchmark ===")
+ perf_rows: list[dict[str, object]] = []
+ metric_rows: list[dict[str, object]] = []
+
+ for n in SAMPLE_SIZES_SMALL:
+ points = np.random.randn(n, 3).astype(np.float32) * 0.5
+ jacobians = np.random.randn(n, 6, 6).astype(np.float32) * 0.1
+
+ elapsed, mem_delta, peak_gpu, result = _time_call(
+ lambda: metric.compute(points, jacobians=jacobians)
+ )
+ elapsed_ms = elapsed * 1000.0
+ print(
+ _format_perf_line(
+ n=n,
+ elapsed_s=elapsed,
+ memory_delta=mem_delta,
+ peak_gpu_mb=peak_gpu,
+ extra_info=f"mean_manip={result['mean_manipulability']:.6f}",
+ )
+ )
+
+ perf_rows.append(
+ {
+ "sample_size": n,
+ "impl": "workspace_analyzer",
+ "component": "manipulability_metric",
+ "cost_time_ms": f"{elapsed_ms:.6f}",
+ "cpu_delta_mb": f"{mem_delta['cpu_mb']:.6f}",
+ "gpu_delta_mb": f"{mem_delta['gpu_mb']:.6f}",
+ "peak_gpu_mb": f"{peak_gpu:.6f}",
+ }
+ )
+ metric_rows.append(
+ {
+ "sample_size": n,
+ "impl": "workspace_analyzer",
+ "component": "manipulability_metric",
+ "success_rate": "N/A",
+ "other_metrics": (
+ f"mean_manipulability={result['mean_manipulability']:.6f}"
+ ),
+ }
+ )
+
+ return perf_rows, metric_rows
+
+
+def benchmark_batch_fk() -> tuple[list[dict[str, object]], list[dict[str, object]]]:
+ """Benchmark batch FK vs sequential FK (requires GPU robot setup).
+
+ This benchmark requires a running simulation with a robot.
+ It is skipped if no simulation is available.
+ """
+ print("\n=== Batch FK Benchmark (requires robot/simulation) ===")
+ print(" Skipped -- requires live SimulationManager and Robot.")
+ print(" To run manually, integrate with your robot setup:")
+ print(" analyzer.compute_workspace_points(joint_configs, batch_size=512)")
+ return [], [
+ {
+ "sample_size": "N/A",
+ "impl": "workspace_analyzer",
+ "component": "batch_fk",
+ "success_rate": "N/A",
+ "other_metrics": "skipped: requires live SimulationManager and Robot",
+ }
+ ]
+
+
+def benchmark_batch_ik() -> tuple[list[dict[str, object]], list[dict[str, object]]]:
+ """Benchmark batch IK vs sequential IK (requires GPU robot setup).
+
+ This benchmark requires a running simulation with a robot.
+ It is skipped if no simulation is available.
+ """
+ print("\n=== Batch IK Benchmark (requires robot/simulation) ===")
+ print(" Skipped -- requires live SimulationManager and Robot.")
+ print(" To run manually, integrate with your robot setup:")
+ print(" analyzer.compute_reachability(cartesian_points, batch_size=512)")
+ return [], [
+ {
+ "sample_size": "N/A",
+ "impl": "workspace_analyzer",
+ "component": "batch_ik",
+ "success_rate": "N/A",
+ "other_metrics": "skipped: requires live SimulationManager and Robot",
+ }
+ ]
+
+
+def run_all_benchmarks() -> None:
+ """Run all benchmarks and print summary."""
+ print("=" * 60)
+ print("Workspace Analyzer Performance Benchmarks")
+ print("=" * 60)
+
+ perf_rows: list[dict[str, object]] = []
+ metric_rows: list[dict[str, object]] = []
+
+ perf_part, metric_part = benchmark_halton_sampler()
+ perf_rows.extend(perf_part)
+ metric_rows.extend(metric_part)
+
+ perf_part, metric_part = benchmark_density_metric()
+ perf_rows.extend(perf_part)
+ metric_rows.extend(metric_part)
+
+ perf_part, metric_part = benchmark_voxelization()
+ perf_rows.extend(perf_part)
+ metric_rows.extend(metric_part)
+
+ perf_part, metric_part = benchmark_manipulability()
+ perf_rows.extend(perf_part)
+ metric_rows.extend(metric_part)
+
+ perf_part, metric_part = benchmark_batch_fk()
+ perf_rows.extend(perf_part)
+ metric_rows.extend(metric_part)
+
+ perf_part, metric_part = benchmark_batch_ik()
+ perf_rows.extend(perf_part)
+ metric_rows.extend(metric_part)
+
+ print("\n" + "=" * 60)
+ print("Benchmarks complete.")
+ print("=" * 60)
+
+ report_path = _write_markdown_report(
+ benchmark_name="workspace_analyzer",
+ perf_rows=perf_rows,
+ metric_rows=metric_rows,
+ notes=[
+ "CPU/GPU memory fields are deltas measured around timed calls.",
+ "This report contains exactly two tables: Time & Memory, and Success & Other Metrics.",
+ ],
+ )
+ print(f"Markdown report saved: {report_path}")
+
+
+if __name__ == "__main__":
+ run_all_benchmarks()
diff --git a/scripts/tutorials/grasp/grasp_generator.py b/scripts/tutorials/grasp/grasp_generator.py
index bab09c03..db4a79ac 100644
--- a/scripts/tutorials/grasp/grasp_generator.py
+++ b/scripts/tutorials/grasp/grasp_generator.py
@@ -30,8 +30,10 @@
from embodichain.lab.sim.shapes import MeshCfg
from embodichain.lab.sim.solvers import PytorchSolverCfg
from embodichain.data import get_data_path
+from embodichain.lab.gym.utils.gym_utils import add_env_launcher_args_to_parser
from embodichain.utils import logger
from embodichain.lab.sim.cfg import (
+ RenderCfg,
JointDrivePropertiesCfg,
RobotCfg,
LightCfg,
@@ -59,19 +61,7 @@ def parse_arguments():
parser = argparse.ArgumentParser(
description="Create and simulate a robot in SimulationManager"
)
- parser.add_argument(
- "--num_envs", type=int, default=1, help="Number of parallel environments"
- )
- parser.add_argument(
- "--enable_rt", action="store_true", help="Enable ray tracing rendering"
- )
- parser.add_argument("--headless", action="store_true", help="Enable headless mode")
- parser.add_argument(
- "--device",
- type=str,
- default="cpu",
- help="device to run the environment on, e.g., 'cpu' or 'cuda'",
- )
+ add_env_launcher_args_to_parser(parser)
return parser.parse_args()
@@ -88,21 +78,20 @@ def initialize_simulation(args) -> SimulationManager:
config = SimulationManagerCfg(
headless=True,
sim_device=args.device,
- enable_rt=args.enable_rt,
+ render_cfg=RenderCfg(renderer=args.renderer),
physics_dt=1.0 / 100.0,
arena_space=2.5,
)
sim = SimulationManager(config)
- if args.enable_rt:
- light = sim.add_light(
- cfg=LightCfg(
- uid="main_light",
- color=(0.6, 0.6, 0.6),
- intensity=30.0,
- init_pos=(1.0, 0, 3.0),
- )
+ light = sim.add_light(
+ cfg=LightCfg(
+ uid="main_light",
+ color=(0.6, 0.6, 0.6),
+ intensity=30.0,
+ init_pos=(1.0, 0, 3.0),
)
+ )
return sim
@@ -271,11 +260,20 @@ def get_grasp_traj(sim: SimulationManager, robot: Robot, grasp_xpos: torch.Tenso
)
obj_poses = mug.get_local_pose(to_matrix=True)
grasp_xpos_list = []
- for obj_pose in obj_poses:
- grasp_pose, _ = grasp_generator.get_grasp_poses(
+
+ rest_xpos = robot.compute_fk(
+ qpos=robot.get_qpos("arm"), name="arm", to_matrix=True
+ )[0]
+ for i, obj_pose in enumerate(obj_poses):
+ is_success, grasp_pose, open_length = grasp_generator.get_grasp_poses(
obj_pose, approach_direction, visualize_pose=False
)
- grasp_xpos_list.append(grasp_pose.unsqueeze(0))
+ if is_success:
+ grasp_xpos_list.append(grasp_pose.unsqueeze(0))
+ else:
+ logger.log_warning(f"No valid grasp pose found for {i}-th object.")
+ grasp_xpos_list.append(rest_xpos.unsqueeze(0))
+
grasp_xpos = torch.cat(grasp_xpos_list, dim=0)
cost_time = time.time() - start_time
logger.log_info(f"Get grasp pose cost time: {cost_time:.2f} seconds")
diff --git a/scripts/tutorials/gym/modular_env.py b/scripts/tutorials/gym/modular_env.py
index 9c8bfd66..17b14fb8 100644
--- a/scripts/tutorials/gym/modular_env.py
+++ b/scripts/tutorials/gym/modular_env.py
@@ -33,6 +33,7 @@
from embodichain.lab.sim.sensors import StereoCameraCfg, SensorCfg
from embodichain.lab.sim.shapes import MeshCfg
from embodichain.lab.sim.cfg import (
+ RenderCfg,
LightCfg,
ArticulationCfg,
RobotCfg,
@@ -209,12 +210,20 @@ def __init__(self, cfg: EmbodiedEnvCfg, **kwargs):
import argparse
from embodichain.lab.sim import SimulationManagerCfg
+ from embodichain.lab.gym.utils.gym_utils import add_env_launcher_args_to_parser
parser = argparse.ArgumentParser()
- parser.add_argument("--enable_rt", action="store_true", help="Enable ray tracing")
+ add_env_launcher_args_to_parser(parser)
args = parser.parse_args()
- env_cfg = ExampleCfg(sim_cfg=SimulationManagerCfg(enable_rt=args.enable_rt))
+ env_cfg = ExampleCfg(
+ sim_cfg=SimulationManagerCfg(
+ render_cfg=RenderCfg(renderer=args.renderer),
+ headless=args.headless,
+ sim_device=args.device,
+ num_envs=args.num_envs,
+ )
+ )
# Create the Gym environment
env = gym.make("ModularEnv-v1", cfg=env_cfg)
diff --git a/scripts/tutorials/gym/random_reach.py b/scripts/tutorials/gym/random_reach.py
index 4aca9ab3..b55a7a8e 100644
--- a/scripts/tutorials/gym/random_reach.py
+++ b/scripts/tutorials/gym/random_reach.py
@@ -24,6 +24,7 @@
from embodichain.lab.sim.shapes import CubeCfg
from embodichain.lab.sim.objects import RigidObject, Robot
from embodichain.lab.sim.cfg import (
+ RenderCfg,
RobotCfg,
RigidObjectCfg,
RigidBodyAttributesCfg,
@@ -43,11 +44,15 @@ def __init__(
num_envs=1,
headless=False,
device="cpu",
+ renderer="hybrid",
**kwargs,
):
env_cfg = EnvCfg(
sim_cfg=SimulationManagerCfg(
- headless=headless, arena_space=2.0, sim_device=device
+ headless=headless,
+ arena_space=2.0,
+ sim_device=device,
+ render_cfg=RenderCfg(renderer=renderer),
),
num_envs=num_envs,
)
@@ -112,19 +117,12 @@ def _extend_obs(self, obs: EnvObs, **kwargs) -> EnvObs:
import argparse
import time
+ from embodichain.lab.gym.utils.gym_utils import add_env_launcher_args_to_parser
+
parser = argparse.ArgumentParser(
description="Demo for running a random reach environment."
)
- parser.add_argument(
- "--num_envs", type=int, default=1, help="number of environments to run"
- )
- parser.add_argument(
- "--device",
- type=str,
- default="cpu",
- help="device to run the environment on, e.g., 'cpu' or 'cuda'",
- )
- parser.add_argument("--headless", action="store_true", help="run in headless mode")
+ add_env_launcher_args_to_parser(parser)
args = parser.parse_args()
env = gym.make(
@@ -132,6 +130,7 @@ def _extend_obs(self, obs: EnvObs, **kwargs) -> EnvObs:
num_envs=args.num_envs,
headless=args.headless,
device=args.device,
+ renderer=args.renderer,
)
for episode in range(10):
diff --git a/scripts/tutorials/sim/atomic_actions.py b/scripts/tutorials/sim/atomic_actions.py
new file mode 100644
index 00000000..1f4de8d5
--- /dev/null
+++ b/scripts/tutorials/sim/atomic_actions.py
@@ -0,0 +1,348 @@
+# ----------------------------------------------------------------------------
+# Copyright (c) 2021-2026 DexForce Technology Co., Ltd.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ----------------------------------------------------------------------------
+
+"""
+Tutorial: Atomic Actions for Robot Motion Generation
+=====================================================
+
+This script shows how to use the atomic action system to plan and execute
+a pick-and-place task with a robot arm.
+
+Key concepts covered:
+ 1. Setting up a MotionGenerator and AtomicActionEngine
+ 2. Describing what to pick using ObjectSemantics and AntipodalAffordance
+ 3. Running a pick → place → move sequence with execute_static()
+
+Run with:
+ python atomic_actions.py [--num_envs N] [--enable_rt]
+"""
+
+import argparse
+import numpy as np
+import time
+import torch
+
+from embodichain.lab.sim import SimulationManager, SimulationManagerCfg
+from embodichain.lab.sim.objects import Robot, RigidObject
+from embodichain.lab.sim.shapes import MeshCfg
+from embodichain.lab.sim.solvers import PytorchSolverCfg
+from embodichain.data import get_data_path
+from embodichain.lab.sim.cfg import (
+ JointDrivePropertiesCfg,
+ RobotCfg,
+ RigidObjectCfg,
+ RigidBodyAttributesCfg,
+ LightCfg,
+ URDFCfg,
+)
+from embodichain.lab.sim.planners import MotionGenerator, MotionGenCfg, ToppraPlannerCfg
+from embodichain.toolkits.graspkit.pg_grasp.gripper_collision_checker import (
+ GripperCollisionCfg,
+)
+from embodichain.toolkits.graspkit.pg_grasp.antipodal_generator import (
+ GraspGenerator,
+ GraspGeneratorCfg,
+ AntipodalSamplerCfg,
+)
+
+# Import everything from the public atomic_actions API
+from embodichain.lab.sim.atomic_actions import (
+ AtomicActionEngine,
+ ObjectSemantics,
+ AntipodalAffordance,
+ PickUpActionCfg,
+ PlaceActionCfg,
+ MoveActionCfg,
+)
+
+
+def parse_arguments():
+ """
+ Parse command-line arguments to configure the simulation.
+
+ Returns:
+ argparse.Namespace: Parsed arguments including number of environments, device, and rendering options.
+ """
+ parser = argparse.ArgumentParser(
+ description="Create and simulate a robot in SimulationManager"
+ )
+ parser.add_argument(
+ "--enable_rt", action="store_true", help="Enable ray tracing rendering"
+ )
+ parser.add_argument(
+ "--num_envs", type=int, default=1, help="Number of parallel environments"
+ )
+ return parser.parse_args()
+
+
+def initialize_simulation(args):
+ """
+ Initialize the simulation environment based on the provided arguments.
+
+ Args:
+ args (argparse.Namespace): Parsed command-line arguments.
+
+ Returns:
+ SimulationManager: Configured simulation manager instance.
+ """
+ config = SimulationManagerCfg(
+ headless=True,
+ sim_device="cuda",
+ enable_rt=args.enable_rt,
+ physics_dt=1.0 / 100.0,
+ num_envs=args.num_envs,
+ )
+ sim = SimulationManager(config)
+
+ light = sim.add_light(
+ cfg=LightCfg(uid="main_light", intensity=50.0, init_pos=(0, 0, 2.0))
+ )
+
+ return sim
+
+
+def create_robot(sim: SimulationManager, position=[0.0, 0.0, 0.0]):
+ """
+ Create and configure a robot with an arm and a dexterous hand in the simulation.
+
+ Args:
+ sim (SimulationManager): The simulation manager instance.
+
+ Returns:
+ Robot: The configured robot instance added to the simulation.
+ """
+ # Retrieve URDF paths for the robot arm and hand
+ ur10_urdf_path = get_data_path("UniversalRobots/UR10/UR10.urdf")
+ gripper_urdf_path = get_data_path("DH_PGC_140_50_M/DH_PGC_140_50_M.urdf")
+ # Configure the robot with its components and control properties
+ cfg = RobotCfg(
+ uid="UR10",
+ urdf_cfg=URDFCfg(
+ components=[
+ {"component_type": "arm", "urdf_path": ur10_urdf_path},
+ {"component_type": "hand", "urdf_path": gripper_urdf_path},
+ ]
+ ),
+ drive_pros=JointDrivePropertiesCfg(
+ stiffness={"JOINT[0-9]": 1e4, "FINGER[1-2]": 1e2},
+ damping={"JOINT[0-9]": 1e3, "FINGER[1-2]": 1e1},
+ max_effort={"JOINT[0-9]": 1e5, "FINGER[1-2]": 1e3},
+ drive_type="force",
+ ),
+ control_parts={
+ "arm": ["JOINT[0-9]"],
+ "hand": ["FINGER[1-2]"],
+ },
+ solver_cfg={
+ "arm": PytorchSolverCfg(
+ end_link_name="ee_link",
+ root_link_name="base_link",
+ tcp=[
+ [0.0, 1.0, 0.0, 0.0],
+ [-1.0, 0.0, 0.0, 0.0],
+ [0.0, 0.0, 1.0, 0.12],
+ [0.0, 0.0, 0.0, 1.0],
+ ],
+ )
+ },
+ init_qpos=[0.0, -np.pi / 2, -np.pi / 2, np.pi / 2, -np.pi / 2, 0.0, 0.0, 0.0],
+ init_pos=position,
+ )
+ return sim.add_robot(cfg=cfg)
+
+
+def create_mug(sim: SimulationManager) -> RigidObject:
+ mug_cfg = RigidObjectCfg(
+ uid="mug",
+ shape=MeshCfg(
+ fpath=get_data_path("CoffeeCup/cup.ply"),
+ ),
+ attrs=RigidBodyAttributesCfg(
+ mass=0.01,
+ dynamic_friction=0.97,
+ static_friction=0.99,
+ ),
+ max_convex_hull_num=16,
+ init_pos=[0.55, 0.0, 0.01],
+ init_rot=[0.0, 0.0, -90],
+ body_scale=(4, 4, 4),
+ )
+ mug = sim.add_rigid_object(cfg=mug_cfg)
+ return mug
+
+
+def main():
+ """Pick up a mug and place it at a new location using atomic actions."""
+ args = parse_arguments()
+
+ # ------------------------------------------------------------------ #
+ # Step 1: Set up simulation, robot, and object #
+ # ------------------------------------------------------------------ #
+ sim: SimulationManager = initialize_simulation(args)
+ robot = create_robot(sim)
+ mug = create_mug(sim)
+
+ # ------------------------------------------------------------------ #
+ # Step 2: Create a MotionGenerator for the robot #
+ # MotionGenerator handles trajectory planning (IK + TOPPRA smoothing) #
+ # ------------------------------------------------------------------ #
+ motion_gen = MotionGenerator(
+ cfg=MotionGenCfg(planner_cfg=ToppraPlannerCfg(robot_uid=robot.uid))
+ )
+
+ # ------------------------------------------------------------------ #
+ # Step 3: Configure the three atomic actions #
+ # #
+ # PickUpAction — approach → close gripper → lift #
+ # PlaceAction — lower → open gripper → retract #
+ # MoveAction — free-space move to a target EEF pose #
+ # ------------------------------------------------------------------ #
+ # Gripper joint values for this robot (DH_PGC_140):
+ # open = [0.00, 0.00] (fully open)
+ # close = [0.025, 0.025] (grasping width)
+ hand_open = torch.tensor([0.00, 0.00], dtype=torch.float32, device=sim.device)
+ hand_close = torch.tensor([0.025, 0.025], dtype=torch.float32, device=sim.device)
+
+ pickup_cfg = PickUpActionCfg(
+ control_part="arm",
+ hand_control_part="hand",
+ hand_open_qpos=hand_open,
+ hand_close_qpos=hand_close,
+ # Approach the object from directly above (negative world-Z)
+ approach_direction=torch.tensor(
+ [0.0, 0.0, -1.0], dtype=torch.float32, device=sim.device
+ ),
+ pre_grasp_distance=0.15, # hover 15 cm above before descending
+ lift_height=0.15, # lift 15 cm after grasping
+ )
+
+ place_cfg = PlaceActionCfg(
+ control_part="arm",
+ hand_control_part="hand",
+ hand_open_qpos=hand_open,
+ hand_close_qpos=hand_close,
+ lift_height=0.15,
+ )
+
+ move_cfg = MoveActionCfg(
+ control_part="arm",
+ )
+
+ # ------------------------------------------------------------------ #
+ # Step 4: Build the AtomicActionEngine #
+ # #
+ # actions_cfg_list defines the ORDER of actions that execute_static() #
+ # will run. Each entry is matched positionally to target_list. #
+ # ------------------------------------------------------------------ #
+ atomic_engine = AtomicActionEngine(
+ motion_generator=motion_gen,
+ actions_cfg_list=[pickup_cfg, place_cfg, move_cfg],
+ )
+
+ sim.init_gpu_physics()
+ sim.open_window()
+
+ # ------------------------------------------------------------------ #
+ # Step 5: Describe the mug with ObjectSemantics #
+ # #
+ # ObjectSemantics bundles together: #
+ # - geometry (mesh vertices/triangles for grasp annotation) #
+ # - affordance (how to grasp the object — here antipodal grasps) #
+ # - entity reference (so the action can read the live object pose) #
+ # ------------------------------------------------------------------ #
+ mug_grasp_affordance = AntipodalAffordance(
+ object_label="mug",
+ force_reannotate=False,
+ custom_config={
+ "gripper_collision_cfg": GripperCollisionCfg(
+ max_open_length=0.088, finger_length=0.078, point_sample_dense=0.012
+ ),
+ "generator_cfg": GraspGeneratorCfg(
+ viser_port=11801,
+ antipodal_sampler_cfg=AntipodalSamplerCfg(
+ n_sample=20000, max_length=0.088, min_length=0.003
+ ),
+ ),
+ },
+ )
+ mug_semantics = ObjectSemantics(
+ label="mug",
+ geometry={
+ "mesh_vertices": mug.get_vertices(env_ids=[0], scale=True)[0],
+ "mesh_triangles": mug.get_triangles(env_ids=[0])[0],
+ },
+ affordance=mug_grasp_affordance,
+ entity=mug, # needed so PickUpAction can read the mug's live pose
+ )
+
+ # ------------------------------------------------------------------ #
+ # Step 6: Define target poses for place and final rest #
+ # #
+ # Poses are 4×4 homogeneous transforms (rotation | translation). #
+ # For PickUpAction the target is mug_semantics — the action computes #
+ # the grasp pose automatically from the affordance. #
+ # ------------------------------------------------------------------ #
+ # Place the mug 20 cm to the left and 40 cm forward from its pickup pose
+ place_xpos = torch.tensor(
+ [
+ [-0.0539, -0.9985, -0.0022, 0.2489],
+ [-0.9977, 0.0540, -0.0401, 0.3970],
+ [0.0401, 0.0000, -0.9992, 0.2400],
+ [0.0000, 0.0000, 0.0000, 1.0000],
+ ],
+ dtype=torch.float32,
+ device=sim.device,
+ )
+
+ # Move the arm to a safe resting pose after placing
+ rest_xpos = torch.tensor(
+ [
+ [-0.0539, -0.9985, -0.0022, 0.5000],
+ [-0.9977, 0.0540, -0.0401, 0.0000],
+ [0.0401, 0.0000, -0.9992, 0.5000],
+ [0.0000, 0.0000, 0.0000, 1.0000],
+ ],
+ dtype=torch.float32,
+ device=sim.device,
+ )
+
+ # ------------------------------------------------------------------ #
+ # Step 7: Plan and execute the full sequence #
+ # #
+ # execute_static() plans all three actions in order and returns a #
+ # single concatenated joint trajectory (n_envs, n_waypoints, dof). #
+ # We then replay it frame-by-frame in the simulator. #
+ # ------------------------------------------------------------------ #
+ print("Planning pick → place → move trajectory...")
+ is_success, traj = atomic_engine.execute_static(
+ target_list=[mug_semantics, place_xpos, rest_xpos]
+ )
+
+ if not is_success:
+ print("Planning failed. Check that the target poses are reachable.")
+ return
+
+ print(f"Success! Replaying {traj.shape[1]} waypoints...")
+ for i in range(traj.shape[1]):
+ robot.set_qpos(traj[:, i])
+ sim.update(step=4)
+ time.sleep(1e-2)
+
+ input("Press Enter to exit...")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/scripts/tutorials/sim/create_cloth.py b/scripts/tutorials/sim/create_cloth.py
index b81f2bf6..1f0d883c 100644
--- a/scripts/tutorials/sim/create_cloth.py
+++ b/scripts/tutorials/sim/create_cloth.py
@@ -27,7 +27,9 @@
import open3d as o3d
from dexsim.utility.path import get_resources_data_path
from embodichain.lab.sim import SimulationManager, SimulationManagerCfg
+from embodichain.lab.gym.utils.gym_utils import add_env_launcher_args_to_parser
from embodichain.lab.sim.cfg import (
+ RenderCfg,
RigidObjectCfg,
RigidBodyAttributesCfg,
ClothObjectCfg,
@@ -78,21 +80,7 @@ def main():
parser = argparse.ArgumentParser(
description="Create a simulation scene with SimulationManager"
)
- parser.add_argument(
- "--headless",
- action="store_true",
- default=False,
- help="Run simulation in headless mode",
- )
- parser.add_argument(
- "--num_envs", type=int, default=1, help="Number of parallel environments"
- )
- parser.add_argument(
- "--enable_rt",
- action="store_true",
- default=False,
- help="Enable ray tracing for better visuals",
- )
+ add_env_launcher_args_to_parser(parser)
args = parser.parse_args()
# Configure the simulation
@@ -100,11 +88,10 @@ def main():
width=1920,
height=1080,
headless=True,
+ num_envs=args.num_envs,
physics_dt=1.0 / 100.0, # Physics timestep (100 Hz)
sim_device="cuda", # soft simulation only supports cuda device
- enable_rt=args.enable_rt, # Enable ray tracing for better visuals
- num_envs=args.num_envs, # Number of parallel environments
- arena_space=2.0,
+ render_cfg=RenderCfg(renderer=args.renderer),
)
# Create the simulation instance
@@ -128,7 +115,7 @@ def main():
init_rot=[0, 0, 0],
physical_attr=ClothPhysicalAttributesCfg(
mass=0.01,
- youngs=1e10,
+ youngs=1e9,
poissons=0.4,
thickness=0.04,
bending_stiffness=0.01,
diff --git a/scripts/tutorials/sim/create_rigid_object_group.py b/scripts/tutorials/sim/create_rigid_object_group.py
index 1b734015..d681dc91 100644
--- a/scripts/tutorials/sim/create_rigid_object_group.py
+++ b/scripts/tutorials/sim/create_rigid_object_group.py
@@ -22,7 +22,8 @@
import time
from embodichain.lab.sim import SimulationManager, SimulationManagerCfg
-from embodichain.lab.sim.cfg import RigidBodyAttributesCfg
+from embodichain.lab.gym.utils.gym_utils import add_env_launcher_args_to_parser
+from embodichain.lab.sim.cfg import RigidBodyAttributesCfg, RenderCfg
from embodichain.lab.sim.shapes import CubeCfg
from embodichain.lab.sim.objects import (
RigidObjectGroup,
@@ -38,24 +39,7 @@ def main():
parser = argparse.ArgumentParser(
description="Create a simulation scene with SimulationManager"
)
- parser.add_argument(
- "--headless",
- action="store_true",
- default=False,
- help="Run simulation in headless mode",
- )
- parser.add_argument(
- "--num_envs", type=int, default=1, help="Number of parallel environments"
- )
- parser.add_argument(
- "--device", type=str, default="cpu", help="Simulation device (cuda or cpu)"
- )
- parser.add_argument(
- "--enable_rt",
- action="store_true",
- default=False,
- help="Enable ray tracing for better visuals",
- )
+ add_env_launcher_args_to_parser(parser)
args = parser.parse_args()
# Configure the simulation
@@ -65,7 +49,9 @@ def main():
headless=True,
physics_dt=1.0 / 100.0, # Physics timestep (100 Hz)
sim_device=args.device,
- enable_rt=args.enable_rt, # Enable ray tracing for better visuals
+ render_cfg=RenderCfg(
+ renderer=args.renderer
+ ), # Enable ray tracing for better visuals
num_envs=args.num_envs,
arena_space=3.0,
)
diff --git a/scripts/tutorials/sim/create_robot.py b/scripts/tutorials/sim/create_robot.py
index 614abb7b..3fe3f9fd 100644
--- a/scripts/tutorials/sim/create_robot.py
+++ b/scripts/tutorials/sim/create_robot.py
@@ -31,11 +31,13 @@
from embodichain.lab.sim import SimulationManager, SimulationManagerCfg
from embodichain.lab.sim.objects import Robot
from embodichain.lab.sim.cfg import (
+ RenderCfg,
JointDrivePropertiesCfg,
RobotCfg,
URDFCfg,
)
from embodichain.data import get_data_path
+from embodichain.lab.gym.utils.gym_utils import add_env_launcher_args_to_parser
def main():
@@ -45,20 +47,7 @@ def main():
parser = argparse.ArgumentParser(
description="Create and simulate a robot in SimulationManager"
)
- parser.add_argument(
- "--num_envs", type=int, default=4, help="Number of environments to simulate"
- )
- parser.add_argument(
- "--device",
- type=str,
- default="cpu",
- choices=["cpu", "cuda"],
- help="Device to run simulation on",
- )
- parser.add_argument("--headless", action="store_true", help="Run in headless mode")
- parser.add_argument(
- "--enable_rt", action="store_true", help="Enable ray tracing rendering"
- )
+ add_env_launcher_args_to_parser(parser)
args = parser.parse_args()
# Initialize simulation
@@ -67,7 +56,7 @@ def main():
headless=True,
sim_device=args.device,
arena_space=3.0,
- enable_rt=args.enable_rt,
+ render_cfg=RenderCfg(renderer=args.renderer),
physics_dt=1.0 / 100.0,
num_envs=args.num_envs,
)
diff --git a/scripts/tutorials/sim/create_scene.py b/scripts/tutorials/sim/create_scene.py
index 4f440ca1..b8f6c727 100644
--- a/scripts/tutorials/sim/create_scene.py
+++ b/scripts/tutorials/sim/create_scene.py
@@ -23,10 +23,11 @@
import time
from embodichain.lab.sim import SimulationManager, SimulationManagerCfg
-from embodichain.lab.sim.cfg import RigidBodyAttributesCfg
+from embodichain.lab.sim.cfg import RigidBodyAttributesCfg, RenderCfg
from embodichain.lab.sim.shapes import CubeCfg, MeshCfg
from embodichain.lab.sim.objects import RigidObject, RigidObjectCfg
-from dexsim.utility.path import get_resources_data_path
+from embodichain.lab.gym.utils.gym_utils import add_env_launcher_args_to_parser
+from embodichain.data import get_data_path
def main():
@@ -36,24 +37,7 @@ def main():
parser = argparse.ArgumentParser(
description="Create a simulation scene with SimulationManager"
)
- parser.add_argument(
- "--headless",
- action="store_true",
- default=False,
- help="Run simulation in headless mode",
- )
- parser.add_argument(
- "--num_envs", type=int, default=1, help="Number of parallel environments"
- )
- parser.add_argument(
- "--device", type=str, default="cpu", help="Simulation device (cuda or cpu)"
- )
- parser.add_argument(
- "--enable_rt",
- action="store_true",
- default=False,
- help="Enable ray tracing for better visuals",
- )
+ add_env_launcher_args_to_parser(parser)
args = parser.parse_args()
# Configure the simulation
@@ -63,7 +47,9 @@ def main():
headless=True,
physics_dt=1.0 / 100.0, # Physics timestep (100 Hz)
sim_device=args.device,
- enable_rt=args.enable_rt, # Enable ray tracing for better visuals
+ render_cfg=RenderCfg(
+ renderer=args.renderer,
+ ),
num_envs=args.num_envs,
arena_space=3.0,
)
@@ -71,7 +57,7 @@ def main():
# Create the simulation instance
sim = SimulationManager(sim_cfg)
- # Add objects to the scene
+ # Add cube object to the scene
cube: RigidObject = sim.add_rigid_object(
cfg=RigidObjectCfg(
uid="cube",
@@ -83,7 +69,23 @@ def main():
static_friction=0.5,
restitution=0.1,
),
- init_pos=[0.0, 0.0, 1.0],
+ init_pos=[0, 0.0, 1.0],
+ )
+ )
+
+ # Add chair object to the scene
+ path = get_data_path("Chair/chair.glb")
+ chair: RigidObject = sim.add_rigid_object(
+ cfg=RigidObjectCfg(
+ uid="chair",
+ shape=MeshCfg(fpath=path),
+ body_type="dynamic",
+ attrs=RigidBodyAttributesCfg(
+ mass=3.0,
+ ),
+ body_scale=[0.5, 0.5, 0.5],
+ init_pos=[0.0, 0.0, 0.2],
+ init_rot=[90.0, 0.0, 0.0],
)
)
diff --git a/scripts/tutorials/sim/create_sensor.py b/scripts/tutorials/sim/create_sensor.py
index f4279090..39534d32 100644
--- a/scripts/tutorials/sim/create_sensor.py
+++ b/scripts/tutorials/sim/create_sensor.py
@@ -29,9 +29,11 @@
from scipy.spatial.transform import Rotation as R
from embodichain.lab.sim import SimulationManager, SimulationManagerCfg
+from embodichain.lab.gym.utils.gym_utils import add_env_launcher_args_to_parser
from embodichain.lab.sim.sensors import Camera, CameraCfg
from embodichain.lab.sim.objects import Robot
from embodichain.lab.sim.cfg import (
+ RenderCfg,
JointDrivePropertiesCfg,
RobotCfg,
URDFCfg,
@@ -73,20 +75,7 @@ def main():
parser = argparse.ArgumentParser(
description="Create and simulate a robot in SimulationManager"
)
- parser.add_argument(
- "--num_envs", type=int, default=1, help="Number of environments to simulate"
- )
- parser.add_argument(
- "--device",
- type=str,
- default="cpu",
- choices=["cpu", "cuda"],
- help="Device to run simulation on",
- )
- parser.add_argument("--headless", action="store_true", help="Run in headless mode")
- parser.add_argument(
- "--enable_rt", action="store_true", help="Enable ray tracing rendering"
- )
+ add_env_launcher_args_to_parser(parser)
parser.add_argument(
"--attach_sensor",
action="store_true",
@@ -100,7 +89,7 @@ def main():
headless=True,
sim_device=args.device,
arena_space=3.0,
- enable_rt=args.enable_rt,
+ render_cfg=RenderCfg(renderer=args.renderer),
physics_dt=1.0 / 100.0,
num_envs=args.num_envs,
)
diff --git a/scripts/tutorials/sim/create_softbody.py b/scripts/tutorials/sim/create_softbody.py
index 087f35ec..3b8973ef 100644
--- a/scripts/tutorials/sim/create_softbody.py
+++ b/scripts/tutorials/sim/create_softbody.py
@@ -23,7 +23,9 @@
import time
from dexsim.utility.path import get_resources_data_path
from embodichain.lab.sim import SimulationManager, SimulationManagerCfg
+from embodichain.lab.gym.utils.gym_utils import add_env_launcher_args_to_parser
from embodichain.lab.sim.cfg import (
+ RenderCfg,
SoftbodyVoxelAttributesCfg,
SoftbodyPhysicalAttributesCfg,
)
@@ -41,21 +43,7 @@ def main():
parser = argparse.ArgumentParser(
description="Create a simulation scene with SimulationManager"
)
- parser.add_argument(
- "--headless",
- action="store_true",
- default=False,
- help="Run simulation in headless mode",
- )
- parser.add_argument(
- "--num_envs", type=int, default=4, help="Number of parallel environments"
- )
- parser.add_argument(
- "--enable_rt",
- action="store_true",
- default=False,
- help="Enable ray tracing for better visuals",
- )
+ add_env_launcher_args_to_parser(parser)
args = parser.parse_args()
# Configure the simulation
@@ -63,9 +51,12 @@ def main():
width=1920,
height=1080,
headless=True,
+ num_envs=args.num_envs,
physics_dt=1.0 / 100.0, # Physics timestep (100 Hz)
sim_device="cuda", # soft simulation only supports cuda device
- enable_rt=args.enable_rt, # Enable ray tracing for better visuals
+ render_cfg=RenderCfg(
+ renderer=args.renderer
+ ), # Enable ray tracing for better visuals
)
# Create the simulation instance
diff --git a/scripts/tutorials/sim/export_usd.py b/scripts/tutorials/sim/export_usd.py
index 90e81691..c6cb91c7 100644
--- a/scripts/tutorials/sim/export_usd.py
+++ b/scripts/tutorials/sim/export_usd.py
@@ -15,14 +15,16 @@
# ----------------------------------------------------------------------------
"""
-This script demonstrates how to export a simulation scene to a usd file using the SimulationManager.
+This script demonstrates how to export a simulation scene to a usd file using the SimulationManager.
"""
import argparse
import numpy as np
from embodichain.lab.sim import SimulationManager, SimulationManagerCfg
+from embodichain.lab.gym.utils.gym_utils import add_env_launcher_args_to_parser
from embodichain.lab.sim.objects import Robot, RigidObject
from embodichain.lab.sim.cfg import (
+ RenderCfg,
LightCfg,
JointDrivePropertiesCfg,
RigidObjectCfg,
@@ -46,17 +48,7 @@ def parse_arguments():
parser = argparse.ArgumentParser(
description="Create and simulate a robot in SimulationManager"
)
-
- parser.add_argument(
- "--enable_rt", action="store_true", help="Enable ray tracing rendering"
- )
- parser.add_argument("--headless", action="store_true", help="Enable headless mode")
- parser.add_argument(
- "--device",
- type=str,
- default="cpu",
- help="device to run the environment on, e.g., 'cpu' or 'cuda'",
- )
+ add_env_launcher_args_to_parser(parser)
return parser.parse_args()
@@ -73,22 +65,21 @@ def initialize_simulation(args) -> SimulationManager:
config = SimulationManagerCfg(
headless=True,
sim_device=args.device,
- enable_rt=args.enable_rt,
+ render_cfg=RenderCfg(renderer=args.renderer),
physics_dt=1.0 / 100.0,
num_envs=1,
arena_space=2.5,
)
sim = SimulationManager(config)
- if args.enable_rt:
- light = sim.add_light(
- cfg=LightCfg(
- uid="main_light",
- color=(0.6, 0.6, 0.6),
- intensity=30.0,
- init_pos=(1.0, 0, 3.0),
- )
+ light = sim.add_light(
+ cfg=LightCfg(
+ uid="main_light",
+ color=(0.6, 0.6, 0.6),
+ intensity=30.0,
+ init_pos=(1.0, 0, 3.0),
)
+ )
return sim
diff --git a/scripts/tutorials/sim/gizmo_robot.py b/scripts/tutorials/sim/gizmo_robot.py
index 1f314549..6d6613f9 100644
--- a/scripts/tutorials/sim/gizmo_robot.py
+++ b/scripts/tutorials/sim/gizmo_robot.py
@@ -23,7 +23,9 @@
import argparse
from embodichain.lab.sim import SimulationManager, SimulationManagerCfg
+from embodichain.lab.gym.utils.gym_utils import add_env_launcher_args_to_parser
from embodichain.lab.sim.cfg import (
+ RenderCfg,
RobotCfg,
URDFCfg,
JointDrivePropertiesCfg,
@@ -41,18 +43,7 @@ def main():
parser = argparse.ArgumentParser(
description="Create a simulation scene with SimulationManager"
)
- parser.add_argument(
- "--num_envs", type=int, default=1, help="Number of parallel environments"
- )
- parser.add_argument(
- "--device", type=str, default="cpu", help="Simulation device (cuda or cpu)"
- )
- parser.add_argument(
- "--enable_rt",
- action="store_true",
- default=False,
- help="Enable ray tracing for better visuals",
- )
+ add_env_launcher_args_to_parser(parser)
args = parser.parse_args()
# Configure the simulation
@@ -61,7 +52,7 @@ def main():
height=1080,
physics_dt=1.0 / 100.0,
sim_device=args.device,
- enable_rt=args.enable_rt,
+ render_cfg=RenderCfg(renderer=args.renderer),
)
sim = SimulationManager(sim_cfg)
diff --git a/scripts/tutorials/sim/import_usd.py b/scripts/tutorials/sim/import_usd.py
index 59dfac62..ada74edf 100644
--- a/scripts/tutorials/sim/import_usd.py
+++ b/scripts/tutorials/sim/import_usd.py
@@ -24,13 +24,14 @@
import time
from embodichain.lab.sim import SimulationManager, SimulationManagerCfg
-from embodichain.lab.sim.cfg import RigidBodyAttributesCfg
+from embodichain.lab.gym.utils.gym_utils import add_env_launcher_args_to_parser
+from embodichain.lab.sim.cfg import RigidBodyAttributesCfg, RenderCfg
from embodichain.lab.sim.shapes import CubeCfg, MeshCfg
from embodichain.lab.sim.objects import (
RigidObject,
RigidObjectCfg,
- ArticulationCfg,
- Articulation,
+ RobotCfg,
+ Robot,
)
from embodichain.data import get_data_path
@@ -42,15 +43,7 @@ def main():
parser = argparse.ArgumentParser(
description="Create a simulation scene with SimulationManager"
)
- parser.add_argument(
- "--headless",
- action="store_true",
- default=False,
- help="Run simulation in headless mode",
- )
- parser.add_argument(
- "--device", type=str, default="cpu", help="Simulation device (cuda or cpu)"
- )
+ add_env_launcher_args_to_parser(parser)
args = parser.parse_args()
# Configure the simulation
@@ -60,7 +53,9 @@ def main():
headless=True,
physics_dt=1.0 / 100.0, # Physics timestep (100 Hz)
sim_device=args.device,
- enable_rt=True, # Enable ray tracing for better visuals
+ render_cfg=RenderCfg(
+ renderer=args.renderer,
+ ), # Enable ray tracing for better visuals
num_envs=1,
arena_space=3.0,
)
@@ -98,12 +93,12 @@ def main():
# Add objects to the scene
h1_path = get_data_path("UnitreeH1Usd/H1_usd/h1.usd")
print(f"Loading USD file from: {h1_path}")
- h1: Articulation = sim.add_articulation(
- cfg=ArticulationCfg(
+ h1: Robot = sim.add_robot(
+ cfg=RobotCfg(
uid="h1",
fpath=h1_path,
build_pk_chain=False,
- init_pos=[-0.2, -0.2, 1.0],
+ init_pos=[-0.2, -0.2, 1.05],
use_usd_properties=False,
)
)
diff --git a/skills/add-atomic-action/SKILL.md b/skills/add-atomic-action/SKILL.md
new file mode 100644
index 00000000..9ae574a5
--- /dev/null
+++ b/skills/add-atomic-action/SKILL.md
@@ -0,0 +1,197 @@
+---
+name: add-atomic-action
+description: Use when adding a new observation, event, reward, action, dataset, or randomization functor to an EmbodiChain environment
+---
+
+# Add Atomic Action
+
+Scaffold a new atomic action following EmbodiChain's `ActionCfg` / `AtomicAction` pattern.
+
+## When to Use
+
+- User asks to add a new motion primitive (push, wipe, insert, hand-over, …)
+- User says "add a new atomic action", "create a custom action", "implement a push action"
+- User wants to extend `AtomicActionEngine` with a behaviour not covered by the built-ins
+
+## Key Files
+
+| Purpose | Path |
+|---------|------|
+| Base classes (`ActionCfg`, `AtomicAction`, `ObjectSemantics`) | `embodichain/lab/sim/atomic_actions/core.py` |
+| Built-in actions (reference implementations) | `embodichain/lab/sim/atomic_actions/actions.py` |
+| Engine + global registry (`register_action`) | `embodichain/lab/sim/atomic_actions/engine.py` |
+| Public API exports | `embodichain/lab/sim/atomic_actions/__init__.py` |
+| Reference docs | `docs/source/overview/sim/atomic_actions.md` |
+
+## Steps
+
+### 1. Define the config
+
+Add a `@configclass`-decorated class that extends `ActionCfg` (or `MoveActionCfg` /
+`GraspActionCfg` if the new action reuses arm/gripper fields).
+
+Place it in `embodichain/lab/sim/atomic_actions/actions.py` alongside the existing configs,
+or in a new file if the action is large.
+
+```python
+from embodichain.utils import configclass
+from embodichain.lab.sim.atomic_actions.core import ActionCfg # or MoveActionCfg
+
+@configclass
+class PushActionCfg(ActionCfg):
+ name: str = "push" # must match the registry key
+ push_distance: float = 0.05 # metres to push forward
+ push_speed: int = 30 # waypoints for the push phase
+ control_part: str = "arm" # robot segment to control
+```
+
+**Rules:**
+- `name` must be unique and match the string passed to `register_action`.
+- Inherit from `GraspActionCfg` when the action needs hand open/close fields.
+- All fields must have defaults — configs are instantiated without arguments in tests.
+
+### 2. Implement the action class
+
+Subclass `AtomicAction` and implement the two abstract methods.
+
+```python
+import torch
+from typing import Optional, Union
+from embodichain.lab.sim.atomic_actions.core import AtomicAction, ObjectSemantics
+
+class PushAction(AtomicAction):
+ """Push an object forward by a fixed distance."""
+
+ def __init__(self, motion_generator, cfg: PushActionCfg | None = None):
+ super().__init__(motion_generator, cfg=cfg or PushActionCfg())
+ self.arm_joint_ids = self.robot.get_joint_ids(name=self.cfg.control_part)
+
+ # ------------------------------------------------------------------
+ def execute(
+ self,
+ target: Union[torch.Tensor, ObjectSemantics],
+ start_qpos: Optional[torch.Tensor] = None,
+ **kwargs,
+ ) -> tuple[bool, torch.Tensor, list]:
+ """Plan the push motion and return a joint trajectory.
+
+ Args:
+ target: EEF pose tensor (n_envs, 4, 4) or ObjectSemantics.
+ start_qpos: Starting joint positions (n_envs, dof). Uses current
+ robot state when None.
+
+ Returns:
+ Tuple of (is_success, trajectory, joint_ids) where
+ trajectory has shape (n_envs, n_waypoints, len(joint_ids)).
+ """
+ # 1. Resolve target pose
+ # 2. Plan trajectory with self.motion_generator
+ # 3. Return result
+ return is_success, trajectory, self.arm_joint_ids
+
+ # ------------------------------------------------------------------
+ def validate(
+ self,
+ target: Union[torch.Tensor, ObjectSemantics],
+ start_qpos: Optional[torch.Tensor] = None,
+ **kwargs,
+ ) -> bool:
+ """Fast feasibility check — no trajectory generated.
+
+ Returns:
+ True if the action can be attempted.
+ """
+ return True # add IK reachability check here if needed
+```
+
+**Rules:**
+- `execute()` must always return `(is_success: bool, trajectory: Tensor, joint_ids: list)`.
+- `trajectory` shape: `(n_envs, n_waypoints, len(joint_ids))`.
+- `joint_ids` tells the engine which DOF columns the trajectory covers.
+- `validate()` must be cheap — no motion planning allowed.
+- Call `super().__init__()` — it sets `self.robot`, `self.motion_generator`, and `self.cfg`.
+
+### 3. Register the action
+
+Register the new class so `AtomicActionEngine` can discover it by name.
+
+**Option A — register at module load (built-ins style)**
+
+In `embodichain/lab/sim/atomic_actions/engine.py`, add to the `_builtin_action_map` dict:
+
+```python
+_builtin_action_map: dict[str, type[AtomicAction]] = {
+ "move": MoveAction,
+ "pickup": PickUpAction,
+ "place": PlaceAction,
+ "push": PushAction, # ← add here
+}
+```
+
+**Option B — register at runtime (custom / plugin style)**
+
+```python
+from embodichain.lab.sim.atomic_actions import register_action
+register_action("push", PushAction)
+```
+
+### 4. Export from the public API
+
+Add config and action class to `embodichain/lab/sim/atomic_actions/__init__.py`:
+
+```python
+from .actions import PushAction, PushActionCfg
+
+__all__ = [
+ ...,
+ "PushAction",
+ "PushActionCfg",
+]
+```
+
+### 5. Update the supported actions table
+
+Add a row to the table in `docs/source/overview/sim/atomic_actions.md` under
+"Supported Actions":
+
+```markdown
+| `PushAction` | `PushActionCfg` | `Tensor (4,4)` — contact pose | Approach → push forward |
+```
+
+### 6. Write a test
+
+Add a test in `tests/sim/atomic_actions/` (append to an existing file or create a new one):
+
+```python
+def test_push_action_cfg_defaults():
+ cfg = PushActionCfg()
+ assert cfg.name == "push"
+ assert cfg.push_distance == 0.05
+
+def test_push_action_validate(mock_motion_generator):
+ action = PushAction(mock_motion_generator)
+ assert action.validate(target=torch.eye(4)) is True
+```
+
+## Common Mistakes
+
+| Mistake | Fix |
+|---------|-----|
+| `name` in config doesn't match registry key | Keep `cfg.name` identical to the string in `register_action("push", ...)` |
+| Returning `trajectory` without `joint_ids` | Always return the 3-tuple `(bool, Tensor, list)` |
+| `trajectory` shape `(n_envs, dof, n_waypoints)` | Correct shape is `(n_envs, n_waypoints, dof)` |
+| Doing motion planning inside `validate()` | `validate()` must be fast — IK check only |
+| Not calling `super().__init__()` | Required to set `self.robot`, `self.motion_generator`, `self.cfg` |
+| Inheriting `MoveActionCfg` instead of `ActionCfg` | Use `MoveActionCfg` only when the action reuses arm-control fields; otherwise use `ActionCfg` |
+| Forgetting to export from `__init__.py` | Users import from the public API — missing exports cause `ImportError` |
+
+## Quick Reference
+
+| Step | Action |
+|------|--------|
+| 1 | Define `@configclass` config extending `ActionCfg` with `name` field |
+| 2 | Subclass `AtomicAction`, implement `execute()` and `validate()` |
+| 3 | Register: add to `_builtin_action_map` or call `register_action()` |
+| 4 | Export from `__init__.py` |
+| 5 | Add row to supported-actions table in overview docs |
+| 6 | Write tests for config defaults and `validate()` |
diff --git a/skills/add-functor/SKILL.md b/skills/add-functor/SKILL.md
new file mode 100644
index 00000000..6133d435
--- /dev/null
+++ b/skills/add-functor/SKILL.md
@@ -0,0 +1,156 @@
+---
+name: add-functor
+description: Use when adding a new observation, event, reward, action, dataset, or randomization functor to an EmbodiChain environment
+---
+
+# Add Functor
+
+Scaffold a new functor following EmbodiChain's Functor/FunctorCfg pattern.
+
+## When to Use
+
+- User asks to add an observation term, reward function, event handler, action term, dataset functor, or randomizer
+- User says "add a reward", "new observation", "create a randomizer", "add event functor"
+- Any new function needs to be registered in a manager config
+
+## Determine Functor Type
+
+| Functor Type | Config Class | Module File | Manager | Signature |
+|-------------|-------------|-------------|---------|-----------|
+| Observation | `ObservationCfg` (extends `FunctorCfg`) | `managers/observations.py` | `ObservationManager` | `(env, obs, entity_cfg, ...) -> Tensor` |
+| Reward | `RewardCfg` (extends `FunctorCfg`) | `managers/rewards.py` | `RewardManager` | `(env, obs, action, info, ...) -> Tensor` |
+| Event | `EventCfg` (extends `FunctorCfg`) | `managers/events.py` | `EventManager` | `(env, env_ids, ...) -> None` |
+| Action | `ActionTermCfg` (extends `FunctorCfg`) | `managers/actions.py` | `ActionManager` | Varies |
+| Dataset | `DatasetFunctorCfg` (extends `FunctorCfg`) | `managers/datasets.py` | `DatasetManager` | `(env, ...) -> dict` |
+| Randomization | `EventCfg` (randomizations ARE events) | `managers/randomization/.py` | `EventManager` | `(env, env_ids, entity_cfg, ...) -> None` |
+
+## Two Functor Styles
+
+### Function-style (Preferred for Simple Functors)
+
+A plain function with the right signature. Registered via `FunctorCfg(func=my_function, params={...})`.
+
+```python
+def my_reward(
+ env: EmbodiedEnv,
+ obs: dict,
+ action: EnvAction,
+ info: dict,
+ my_param: float = 1.0, # params become keyword args
+) -> torch.Tensor:
+ """Short one-line summary.
+
+ Longer description if needed.
+
+ Args:
+ env: The environment instance.
+ obs: The observation dictionary.
+ action: The action taken.
+ info: The info dictionary.
+ my_param: Description of this parameter.
+
+ Returns:
+ Reward tensor of shape (num_envs,).
+ """
+ # implementation
+ return result
+```
+
+### Class-style (Required When Functor Has State)
+
+A class inheriting `Functor`, with `__init__(cfg, env)` and `__call__(env, ...)`. Registered via `FunctorCfg(func=MyClass, params={...})`.
+
+```python
+class my_randomizer(Functor):
+ """One-line summary."""
+
+ def __init__(self, cfg: FunctorCfg, env: EmbodiedEnv):
+ super().__init__(cfg, env)
+ # Extract params and initialize state
+ self.entity_cfg: SceneEntityCfg = cfg.params["entity_cfg"]
+
+ def __call__(self, env: EmbodiedEnv, env_ids: torch.Tensor, **kwargs):
+ """Apply the randomization.
+
+ Args:
+ env: The environment instance.
+ env_ids: Target environment IDs.
+ """
+ # implementation
+```
+
+## Steps
+
+### 1. Identify Functor Type and Style
+
+Ask the user:
+1. **Which manager?** (observation / reward / event / action / dataset / randomization)
+2. **Function or class style?** (function for stateless, class for stateful)
+3. **What does it do?** (brief description for naming + docstring)
+
+### 2. Choose the Right Module File
+
+Place the functor in the existing module for its type:
+
+| Type | File |
+|------|------|
+| Observation | `embodichain/lab/gym/envs/managers/observations.py` |
+| Reward | `embodichain/lab/gym/envs/managers/rewards.py` |
+| Event | `embodichain/lab/gym/envs/managers/events.py` |
+| Action | `embodichain/lab/gym/envs/managers/actions.py` |
+| Dataset | `embodichain/lab/gym/envs/managers/datasets.py` |
+| Physics randomization | `embodichain/lab/gym/envs/managers/randomization/physics.py` |
+| Visual randomization | `embodichain/lab/gym/envs/managers/randomization/visual.py` |
+| Spatial randomization | `embodichain/lab/gym/envs/managers/randomization/spatial.py` |
+| Geometry randomization | `embodichain/lab/gym/envs/managers/randomization/geometry.py` |
+
+### 3. Write the Functor
+
+Follow the template for function-style or class-style (see above).
+
+Key rules:
+- First argument is always `env: EmbodiedEnv` (use `TYPE_CHECKING` guard for the import)
+- Use `from __future__ import annotations` at the top
+- Use `SceneEntityCfg` for entity references, not raw strings
+- For observation functors: add `shape` key to `FunctorCfg.extra` dict
+- For randomization functors: second arg is `env_ids: torch.Tensor | list[int]`
+- For reward functors: return shape must be `(num_envs,)`
+
+### 4. Update `__all__`
+
+Add the new functor to the module's `__all__` list. If no `__all__` exists, create one.
+
+### 5. Write a Test
+
+Place at `tests/gym/envs/managers/test_.py` (append to existing file if present).
+
+For functors that don't need a live simulation, use mock objects (`MockEnv`, `MockSim`, etc.) following the pattern in `tests/gym/envs/managers/test_reward_functors.py`.
+
+### 6. Run `black`
+
+```bash
+black embodichain/lab/gym/envs/managers/.py
+black tests/gym/envs/managers/test_.py
+```
+
+## Common Mistakes
+
+| Mistake | Fix |
+|---------|-----|
+| Wrong first argument signature | Observation: `(env, obs, ...)`, Reward: `(env, obs, action, info, ...)`, Event/Randomization: `(env, env_ids, ...)` |
+| Importing `EmbodiedEnv` at module level | Use `TYPE_CHECKING` guard to avoid circular imports |
+| Forgetting `SceneEntityCfg` for entity refs | Always use `SceneEntityCfg(uid="...")` not bare strings |
+| Returning wrong tensor shape | Rewards must return `(num_envs,)`, observations must match declared shape |
+| Missing `from __future__ import annotations` | Required in every file |
+| Class-style functor not calling `super().__init__` | Always call `super().__init__(cfg, env)` |
+| Adding randomizer as standalone | Randomizations ARE events — they go in `randomization/` but use `EventCfg` |
+
+## Quick Reference
+
+| Step | Action |
+|------|--------|
+| 1 | Identify manager type + function vs class style |
+| 2 | Write functor in the correct module file |
+| 3 | Update `__all__` in that module |
+| 4 | Write test with mocks (no sim needed for most) |
+| 5 | Run `black` on changed files |
diff --git a/skills/add-task-env/SKILL.md b/skills/add-task-env/SKILL.md
new file mode 100644
index 00000000..b6092cfc
--- /dev/null
+++ b/skills/add-task-env/SKILL.md
@@ -0,0 +1,107 @@
+---
+name: add-task-env
+description: Use when creating a new task environment for EmbodiChain, including expert demonstration tasks, RL tasks or any EmbodiedEnv subclass
+---
+
+# Add Task Environment
+
+Scaffold a new task environment following EmbodiChain's conventions and patterns.
+
+## When to Use
+
+- User asks to create a new task or environment
+- User says "add a task", "new env", "create environment for X"
+
+## Steps
+
+### 1. Determine Task Category
+
+Ask the user:
+
+- **Category**: `tableware`, `rl`, or `special` (maps to `embodichain/lab/gym/envs/tasks//`)
+- **Task name** (snake_case, e.g. `pick_place`)
+- **Gym ID** (e.g. `PickPlace-v1`)
+- **Task type**: RL task (needs reward functors) or expert demonstration task (needs `create_demo_action_list`)
+
+### 2. Create the Task File
+
+Place at `embodichain/lab/gym/envs/tasks//.py`.
+
+Template:
+
+```python
+# ----------------------------------------------------------------------------
+# Copyright (c) 2021-2026 DexForce Technology Co., Ltd.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ----------------------------------------------------------------------------
+
+from __future__ import annotations
+
+import torch
+from typing import Dict, Any, Tuple
+
+from embodichain.lab.gym.utils.registration import register_env
+from embodichain.lab.gym.envs import EmbodiedEnv, EmbodiedEnvCfg
+from embodichain.lab.sim.types import EnvObs
+
+__all__ = ["Env"]
+
+
+@register_env("")
+class Env(EmbodiedEnv):
+ """.
+
+
+ """
+
+ def __init__(self, cfg: EmbodiedEnvCfg = None, **kwargs):
+ if cfg is None:
+ cfg = EmbodiedEnvCfg()
+ super().__init__(cfg, **kwargs)
+
+ # Expert demo tasks: implement `create_demo_action_list`.
+ # RL tasks: implement `check_truncated`, `get_reward`, `compute_task_state`.
+```
+
+### 3. Update Exports
+
+Add to `embodichain/lab/gym/envs/tasks/__init__.py`:
+
+```python
+from embodichain.lab.gym.envs.tasks.. import Env
+```
+
+Add `"Env"` to the `__all__` list.
+
+### 4. Create Test Stub
+
+Place at `tests/gym/envs/tasks/test_.py`.
+
+### 5. Format
+
+```bash
+black embodichain/lab/gym/envs/tasks//.py
+black tests/gym/envs/tasks/test_.py
+```
+
+## Checklist
+
+- [ ] File has Apache 2.0 header
+- [ ] Uses `from __future__ import annotations`
+- [ ] `@register_env` decorator with unique gym ID
+- [ ] `__all__` defined in the task module
+- [ ] Default `cfg = EmbodiedEnvCfg()` in `__init__`
+- [ ] Import and `__all__` added to `tasks/__init__.py`
+- [ ] Test stub created
+- [ ] `black` run on both files
diff --git a/skills/add-test/SKILL.md b/skills/add-test/SKILL.md
new file mode 100644
index 00000000..d780154c
--- /dev/null
+++ b/skills/add-test/SKILL.md
@@ -0,0 +1,246 @@
+---
+name: add-test
+description: Use when writing tests for EmbodiChain modules, including observation functors, reward functors, solvers, sensors, environments, or any Python module
+---
+
+# Add Test
+
+Write tests following EmbodiChain's conventions and patterns.
+
+## When to Use
+
+- User asks to "add a test", "write tests for X", "test this module"
+- A new public module or function needs test coverage
+- PR checklist requires tests
+
+## Test File Location
+
+Tests mirror the source tree under `tests/`:
+
+```
+embodichain/lab/sim/solvers/pytorch_solver.py → tests/sim/solvers/test_pytorch_solver.py
+embodichain/lab/gym/envs/managers/rewards.py → tests/gym/envs/managers/test_reward_functors.py
+embodichain/toolkits/graspkit/pg_grasp/foo.py → tests/toolkits/test_pg_grasp.py
+embodichain/lab/gym/envs/tasks/rl/push_cube.py → tests/gym/envs/tasks/test_push_cube.py
+```
+
+Rules:
+- File name: `test_.py`
+- Directory path mirrors `embodichain/` structure under `tests/`
+- Create `__init__.py` files in new `tests/` subdirectories if needed
+
+## Two Test Styles
+
+### pytest Style — For Pure-Python Logic (No Sim)
+
+Use when: testing functors, utility functions, pure math, config validation — anything that doesn't need a `SimulationManager`.
+
+```python
+# ----------------------------------------------------------------------------
+# Copyright (c) 2021-2026 DexForce Technology Co., Ltd.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# ...
+# ----------------------------------------------------------------------------
+
+from __future__ import annotations
+
+import pytest
+import torch
+
+from embodichain.my_module import my_function
+
+
+def test_expected_output():
+ result = my_function(input_value)
+ assert result == expected_value
+
+
+def test_edge_case():
+ result = my_function(edge_input)
+ assert result is not None
+```
+
+### Class Style — For Sim-Dependent or Ordered Tests
+
+Use when: tests need `SimulationManager`, GPU setup, or must run in a specific order. Share state via `setup_method`/`teardown_method`.
+
+```python
+# ----------------------------------------------------------------------------
+# Copyright (c) 2021-2026 DexForce Technology Co., Ltd.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# ...
+# ----------------------------------------------------------------------------
+
+from __future__ import annotations
+
+import pytest
+import torch
+
+from embodichain.lab.sim import SimulationManager, SimulationManagerCfg
+
+
+class TestMySimComponent:
+ def setup_method(self):
+ config = SimulationManagerCfg(headless=True, sim_device="cpu")
+ self.sim = SimulationManager(config)
+ # ... setup ...
+
+ def teardown_method(self):
+ self.sim.destroy()
+
+ def test_basic_behavior(self):
+ result = self.sim.do_something()
+ assert result == expected_result
+
+ def test_raises_on_bad_input(self):
+ with pytest.raises(ValueError):
+ self.sim.do_something(bad_input)
+```
+
+## Mocking Patterns for Functor Tests
+
+Most functor tests don't need a live simulation. Use mock objects following the pattern in `tests/gym/envs/managers/test_reward_functors.py`:
+
+```python
+from unittest.mock import MagicMock, Mock
+
+
+class MockSim:
+ """Mock simulation for functor tests."""
+
+ def __init__(self, num_envs: int = 4):
+ self.num_envs = num_envs
+ self.device = torch.device("cpu")
+ self._rigid_objects: dict = {}
+
+ def get_rigid_object(self, uid: str):
+ return self._rigid_objects.get(uid)
+
+ def add_rigid_object(self, obj):
+ self._rigid_objects[obj.uid] = obj
+
+
+class MockEnv:
+ """Mock environment for functor tests."""
+
+ def __init__(self, num_envs: int = 4):
+ self.num_envs = num_envs
+ self.device = torch.device("cpu")
+ self.sim = MockSim(num_envs)
+```
+
+Key points for mock objects:
+- Set `num_envs` and `device` attributes (functors use these)
+- Mock only the sim methods the functor actually calls
+- Use `MagicMock(uid="...")` for `SceneEntityCfg` parameters
+
+## Steps
+
+### 1. Identify What to Test
+
+Ask the user:
+1. **Which module/function?** — determines file path
+2. **Does it need a live simulation?** — determines test style
+3. **Key behaviors to verify** — happy path, edge cases, error cases
+
+### 2. Determine Test File Path
+
+Map the source path to test path:
+
+```
+embodichain//.py → tests//test_.py
+```
+
+Check if the test file already exists — append new test classes/functions if so.
+
+### 3. Choose Test Style
+
+```dot
+digraph test_style {
+ rankdir=LR;
+ "Needs SimulationManager?" -> "Class style" [label="yes"];
+ "Needs SimulationManager?" -> "pytest style" [label="no"];
+ "Tests share state/order?" -> "Class style" [label="yes"];
+ "Tests share state/order?" -> "pytest style" [label="no"];
+}
+```
+
+### 4. Write the Test
+
+Use the appropriate template (pytest or class style above).
+
+Rules:
+- **Apache 2.0 header** — required on every test file
+- **`from __future__ import annotations`** — after header, before imports
+- **No magic numbers** — define expected values as named constants or comment their origin
+- **Test function names** — `test_` (descriptive, not just `test_foo`)
+- **One assertion concept per test** — don't bundle unrelated checks
+
+### 5. Add `if __name__ == "__main__"` Block
+
+Include this for tests that support optional visual/interactive debugging:
+
+```python
+if __name__ == "__main__":
+ # For visual debugging: set is_visual=True when calling env methods
+ test_obj = TestMyComponent()
+ test_obj.setup_method()
+ # ... manually run test logic ...
+```
+
+### 6. Run the Test
+
+```bash
+# Single file
+pytest tests//test_.py -v
+
+# Single test function
+pytest tests//test_.py::test_expected_output -v
+
+# Single test class method
+pytest tests//test_.py::TestMyClass::test_basic_behavior -v
+```
+
+### 7. Run `black`
+
+```bash
+black tests//test_.py
+```
+
+## Conventions Summary
+
+| Convention | Rule |
+|-----------|------|
+| File header | Apache 2.0 copyright block (same 15 lines as source) |
+| File naming | `test_.py` |
+| Function naming | `test_` |
+| `from __future__` | Required after header |
+| Magic numbers | Define as named constants with explanatory comments |
+| Simulation tests | Initialize/teardown in `setup_method`/`teardown_method` |
+| Pure-logic tests | Use mock objects, no real sim |
+| `SceneEntityCfg` | Use `MagicMock(uid="...")` in tests |
+| Assertions | `assert`, `pytest.approx`, `torch.allclose`, `pytest.raises` |
+| Entry block | `if __name__ == "__main__"` for visual debugging support |
+
+## Common Mistakes
+
+| Mistake | Fix |
+|---------|-----|
+| Missing Apache header on test file | Copy the 15-line copyright block |
+| Using real `SimulationManager` for functor tests | Use `MockEnv`/`MockSim` — much faster, no GPU needed |
+| Hardcoded numbers without explanation | Define as `EXPECTED_DISTANCE = 0.5 # cube at origin, target at (0.5, 0, 0)` |
+| Testing multiple concepts in one function | Split into separate `test_` functions |
+| Forgetting `teardown_method` | Always call `self.sim.destroy()` in teardown |
+| Not running `black` on test file | CI checks all files including tests |
+
+## Quick Reference
+
+| Action | Command |
+|--------|---------|
+| Run all tests | `pytest tests/` |
+| Run single file | `pytest tests//test_.py -v` |
+| Run single test | `pytest tests/::test_ -v` |
+| Run with print output | `pytest -s tests//test_.py` |
+| Format | `black tests//test_.py` |
diff --git a/skills/benchmark/SKILL.md b/skills/benchmark/SKILL.md
new file mode 100644
index 00000000..e95ffe05
--- /dev/null
+++ b/skills/benchmark/SKILL.md
@@ -0,0 +1,479 @@
+---
+name: benchmark
+description: Write benchmark scripts for EmbodiChain modules following project conventions
+---
+
+# EmbodiChain Benchmark Script Writer
+
+This skill guides you through writing well-structured benchmark scripts for EmbodiChain modules, covering performance measurement of solvers, samplers, metrics, and other computationally intensive components.
+
+## Usage
+
+Invoke this skill when:
+- A user asks to write or extend a benchmark script for any EmbodiChain module
+- Comparing CPU vs GPU implementations (e.g., Warp CUDA vs pure-Python)
+- Measuring throughput of samplers, metrics, FK/IK solvers, or data pipelines
+- The file path contains `scripts/benchmark/` or the word "benchmark" appears in the request
+
+## Key Conventions
+
+### File Location
+
+Place benchmark scripts under:
+
+```
+scripts/benchmark//.py
+```
+
+Examples:
+- `scripts/benchmark/robotics/kinematic_solver/opw_solver.py`
+- `scripts/benchmark/workspace_analyzer/benchmark_workspace_analyzer.py`
+
+### File Header
+
+Every benchmark file **must** begin with the Apache 2.0 copyright header followed by a module-level docstring:
+
+```python
+# ----------------------------------------------------------------------------
+# Copyright (c) 2021-2026 DexForce Technology Co., Ltd.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ----------------------------------------------------------------------------
+
+"""One-line summary of what this benchmark measures.
+
+Longer description of the optimizations or comparisons being evaluated.
+Run: python -m scripts.benchmark..
+"""
+```
+
+---
+
+## Steps
+
+### 1. Identify What to Benchmark
+
+Ask yourself:
+- **What implementations are being compared?** (e.g., Warp CUDA vs. CPU, vectorized vs. loop-based)
+- **What is the primary metric?** (wall-clock time, mean error, throughput)
+- **What sample sizes cover realistic usage?** Typically: `[100, 1000, 10000, 100000]`
+
+### 2. Structure the Script
+
+Use one helper function per concern, then a single orchestrator:
+
+```
+benchmark_() # e.g., benchmark_halton_sampler()
+benchmark_() # e.g., benchmark_density_metric()
+...
+run_all_benchmarks() # calls all of the above + prints header/footer
+```
+
+### 3. Write Individual Benchmark Functions
+
+Each benchmark function follows this pattern:
+
+```python
+def benchmark_():
+ """One-line description of what is being measured."""
+ from embodichain. import SomeClass, SomeCfg
+
+ # --- Setup (not timed) ---
+ cfg = SomeCfg(...)
+ obj = cfg.init_solver(...) # or SomeClass(cfg)
+
+ print("\n=== Benchmark ===")
+ for n in [100, 1000, 10000, 100000]:
+ # Prepare inputs (not timed)
+ inputs = ...
+
+ # --- Timed block ---
+ start = time.perf_counter()
+ result = obj.compute(inputs) # or obj.get_ik(...) etc.
+ elapsed = time.perf_counter() - start
+
+ print(f" n={n:>7d}: {elapsed*1000:>10.2f} ms (...)")
+```
+
+Key rules:
+- Use `time.perf_counter()` for high-resolution wall-clock timing, **not** `time.time()`.
+- Only time the core computation — exclude setup, data preparation, and print statements.
+- Print results in milliseconds (`elapsed * 1000`) with consistent column alignment using `>` format specs.
+
+> **Exception**: When benchmarking GPU (Warp/CUDA) code alongside a CPU baseline, it is acceptable to use `time.time()` for coarser comparison timing, as seen in `opw_solver.py`. Prefer `time.perf_counter()` for CPU-only benchmarks.
+
+### 4. Comparing Two Implementations
+
+When the benchmark compares two backends (e.g., Warp CUDA vs. Python OPW):
+
+```python
+def check_(solver_a, solver_b, n_samples=1000):
+ """Run both solvers and return timing + accuracy metrics."""
+ # shared input generation
+ qpos = ...
+
+ # --- Solver A (e.g., Warp CUDA) ---
+ start = time.time()
+ success_a, result_a = solver_a.get_ik(xpos, ...)
+ time_a = time.time() - start
+ t_err_a, r_err_a = get_poses_err(...)
+
+ # --- Solver B (e.g., CPU) ---
+ start = time.time()
+ success_b, result_b = solver_b.get_ik(xpos, ...)
+ time_b = time.time() - start
+ t_err_b, r_err_b = get_poses_err(...)
+
+ return time_a, t_err_a, r_err_a, time_b, t_err_b, r_err_b
+
+
+def benchmark_():
+ cfg = ...
+ solver_a = cfg.init_solver(device=torch.device("cuda"), ...)
+ solver_b = cfg.init_solver(device=torch.device("cpu"), ...)
+
+ for n in [100, 1000, 10000, 100000]:
+ time_a, t_err_a, r_err_a, time_b, t_err_b, r_err_b = check_(
+ solver_a, solver_b, n_samples=n
+ )
+ print(f"**** Test over {n} samples:")
+ print(f"===Impl A time: {time_a * 1000:.6f} ms")
+ print(f" Translation mean error: {t_err_a * 1000:.6f} mm")
+ print(f" Rotation mean error: {r_err_a * 180 / np.pi:.6f} degrees")
+ print(f"===Impl B time: {time_b * 1000:.6f} ms")
+ ...
+```
+
+### 5. Report Accuracy Alongside Speed
+
+For FK/IK solvers, always verify correctness by running FK on the IK output and measuring pose error:
+
+```python
+def get_pose_err(matrix_a: np.ndarray, matrix_b: np.ndarray) -> tuple[float, float]:
+ """Return (translation_error_m, rotation_error_rad)."""
+ t_err = np.linalg.norm(matrix_a[:3, 3] - matrix_b[:3, 3])
+ relative_rot = matrix_a[:3, :3].T @ matrix_b[:3, :3]
+ cos_angle = np.clip((np.trace(relative_rot) - 1) / 2.0, -1.0, 1.0)
+ r_err = np.arccos(cos_angle)
+ return t_err, r_err
+
+
+def get_poses_err(
+ matrix_a_list: list[np.ndarray], matrix_b_list: list[np.ndarray]
+) -> tuple[float, float]:
+ t_errs, r_errs = [], []
+ for a, b in zip(matrix_a_list, matrix_b_list):
+ t, r = get_pose_err(a, b)
+ t_errs.append(t)
+ r_errs.append(r)
+ return float(np.mean(t_errs)), float(np.mean(r_errs))
+```
+
+### 6. Handle Benchmarks That Require External Resources
+
+If a benchmark requires a live simulation, robot, or GPU device that may not be available, **skip gracefully** rather than raising an error:
+
+```python
+def benchmark_batch_fk():
+ """Benchmark batch FK (requires GPU robot setup)."""
+ print("\n=== Batch FK Benchmark (requires robot/simulation) ===")
+ print(" Skipped -- requires live SimulationManager and Robot.")
+ print(" To run manually, integrate with your robot setup:")
+ print(" analyzer.compute_workspace_points(joint_configs, batch_size=512)")
+```
+
+### 7. Write the Orchestrator
+
+```python
+def run_all_benchmarks():
+ """Run all benchmarks and print summary."""
+ print("=" * 60)
+ print(" Performance Benchmarks")
+ print("=" * 60)
+
+ benchmark_component_a()
+ benchmark_component_b()
+ # ...
+
+ print("\n" + "=" * 60)
+ print("Benchmarks complete.")
+ print("=" * 60)
+
+
+if __name__ == "__main__":
+ run_all_benchmarks()
+```
+
+### 8. Save Results to One Markdown Report (Required)
+
+Every benchmark script must write its final results to **one Markdown file** after execution.
+
+- Output directory recommendation: `outputs/benchmarks/`
+- File naming recommendation: `_.md`
+- Requirement: output **exactly three Markdown tables** in the report
+ 1. `Time & Memory` table (cost time + memory columns)
+ 2. `Success & Other Metrics` table (success rate + quality/accuracy/extra metrics)
+ 3. `Leaderboard` table (algorithm ranking by overall success rate, descending)
+- `Leaderboard` coverage rule: include **all algorithms evaluated in the current benchmark scope**. If a provided leaderboard artifact is incomplete, backfill missing algorithms from aggregate summaries before rendering.
+
+Use this pattern:
+
+```python
+from datetime import datetime
+from pathlib import Path
+
+
+def write_markdown_report(
+ benchmark_name: str,
+ perf_rows: list[dict[str, object]],
+ metric_rows: list[dict[str, object]],
+ leaderboard_rows: list[dict[str, object]],
+ notes: list[str] | None = None,
+) -> Path:
+ """Write benchmark results into a single markdown report file."""
+ output_dir = Path("outputs/benchmarks")
+ output_dir.mkdir(parents=True, exist_ok=True)
+
+ ts = datetime.now().strftime("%Y%m%d_%H%M%S")
+ report_path = output_dir / f"{benchmark_name}_{ts}.md"
+
+ lines: list[str] = [
+ f"# {benchmark_name} Benchmark Report",
+ "",
+ f"Generated at: {datetime.now().isoformat(timespec='seconds')}",
+ "",
+ "## Time & Memory",
+ "",
+ ]
+
+ if perf_rows:
+ perf_headers = list(perf_rows[0].keys())
+ lines.append("| " + " | ".join(perf_headers) + " |")
+ lines.append("| " + " | ".join(["---"] * len(perf_headers)) + " |")
+ for row in perf_rows:
+ lines.append("| " + " | ".join(str(row[h]) for h in perf_headers) + " |")
+ else:
+ lines.append("No time/memory rows were produced.")
+
+ lines.extend(["", "## Success & Other Metrics", ""])
+
+ if metric_rows:
+ metric_headers = list(metric_rows[0].keys())
+ lines.append("| " + " | ".join(metric_headers) + " |")
+ lines.append("| " + " | ".join(["---"] * len(metric_headers)) + " |")
+ for row in metric_rows:
+ lines.append(
+ "| " + " | ".join(str(row[h]) for h in metric_headers) + " |"
+ )
+ else:
+ lines.append("No success/metric rows were produced.")
+
+ lines.extend(["", "## Leaderboard", ""])
+
+ if leaderboard_rows:
+ leaderboard_headers = list(leaderboard_rows[0].keys())
+ lines.append("| " + " | ".join(leaderboard_headers) + " |")
+ lines.append("| " + " | ".join(["---"] * len(leaderboard_headers)) + " |")
+ for row in leaderboard_rows:
+ lines.append(
+ "| " + " | ".join(str(row[h]) for h in leaderboard_headers) + " |"
+ )
+ else:
+ lines.append("No leaderboard rows were produced.")
+
+ if notes:
+ lines.extend(["", "## Notes", ""])
+ lines.extend([f"- {note}" for note in notes])
+
+ report_path.write_text("\\n".join(lines) + "\\n", encoding="utf-8")
+ return report_path
+```
+
+And call it at the end of `run_all_benchmarks()`:
+
+```python
+def run_all_benchmarks() -> None:
+ perf_rows: list[dict[str, object]] = []
+ metric_rows: list[dict[str, object]] = []
+
+ perf_part, metric_part = benchmark_halton_sampler()
+ perf_rows.extend(perf_part)
+ metric_rows.extend(metric_part)
+ perf_part, metric_part = benchmark_density_metric()
+ perf_rows.extend(perf_part)
+ metric_rows.extend(metric_part)
+ # ...
+
+ leaderboard_rows = build_leaderboard_rows(metric_rows)
+ # `build_leaderboard_rows` should aggregate per algorithm and sort by
+ # overall success rate in descending order.
+
+ report_path = write_markdown_report(
+ benchmark_name="workspace_analyzer",
+ perf_rows=perf_rows,
+ metric_rows=metric_rows,
+ leaderboard_rows=leaderboard_rows,
+ notes=["CPU/GPU memory fields are deltas measured around timed calls."],
+ )
+ print(f"Markdown report saved: {report_path}")
+```
+
+---
+
+## Output Format Reference
+
+| Scenario | Print format |
+|----------|-------------|
+| Single implementation, many sizes | `n={n:>7d}: {elapsed*1000:>10.2f} ms \| CPU Δ={...:+.1f} MB GPU Δ={...:+.1f} MB peak GPU={...:.1f} MB` |
+| Two implementations compared | `=== time: {ms:.6f} ms` then error & memory lines indented 3 spaces |
+| Markdown report path | `Markdown report saved: outputs/benchmarks/_.md` |
+| Markdown table 1 (Time & Memory) | `| sample_size | impl | cost_time_ms | cpu_delta_mb | gpu_delta_mb | peak_gpu_mb |` |
+| Markdown table 2 (Success & Metrics) | `| sample_size | impl | success_rate | translation_err_mm | rotation_err_deg | ... |` |
+| Markdown table 3 (Leaderboard) | `| rank | algorithm | overall_success_rate | ... |` (sorted by `overall_success_rate` descending) |
+| Section header | `\n=== Benchmark ===` |
+| Top-level separator | `"=" * 60` |
+
+---
+
+## Measuring Memory Usage
+
+Always measure **both GPU VRAM and CPU RAM** alongside wall-clock time. Use the helpers below.
+
+### GPU VRAM (via PyTorch CUDA)
+
+```python
+import torch
+
+def get_gpu_memory_mb() -> float:
+ """Return current GPU VRAM allocated by PyTorch in MB."""
+ if torch.cuda.is_available():
+ return torch.cuda.memory_allocated() / 1024 ** 2
+ return 0.0
+
+# Usage pattern inside a benchmark loop:
+torch.cuda.reset_peak_memory_stats() # reset peak counter before timed block
+mem_before = get_gpu_memory_mb()
+
+start = time.perf_counter()
+result = obj.compute(inputs)
+elapsed = time.perf_counter() - start
+
+mem_after = get_gpu_memory_mb()
+peak_vram = torch.cuda.max_memory_allocated() / 1024 ** 2 # peak during timed block
+
+print(
+ f" n={n:>7d}: {elapsed*1000:>10.2f} ms | "
+ f"VRAM delta={mem_after - mem_before:+.1f} MB peak={peak_vram:.1f} MB"
+)
+```
+
+### CPU RAM (via `psutil`)
+
+```python
+import psutil, os
+
+def get_cpu_memory_mb() -> float:
+ """Return current process RSS (resident set size) in MB."""
+ process = psutil.Process(os.getpid())
+ return process.memory_info().rss / 1024 ** 2
+
+# Usage pattern:
+mem_before = get_cpu_memory_mb()
+
+start = time.perf_counter()
+result = obj.compute(inputs)
+elapsed = time.perf_counter() - start
+
+mem_after = get_cpu_memory_mb()
+
+print(
+ f" n={n:>7d}: {elapsed*1000:>10.2f} ms | "
+ f"RAM delta={mem_after - mem_before:+.1f} MB"
+)
+```
+
+### Combined Helper (recommended)
+
+For benchmarks that use both CPU and GPU, combine into a single snapshot:
+
+```python
+import os, psutil, torch
+
+def memory_snapshot() -> dict:
+ """Return a dict with current CPU RSS and GPU allocated memory in MB."""
+ process = psutil.Process(os.getpid())
+ cpu_mb = process.memory_info().rss / 1024 ** 2
+ gpu_mb = torch.cuda.memory_allocated() / 1024 ** 2 if torch.cuda.is_available() else 0.0
+ return {"cpu_mb": cpu_mb, "gpu_mb": gpu_mb}
+
+# Usage:
+torch.cuda.reset_peak_memory_stats()
+before = memory_snapshot()
+
+start = time.perf_counter()
+result = obj.compute(inputs)
+elapsed = time.perf_counter() - start
+
+after = memory_snapshot()
+peak_gpu = torch.cuda.max_memory_allocated() / 1024 ** 2
+
+print(
+ f" n={n:>7d}: {elapsed*1000:>10.2f} ms | "
+ f"CPU Δ={after['cpu_mb'] - before['cpu_mb']:+.1f} MB "
+ f"GPU Δ={after['gpu_mb'] - before['gpu_mb']:+.1f} MB peak GPU={peak_gpu:.1f} MB"
+)
+```
+
+> Add `psutil` to the project's dev-dependencies if not already present (`pip install psutil`).
+
+---
+
+## Common Imports
+
+```python
+import os
+import time
+import psutil
+import numpy as np
+import torch
+import warp as wp # only when GPU kernels are benchmarked
+from scipy.spatial.transform import Rotation # only when needed
+from typing import Tuple, List # or use built-in generics (Python ≥ 3.10)
+```
+
+---
+
+## Quick Checklist
+
+Before finishing a benchmark script:
+
+- [ ] Apache 2.0 copyright header is present
+- [ ] Module-level docstring with `Run:` line
+- [ ] Each function has a one-line docstring
+- [ ] Setup code is **outside** the timed block
+- [ ] Timing uses `time.perf_counter()` (or `time.time()` when comparing GPU/CPU coarsely)
+- [ ] CPU RAM measured with `psutil` (delta MB before/after timed block)
+- [ ] GPU VRAM measured with `torch.cuda.memory_allocated()` + `torch.cuda.max_memory_allocated()` (delta + peak)
+- [ ] `torch.cuda.reset_peak_memory_stats()` called before each timed block
+- [ ] Accuracy metrics reported alongside timing (for solver benchmarks)
+- [ ] Graceful skip for benchmarks that need unavailable hardware
+- [ ] `run_all_benchmarks()` orchestrator with formatted separators
+- [ ] Results are written to exactly one Markdown report file per run
+- [ ] Report contains exactly three Markdown tables: `Time & Memory`, `Success & Other Metrics`, and `Leaderboard`
+- [ ] `Time & Memory` table includes `cost_time_ms`, `cpu_delta_mb`, `gpu_delta_mb`, `peak_gpu_mb`
+- [ ] `Success & Other Metrics` table includes `success_rate` and domain-specific quality metrics
+- [ ] `Leaderboard` table ranks algorithms by overall success rate in descending order
+- [ ] `Leaderboard` table includes all benchmarked algorithms (missing entries are backfilled from aggregate summaries if needed)
+- [ ] Console log includes final report path
+- [ ] `if __name__ == "__main__":` entry point
+- [ ] `black .` formatting applied
diff --git a/.claude/skills/pr/SKILL.md b/skills/pr/SKILL.md
similarity index 83%
rename from .claude/skills/pr/SKILL.md
rename to skills/pr/SKILL.md
index 59c3d3b6..e31b1628 100644
--- a/.claude/skills/pr/SKILL.md
+++ b/skills/pr/SKILL.md
@@ -1,6 +1,6 @@
---
name: pr
-description: Create a pull request for EmbodiChain following the project's PR template and conventions
+description: Create a pull request for EmbodiChain following the project's PR template and conventions, including selecting proper GitHub repository labels
---
# EmbodiChain Pull Request Creator
@@ -99,6 +99,36 @@ Use the gh CLI with the proper PR template:
gh pr create --title "" --body ""
```
+### 9. Select and Apply Labels
+
+After creating the PR, select proper labels from the repository label list and apply them.
+
+First, list available labels:
+
+```bash
+gh label list
+```
+
+Then choose labels based on change type and scope. Typical mapping:
+
+- Bug fix: `bug`
+- Enhancement: `enhancement`
+- New feature: `feature`
+- Documentation update: `docs`
+- Affected area labels when available (for example): `physics`, `robot`, `agent`, `dataset`, `dexsim`
+
+Apply labels to the PR:
+
+```bash
+gh pr edit --add-label "bug" --add-label "env"
+```
+
+If needed, remove incorrect labels:
+
+```bash
+gh pr edit --remove-label ""
+```
+
## PR Template
Use this template for the PR body:
@@ -161,6 +191,8 @@ Fixes #
| `git checkout -b branch-name` | Create branch |
| `git push -u origin branch` | Push to remote |
| `gh pr create` | Create PR |
+| `gh label list` | List repository labels |
+| `gh pr edit --add-label ...` | Apply labels to PR |
## Notes
diff --git a/skills/pre-commit-check/SKILL.md b/skills/pre-commit-check/SKILL.md
new file mode 100644
index 00000000..41ec4d3d
--- /dev/null
+++ b/skills/pre-commit-check/SKILL.md
@@ -0,0 +1,158 @@
+---
+name: pre-commit-check
+description: Use before committing or creating a PR for EmbodiChain to verify code style, headers, annotations, exports, and docstrings pass CI checks
+---
+
+# Pre-Commit Check
+
+Run all local checks that the CI pipeline enforces, catching issues before pushing.
+
+## When to Use
+
+- Before creating a commit or PR
+- User says "check my changes", "pre-commit", "verify before commit", "ready to push"
+- After making any code changes to `.py` files
+
+## Steps
+
+### 1. Identify Changed Files
+
+```bash
+git diff --name-only HEAD
+git diff --name-only --cached
+git status --short
+```
+
+Collect all changed/added `.py` files.
+
+### 2. Run Black Formatting Check
+
+This is the **first CI gate** and will cause immediate failure:
+
+```bash
+black --check --diff --color ./
+```
+
+If it fails, run `black .` and review the formatting changes.
+
+### 3. Check Apache 2.0 Copyright Header
+
+Every `.py` file must begin with the 15-line copyright block. For each changed/new `.py` file, verify the first line is:
+
+```
+# ----------------------------------------------------------------------------
+```
+
+The full header template:
+
+```python
+# ----------------------------------------------------------------------------
+# Copyright (c) 2021-2026 DexForce Technology Co., Ltd.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ----------------------------------------------------------------------------
+```
+
+### 4. Check `from __future__ import annotations`
+
+Every `.py` file must have this import (after the header, before other imports). This enables `A | B` syntax and forward references.
+
+### 5. Check `__all__` in Public Modules
+
+For any new or modified module under `embodichain/`, verify it defines `__all__` listing all public symbols. Example:
+
+```python
+__all__ = ["MyClass", "my_function"]
+```
+
+Skip this check for `__init__.py` files that only re-export via `from . import *`.
+
+### 6. Check Docstrings on Public APIs
+
+For any new public function, class, or method:
+- Must have a Google-style docstring
+- Must include `Args:` section if it takes parameters
+- Must include `Returns:` section if it returns a value
+- Use `.. attention::` or `.. tip::` directives for non-obvious behavior
+
+### 7. Check Type Annotations
+
+For any new public API:
+- All parameters must have type hints
+- Return type must be annotated
+- Use `A | B` over `Union[A, B]`
+- Use `TYPE_CHECKING` guard for imports that would cause circular dependencies
+
+### 8. Check `@configclass` Usage
+
+For any new configuration class:
+- Must use `@configclass` decorator (not bare `@dataclass`)
+- Must use `from dataclasses import MISSING` for required fields
+- Import from `embodichain.utils import configclass`
+
+### 9. Check Test Coverage
+
+For any new public module or function:
+- A corresponding test must exist at `tests//test_.py`
+- Test file must also have the Apache 2.0 header
+- Report if tests are missing
+
+### 10. Summary Report
+
+Output a pass/fail summary:
+
+```
+Pre-Commit Check Results
+========================
+[PASS] Black formatting
+[PASS] Apache 2.0 headers (5/5 files)
+[FAIL] from __future__ import annotations — missing in: foo.py
+[PASS] __all__ exports
+[PASS] Docstrings on public APIs
+[PASS] Type annotations
+[PASS] @configclass usage
+[WARN] Missing tests for: bar.py
+
+Fix the above issues before committing.
+```
+
+## What CI Checks
+
+The project's CI pipeline (`.github/workflows/main.yml`) runs:
+
+1. **lint** job: `black --check --diff --color ./`
+2. **test** job: `pytest tests`
+3. **build** job: Sphinx docs build
+
+This skill covers items 1 and 2 locally. Docs build is heavier and typically only needed for documentation changes.
+
+## Common Mistakes
+
+| Mistake | Fix |
+|---------|-----|
+| Running `black` on only one file | Run `black .` on the whole project — CI checks everything |
+| Forgetting test Apache header | Test files also need the 15-line copyright block |
+| Using `Union[A, B]` | Use `A \| B` (with `from __future__ import annotations`) |
+| Using bare `@dataclass` | Use `@configclass` from `embodichain.utils` |
+| Missing `__all__` in new module | Add `__all__` with all public symbols |
+
+## Quick Reference
+
+| Check | Command/Method |
+|-------|---------------|
+| Black formatting | `black --check --diff --color ./` |
+| Auto-fix formatting | `black .` |
+| Header check | Verify first line is `# ---...---` |
+| `__future__` import | Grep for `from __future__ import annotations` |
+| `__all__` export | Grep for `__all__` in module |
+| Run tests | `pytest tests/` |
diff --git a/tests/agents/test_language_support.py b/tests/agents/test_language_support.py
new file mode 100644
index 00000000..112adb34
--- /dev/null
+++ b/tests/agents/test_language_support.py
@@ -0,0 +1,325 @@
+# ----------------------------------------------------------------------------
+# Copyright (c) 2021-2026 DexForce Technology Co., Ltd.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ----------------------------------------------------------------------------
+
+"""Tests for language support in ODS and VLA training."""
+
+import pytest
+import torch
+import tempfile
+from pathlib import Path
+
+from embodichain.lab.gym.envs.managers import (
+ LanguageCfg,
+ LanguageManager,
+ LanguageData,
+ HierarchicalLanguageData,
+ FileBasedLanguageProvider,
+ TemplateBasedLanguageProvider,
+)
+from embodichain.lab.gym.utils.gym_utils import _init_language_buffer
+
+
+class MockEnv:
+ """Mock environment for testing."""
+
+ task_name = "test_task"
+ task_description = "Complete the test task."
+
+
+class TestLanguageData:
+ """Tests for LanguageData and HierarchicalLanguageData."""
+
+ def test_language_data_creation(self):
+ """Test creating LanguageData objects."""
+ tokens = torch.tensor([1, 2, 3, 0, 0], dtype=torch.int64)
+ mask = torch.tensor([1, 1, 1, 0, 0], dtype=torch.int64)
+
+ data = LanguageData(
+ tokens=tokens,
+ attention_mask=mask,
+ raw_text="Test instruction",
+ instruction_type="imperative",
+ )
+
+ assert data.tokens.shape == (5,)
+ assert data.attention_mask.shape == (5,)
+ assert data.raw_text == "Test instruction"
+ assert data.instruction_type == "imperative"
+
+ def test_hierarchical_language_data_creation(self):
+ """Test creating HierarchicalLanguageData."""
+ task_tokens = torch.tensor([1, 2, 3, 0], dtype=torch.int64)
+ task_mask = torch.tensor([1, 1, 1, 0], dtype=torch.int64)
+
+ task_data = LanguageData(
+ tokens=task_tokens,
+ attention_mask=task_mask,
+ raw_text="Task description",
+ )
+
+ subtask_tokens = torch.tensor([4, 5, 0, 0], dtype=torch.int64)
+ subtask_mask = torch.tensor([1, 1, 0, 0], dtype=torch.int64)
+
+ subtask_data = LanguageData(
+ tokens=subtask_tokens,
+ attention_mask=subtask_mask,
+ raw_text="Subtask description",
+ )
+
+ hierarchical = HierarchicalLanguageData(
+ task_level=[task_data],
+ subtask_level=[subtask_data],
+ primitive_level=[],
+ )
+
+ assert len(hierarchical.task_level) == 1
+ assert len(hierarchical.subtask_level) == 1
+ assert len(hierarchical.primitive_level) == 0
+
+ def test_hierarchical_language_data_flatten(self):
+ """Test flattening hierarchical language data."""
+ task_data = LanguageData(
+ tokens=torch.tensor([1, 2, 0], dtype=torch.int64),
+ attention_mask=torch.tensor([1, 1, 0], dtype=torch.int64),
+ raw_text="Task",
+ )
+
+ hierarchical = HierarchicalLanguageData(
+ task_level=[task_data],
+ subtask_level=[],
+ primitive_level=[],
+ )
+
+ flattened = hierarchical.flatten()
+ assert "task" in flattened
+ assert "subtask" in flattened
+ assert "primitive" in flattened
+
+
+class TestLanguageBuffer:
+ """Tests for language buffer initialization."""
+
+ def test_init_language_buffer(self):
+ """Test initializing language buffer tensors."""
+ language_cfg = {
+ "hierarchy_levels": ["task", "subtask"],
+ "max_tokens": 256,
+ "max_instructions_per_level": 3,
+ "pad_token_id": 0,
+ "mode": "tokens",
+ }
+
+ buffer = _init_language_buffer(
+ language_cfg, batch_size=4, max_episode_steps=100, device="cpu"
+ )
+
+ # Check that expected keys are present
+ assert "task_level_tokens" in buffer
+ assert "task_level_attention_mask" in buffer
+ assert "subtask_level_tokens" in buffer
+ assert "subtask_level_attention_mask" in buffer
+
+ # Check tensor shapes
+ assert buffer["task_level_tokens"].shape == (4, 100, 3, 256)
+ assert buffer["task_level_attention_mask"].shape == (4, 100, 3, 256)
+ assert buffer["task_level_count"].shape == (4, 100)
+
+ # Check global fields
+ assert "instruction_counts" in buffer
+ assert buffer["instruction_counts"].shape == (4, 100, 3)
+ assert "change_points" in buffer
+ assert buffer["change_points"].shape == (4, 100, 3)
+ assert "hierarchy_depth" in buffer
+ assert buffer["hierarchy_depth"].shape == (4, 100)
+
+
+class TestLanguageManager:
+ """Tests for LanguageManager."""
+
+ def test_language_manager_initialization(self):
+ """Test initializing LanguageManager."""
+ cfg = LanguageCfg(
+ mode="tokens",
+ hierarchy_levels=["task", "subtask"],
+ max_tokens=256,
+ tokenizer="gpt2",
+ )
+
+ env = MockEnv()
+
+ # Test with a simple tokenizer that doesn't require external dependencies
+ try:
+ manager = LanguageManager(cfg, env)
+ assert manager.cfg == cfg
+ assert manager.env == env
+ except (ImportError, RuntimeError) as e:
+ pytest.skip(f"Tokenizer not available: {e}")
+
+ def test_create_language_data(self):
+ """Test creating LanguageData from raw text."""
+ cfg = LanguageCfg(
+ mode="tokens",
+ max_tokens=256,
+ tokenizer="gpt2",
+ )
+
+ env = MockEnv()
+
+ try:
+ manager = LanguageManager(cfg, env)
+ data = manager.create_language_data("Test instruction")
+ assert isinstance(data, LanguageData)
+ assert data.raw_text == "Test instruction"
+ except (ImportError, RuntimeError) as e:
+ pytest.skip(f"Tokenizer not available: {e}")
+
+ def test_create_hierarchical_language_data(self):
+ """Test creating hierarchical language data."""
+ cfg = LanguageCfg(
+ mode="tokens",
+ max_tokens=256,
+ tokenizer="gpt2",
+ )
+
+ env = MockEnv()
+
+ try:
+ manager = LanguageManager(cfg, env)
+ data = manager.create_hierarchical_language_data(
+ task_texts="Pick up the block.",
+ subtask_texts=["Move to block.", "Grasp block."],
+ primitive_texts=["Close gripper."],
+ )
+
+ assert isinstance(data, HierarchicalLanguageData)
+ assert len(data.task_level) == 1
+ assert len(data.subtask_level) == 2
+ assert len(data.primitive_level) == 1
+ except (ImportError, RuntimeError) as e:
+ pytest.skip(f"Tokenizer not available: {e}")
+
+ def test_to_buffer_format(self):
+ """Test converting hierarchical data to buffer format."""
+ cfg = LanguageCfg(
+ mode="tokens",
+ hierarchy_levels=["task", "subtask"],
+ max_tokens=256,
+ max_instructions_per_level=3,
+ tokenizer="gpt2",
+ )
+
+ env = MockEnv()
+
+ try:
+ manager = LanguageManager(cfg, env)
+ data = manager.create_hierarchical_language_data(
+ task_texts="Task description.",
+ subtask_texts=["Step 1.", "Step 2."],
+ )
+
+ buffer_format = data.to_buffer_format(cfg)
+
+ assert "task_level_tokens" in buffer_format
+ assert "subtask_level_tokens" in buffer_format
+ assert "instruction_counts" in buffer_format
+
+ # Check shapes
+ assert buffer_format["task_level_tokens"].shape == (3, 256)
+ assert buffer_format["subtask_level_tokens"].shape == (3, 256)
+ except (ImportError, RuntimeError) as e:
+ pytest.skip(f"Tokenizer not available: {e}")
+
+
+class TestFileBasedLanguageProvider:
+ """Tests for FileBasedLanguageProvider."""
+
+ def test_file_provider_initialization(self):
+ """Test initializing file-based provider."""
+ cfg = LanguageCfg(
+ mode="tokens",
+ max_tokens=256,
+ tokenizer="gpt2",
+ )
+
+ # Create a temporary YAML file
+ with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f:
+ f.write("""
+test_task:
+ task:
+ - "Test task description."
+ subtask:
+ - "Step 1."
+ - "Step 2."
+""")
+ temp_path = f.name
+
+ try:
+ provider = FileBasedLanguageProvider(cfg, temp_path)
+ assert provider.config_path == Path(temp_path)
+ assert "test_task" in provider.get_available_tasks()
+ finally:
+ Path(temp_path).unlink()
+
+
+class TestTemplateBasedLanguageProvider:
+ """Tests for TemplateBasedLanguageProvider."""
+
+ def test_template_provider_initialization(self):
+ """Test initializing template-based provider."""
+ cfg = LanguageCfg(
+ mode="tokens",
+ max_tokens=256,
+ tokenizer="gpt2",
+ )
+
+ templates = {
+ "test_task": {
+ "task": "Complete the {object} task.",
+ "subtasks": ["Move to {object}.", "Grasp {object}."],
+ }
+ }
+
+ provider = TemplateBasedLanguageProvider(cfg, templates)
+ assert "test_task" in provider.get_available_tasks()
+
+ def test_template_provider_get_language(self):
+ """Test getting language from templates."""
+ cfg = LanguageCfg(
+ mode="tokens",
+ max_tokens=256,
+ tokenizer="gpt2",
+ )
+
+ templates = {
+ "test_task": {
+ "task": "Pick up the {color} {object}.",
+ "subtasks": [
+ "Move to {color} {object}.",
+ "Grasp {color} {object}.",
+ ],
+ }
+ }
+
+ provider = TemplateBasedLanguageProvider(cfg, templates)
+
+ context = {"color": "red", "object": "block"}
+ language_data = provider.get_language("test_task", context)
+
+ assert isinstance(language_data, HierarchicalLanguageData)
+
+
+if __name__ == "__main__":
+ pytest.main([__file__, "-v"])
diff --git a/tests/agents/test_shared_rollout.py b/tests/agents/test_shared_rollout.py
index 37dd34fa..4701540f 100644
--- a/tests/agents/test_shared_rollout.py
+++ b/tests/agents/test_shared_rollout.py
@@ -21,6 +21,7 @@
import torch
from tensordict import TensorDict
+from embodichain.lab.sim.cfg import RenderCfg
from embodichain.agents.rl.buffer import RolloutBuffer
from embodichain.agents.rl.collector import SyncCollector
from embodichain.agents.rl.utils import flatten_dict_observation
@@ -186,7 +187,7 @@ def test_embodied_env_writes_next_fields_into_external_rollout():
env_cfg.sim_cfg = SimulationManagerCfg(
headless=True,
sim_device=torch.device("cpu"),
- enable_rt=False,
+ render_cfg=RenderCfg(renderer="hybrid"),
gpu_id=0,
)
diff --git a/tests/benchmark/test_leaderboard.py b/tests/benchmark/test_leaderboard.py
new file mode 100644
index 00000000..4412d746
--- /dev/null
+++ b/tests/benchmark/test_leaderboard.py
@@ -0,0 +1,72 @@
+# ----------------------------------------------------------------------------
+# Copyright (c) 2021-2026 DexForce Technology Co., Ltd.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ----------------------------------------------------------------------------
+
+from __future__ import annotations
+
+from benchmark.rl.metrics import build_leaderboard
+
+
+def test_build_leaderboard_ranks_higher_success_first():
+ aggregate_results = [
+ {
+ "algorithm": "ppo",
+ "task": "cart_pole",
+ "final_success_rate_mean": 0.8,
+ "final_success_rate_stable_mean": 0.7,
+ "final_reward_mean": 10.0,
+ "steps_to_success_threshold_mean": 100.0,
+ },
+ {
+ "algorithm": "ppo",
+ "task": "push_cube",
+ "final_success_rate_mean": 0.6,
+ "final_success_rate_stable_mean": 0.5,
+ "final_reward_mean": 5.0,
+ "steps_to_success_threshold_mean": 200.0,
+ },
+ {
+ "algorithm": "grpo",
+ "task": "cart_pole",
+ "final_success_rate_mean": 0.7,
+ "final_success_rate_stable_mean": 0.8,
+ "final_reward_mean": 8.0,
+ "steps_to_success_threshold_mean": 150.0,
+ },
+ {
+ "algorithm": "grpo",
+ "task": "push_cube",
+ "final_success_rate_mean": 0.5,
+ "final_success_rate_stable_mean": 0.7,
+ "final_reward_mean": 4.0,
+ "steps_to_success_threshold_mean": 250.0,
+ },
+ ]
+ run_results = [
+ {"algorithm": "ppo", "final_success_rate": 0.8},
+ {"algorithm": "ppo", "final_success_rate": 0.6},
+ {"algorithm": "grpo", "final_success_rate": 0.7},
+ {"algorithm": "grpo", "final_success_rate": 0.5},
+ ]
+
+ leaderboard = build_leaderboard(aggregate_results, run_results=run_results)
+
+ assert leaderboard[0]["algorithm"] == "grpo"
+ assert leaderboard[0]["rank"] == 1
+ assert "avg_success_rate_stable" in leaderboard[0]
+ assert "steps_to_success_threshold" in leaderboard[0]
+ assert "success_rate_std" in leaderboard[0]
+ assert "tasks" in leaderboard[0]
+ assert leaderboard[1]["algorithm"] == "ppo"
diff --git a/tests/benchmark/test_metrics.py b/tests/benchmark/test_metrics.py
new file mode 100644
index 00000000..2d4d163b
--- /dev/null
+++ b/tests/benchmark/test_metrics.py
@@ -0,0 +1,108 @@
+# ----------------------------------------------------------------------------
+# Copyright (c) 2021-2026 DexForce Technology Co., Ltd.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ----------------------------------------------------------------------------
+
+from __future__ import annotations
+
+from benchmark.rl.metrics import (
+ aggregate_runs,
+ compute_final_metric_stable,
+ compute_steps_to_threshold_first_hit,
+ compute_steps_to_threshold_sustained,
+)
+
+
+def test_compute_steps_to_threshold_first_hit_returns_first_matching_step():
+ eval_history = [
+ {"global_step": 128.0, "eval/success_rate": 0.2},
+ {"global_step": 256.0, "eval/success_rate": 0.75},
+ {"global_step": 384.0, "eval/success_rate": 0.81},
+ ]
+
+ assert (
+ compute_steps_to_threshold_first_hit(eval_history, "eval/success_rate", 0.8)
+ == 384
+ )
+
+
+def test_compute_steps_to_threshold_sustained_requires_consecutive_hits():
+ eval_history = [
+ {"global_step": 100.0, "eval/success_rate": 0.81},
+ {"global_step": 200.0, "eval/success_rate": 0.70},
+ {"global_step": 300.0, "eval/success_rate": 0.82},
+ {"global_step": 400.0, "eval/success_rate": 0.84},
+ {"global_step": 500.0, "eval/success_rate": 0.83},
+ ]
+
+ assert (
+ compute_steps_to_threshold_sustained(
+ eval_history, "eval/success_rate", 0.8, sustain_count=3
+ )
+ == 300
+ )
+
+
+def test_compute_final_metric_stable_uses_last_window():
+ eval_history = [
+ {"global_step": 100.0, "eval/success_rate": 0.2},
+ {"global_step": 200.0, "eval/success_rate": 0.4},
+ {"global_step": 300.0, "eval/success_rate": 0.6},
+ {"global_step": 400.0, "eval/success_rate": 0.8},
+ ]
+
+ assert compute_final_metric_stable(eval_history, "eval/success_rate", 2) == 0.7
+
+
+def test_aggregate_runs_groups_by_task_and_algorithm():
+ run_results = [
+ {
+ "task": "cart_pole",
+ "algorithm": "ppo",
+ "seed": 0,
+ "final_reward": 1.0,
+ "final_success_rate": 0.9,
+ "final_success_rate_stable": 0.85,
+ "final_episode_length": 50.0,
+ "training_fps": 100.0,
+ "environment_fps": 500.0,
+ "peak_gpu_memory_mb": 0.0,
+ "steps_to_success_threshold": 1000,
+ "steps_to_success_threshold_first_hit": 800,
+ },
+ {
+ "task": "cart_pole",
+ "algorithm": "ppo",
+ "seed": 1,
+ "final_reward": 3.0,
+ "final_success_rate": 0.7,
+ "final_success_rate_stable": 0.75,
+ "final_episode_length": 40.0,
+ "training_fps": 200.0,
+ "environment_fps": 700.0,
+ "peak_gpu_memory_mb": 0.0,
+ "steps_to_success_threshold": 2000,
+ "steps_to_success_threshold_first_hit": 1200,
+ },
+ ]
+
+ summaries = aggregate_runs(run_results)
+
+ assert len(summaries) == 1
+ assert summaries[0]["task"] == "cart_pole"
+ assert summaries[0]["algorithm"] == "ppo"
+ assert summaries[0]["final_reward_mean"] == 2.0
+ assert summaries[0]["final_success_rate_stable_mean"] == 0.8
+ assert summaries[0]["steps_to_success_threshold_mean"] == 1500
+ assert summaries[0]["steps_to_success_threshold_first_hit_mean"] == 1000
diff --git a/tests/benchmark/test_plots.py b/tests/benchmark/test_plots.py
new file mode 100644
index 00000000..484da225
--- /dev/null
+++ b/tests/benchmark/test_plots.py
@@ -0,0 +1,67 @@
+# ----------------------------------------------------------------------------
+# Copyright (c) 2021-2026 DexForce Technology Co., Ltd.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ----------------------------------------------------------------------------
+
+from __future__ import annotations
+
+from benchmark.rl.plots import build_plot_artifacts
+
+
+def test_build_plot_artifacts_writes_svg_files(tmp_path):
+ run_results = [
+ {
+ "task": "cart_pole",
+ "algorithm": "ppo",
+ "eval_history": [
+ {
+ "global_step": 100.0,
+ "eval/success_rate": 0.2,
+ "eval/avg_reward": 1.0,
+ },
+ {
+ "global_step": 200.0,
+ "eval/success_rate": 0.8,
+ "eval/avg_reward": 2.0,
+ },
+ ],
+ },
+ {
+ "task": "cart_pole",
+ "algorithm": "grpo",
+ "eval_history": [
+ {
+ "global_step": 100.0,
+ "eval/success_rate": 0.1,
+ "eval/avg_reward": 0.5,
+ },
+ {
+ "global_step": 200.0,
+ "eval/success_rate": 0.6,
+ "eval/avg_reward": 1.5,
+ },
+ ],
+ },
+ ]
+ leaderboard = [
+ {"algorithm": "ppo", "score": 0.8},
+ {"algorithm": "grpo", "score": 0.6},
+ ]
+
+ artifacts = build_plot_artifacts(run_results, leaderboard, tmp_path)
+
+ assert "cart_pole_success_rate" in artifacts
+ assert "leaderboard_score" in artifacts
+ for path in artifacts.values():
+ assert path.endswith(".svg")
diff --git a/tests/benchmark/test_reporting.py b/tests/benchmark/test_reporting.py
new file mode 100644
index 00000000..55784b11
--- /dev/null
+++ b/tests/benchmark/test_reporting.py
@@ -0,0 +1,97 @@
+# ----------------------------------------------------------------------------
+# Copyright (c) 2021-2026 DexForce Technology Co., Ltd.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ----------------------------------------------------------------------------
+
+from __future__ import annotations
+
+from benchmark.rl.reporting import generate_markdown_report
+
+
+def test_generate_markdown_report_writes_expected_sections(tmp_path):
+ run_results = [
+ {
+ "task": "cart_pole",
+ "algorithm": "ppo",
+ "seed": 0,
+ "final_reward": 1.5,
+ "final_success_rate": 0.8,
+ "final_success_rate_stable": 0.75,
+ "steps_to_success_threshold": 256,
+ "steps_to_success_threshold_first_hit": 128,
+ "checkpoint_path": "outputs/checkpoint.pt",
+ }
+ ]
+ aggregate_results = [
+ {
+ "task": "cart_pole",
+ "algorithm": "ppo",
+ "num_runs": 1,
+ "final_reward_mean": 1.5,
+ "final_success_rate_mean": 0.8,
+ "final_success_rate_stable_mean": 0.75,
+ "final_success_rate_std": 0.1,
+ "training_fps_mean": 100.0,
+ "environment_fps_mean": 500.0,
+ "peak_gpu_memory_mb_mean": 0.0,
+ "steps_to_success_threshold_mean": 256.0,
+ "steps_to_success_threshold_first_hit_mean": 128.0,
+ },
+ {
+ "task": "cart_pole",
+ "algorithm": "grpo",
+ "num_runs": 1,
+ "final_reward_mean": 1.7,
+ "final_success_rate_mean": 0.85,
+ "final_success_rate_stable_mean": 0.8,
+ "final_success_rate_std": 0.05,
+ "training_fps_mean": 90.0,
+ "environment_fps_mean": 480.0,
+ "peak_gpu_memory_mb_mean": 0.0,
+ "steps_to_success_threshold_mean": 200.0,
+ "steps_to_success_threshold_first_hit_mean": 160.0,
+ },
+ ]
+ leaderboard = [
+ {
+ "rank": 1,
+ "algorithm": "ppo",
+ "score": 0.8,
+ "steps_to_success_threshold": 256.0,
+ "success_rate_std": 0.1,
+ "avg_success_rate": 0.8,
+ "avg_success_rate_stable": 0.75,
+ "avg_final_reward": 1.5,
+ "tasks_covered": 1,
+ }
+ ]
+ plot_artifacts = {"cart_pole_success_rate": str(tmp_path / "plot.svg")}
+ (tmp_path / "plot.svg").write_text("", encoding="utf-8")
+
+ output_path = tmp_path / "benchmark_report.md"
+ generate_markdown_report(
+ run_results,
+ aggregate_results,
+ leaderboard,
+ plot_artifacts,
+ {"device": "cpu", "iterations": 10},
+ output_path,
+ )
+ report = output_path.read_text(encoding="utf-8")
+ assert "RL Benchmark Report" in report
+ assert "Benchmark Overview" in report
+ assert "Leaderboard" in report
+ assert "Plots" in report
+ assert "cart_pole" in report
+ assert "grpo" in report
diff --git a/tests/common.py b/tests/common.py
index 962d9f2d..bbbdc8fc 100644
--- a/tests/common.py
+++ b/tests/common.py
@@ -17,7 +17,6 @@
from unittest import TestLoader
from fnmatch import fnmatchcase
-
__all__ = ["UnittestMetaclass", "OrderedTestLoader"]
diff --git a/tests/conftest.py b/tests/conftest.py
new file mode 100644
index 00000000..d0824fd0
--- /dev/null
+++ b/tests/conftest.py
@@ -0,0 +1,86 @@
+# ----------------------------------------------------------------------------
+# Copyright (c) 2021-2026 DexForce Technology Co., Ltd.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ----------------------------------------------------------------------------
+
+import os
+import pytest
+
+
+def pytest_addoption(parser):
+ parser.addoption(
+ "--renderer",
+ action="store",
+ default="hybrid",
+ help="Specify the renderer backend: hybrid, or fast-rt",
+ )
+
+
+def pytest_configure(config):
+ renderer = config.getoption("--renderer")
+ if renderer:
+ if renderer not in ["hybrid", "fast-rt"]:
+ pytest.exit(
+ f"Invalid renderer: {renderer}. Must be one of 'hybrid', 'fast-rt'"
+ )
+
+ # Override the global default renderer in the simulation config
+ from embodichain.lab.sim import cfg
+
+ cfg.DEFAULT_RENDERER = renderer
+
+ # PREVENT IMPLICIT INITIALIZATION BY EXPLICITLY INITIALIZING DEXSIM HERE
+ import dexsim
+ import dexsim.types
+
+ # Map string to dexsim configuration types
+ renderer_map = {
+ "hybrid": dexsim.types.Renderer.HYBRID,
+ "fast-rt": dexsim.types.Renderer.FASTRT,
+ }
+ backend_map = {
+ "hybrid": dexsim.types.Backend.VULKAN,
+ "fast-rt": dexsim.types.Backend.VULKAN,
+ }
+
+ if dexsim.get_world_num() == 0:
+ sim_config = dexsim.WorldConfig()
+ sim_config.renderer = renderer_map.get(
+ renderer, dexsim.types.Renderer.HYBRID
+ )
+ sim_config.backend = backend_map.get(renderer, dexsim.types.Backend.VULKAN)
+ sim_config.open_windows = False
+ # This triggers initialization with the correct properties immediately.
+ dexsim.init_sim_engine(sim_config)
+
+
+@pytest.fixture(autouse=True, scope="function")
+def wait_scene_destruction_after_test():
+ """Ensure C++ engine scenes are fully destructed globally after each test exits."""
+ yield
+
+ # [Improvement - delayed destruction]: top-level dequeue and traceback cleanup.
+ # Pytest retains Tracebacks on failure; breaking the exception stack ensures
+ # that local variables of temporary objects on the stack can be garbage collected.
+ import sys
+ import gc
+
+ sys.last_traceback = None
+ sys.last_value = None
+ sys.last_type = None
+
+ # [Core fix]: drain the cleanup queue to consume SimManager and related objects
+ from embodichain.lab.sim.sim_manager import SimulationManager
+
+ SimulationManager.flush_cleanup_queue()
diff --git a/tests/gym/envs/managers/test_dataset_functors.py b/tests/gym/envs/managers/test_dataset_functors.py
index d18010fc..1acd54b6 100644
--- a/tests/gym/envs/managers/test_dataset_functors.py
+++ b/tests/gym/envs/managers/test_dataset_functors.py
@@ -22,7 +22,6 @@
from unittest.mock import MagicMock, Mock, patch
-
# Skip all tests if LeRobot is not available
try:
from embodichain.lab.gym.envs.managers.datasets import (
diff --git a/tests/gym/envs/test_base_env.py b/tests/gym/envs/test_base_env.py
index fbf3c0de..27767bef 100644
--- a/tests/gym/envs/test_base_env.py
+++ b/tests/gym/envs/test_base_env.py
@@ -116,15 +116,18 @@ def _extend_obs(self, obs, **kwargs):
class BaseEnvTest:
"""Shared test logic for CPU and CUDA."""
- def setup_simulation(self, sim_device):
- self.env = gym.make(
+ @classmethod
+ def setup_simulation_hook(cls, sim_device):
+ if hasattr(cls, "env"):
+ return
+ cls.env = gym.make(
"RandomReach-v1",
num_envs=NUM_ENVS,
headless=True,
device=sim_device,
)
- self.device = self.env.get_wrapper_attr("device")
- self.num_envs = self.env.get_wrapper_attr("num_envs")
+ cls.device = cls.env.get_wrapper_attr("device")
+ cls.num_envs = cls.env.get_wrapper_attr("num_envs")
def test_env_rollout(self):
"""Test environment rollout."""
@@ -168,19 +171,39 @@ def test_env_rollout(self):
assert obs.get("robot") is not None, "Expected 'robot' in the obs dict"
def teardown_method(self):
+ pass
+
+ @classmethod
+ def teardown_class(cls):
"""Clean up resources after each test method."""
- self.env.close()
+ if hasattr(cls, "env") and cls.env is not None:
+ cls.env.close()
+ import embodichain.lab.sim as om
+
+ om.SimulationManager.flush_cleanup_queue()
+ import gc
+
+ gc.collect()
+# @pytest.mark.skip(reason="Skipping tests temporarily")
class TestBaseEnvCPU(BaseEnvTest):
def setup_method(self):
- self.setup_simulation("cpu")
+ pass
+ @classmethod
+ def setup_class(cls):
+ cls.setup_simulation("cpu")
-@pytest.mark.skip(reason="Skipping CUDA tests temporarily")
+
+# @pytest.mark.skip(reason="Skipping tests temporarily")
class TestBaseEnvCUDA(BaseEnvTest):
def setup_method(self):
- self.setup_simulation("cuda")
+ pass
+
+ @classmethod
+ def setup_class(cls):
+ cls.setup_simulation("cuda")
if __name__ == "__main__":
@@ -189,3 +212,21 @@ def setup_method(self):
test_cpu.setup_method()
test_cpu.test_env_rollout()
test_cpu.teardown_method()
+
+# Patch BaseEnvTest
+import sys
+
+
+def new_setup_simulation(cls, sim_device):
+ print(">>> ENTERING setup_simulation", file=sys.stderr)
+ if hasattr(cls, "env"):
+ return
+ cls.env = gym.make(
+ "RandomReach-v1", num_envs=NUM_ENVS, headless=True, device=sim_device
+ )
+ cls.device = cls.env.get_wrapper_attr("device")
+ cls.num_envs = cls.env.get_wrapper_attr("num_envs")
+ print(">>> EXITING setup_simulation", file=sys.stderr)
+
+
+BaseEnvTest.setup_simulation = classmethod(new_setup_simulation)
diff --git a/tests/gym/envs/test_embodied_env.py b/tests/gym/envs/test_embodied_env.py
index feebdeda..9539381e 100644
--- a/tests/gym/envs/test_embodied_env.py
+++ b/tests/gym/envs/test_embodied_env.py
@@ -20,6 +20,7 @@
import numpy as np
import gymnasium as gym
+from embodichain.lab.sim.cfg import RenderCfg
from embodichain.lab.gym.envs import EmbodiedEnvCfg
from embodichain.lab.sim.objects import RigidObject, Robot
from embodichain.lab.gym.utils.gym_utils import config_to_cfg, DEFAULT_MANAGER_MODULES
@@ -27,7 +28,7 @@
from embodichain.lab.sim import SimulationManager, SimulationManagerCfg
from embodichain.data import get_data_path
-NUM_ENVS = 10
+NUM_ENVS = 2
urdf_path = get_data_path("UniversalRobots/UR5/UR5.urdf")
METADATA = {
@@ -119,13 +120,14 @@
class EmbodiedEnvTest:
"""Shared test logic for CPU and CUDA."""
- def setup_simulation(self, sim_device, enable_rt):
+ def setup_simulation(self, sim_device):
cfg: EmbodiedEnvCfg = config_to_cfg(
METADATA, manager_modules=DEFAULT_MANAGER_MODULES
)
cfg.num_envs = NUM_ENVS
cfg.sim_cfg = SimulationManagerCfg(
- headless=True, sim_device=sim_device, enable_rt=enable_rt
+ headless=True,
+ sim_device=sim_device,
)
self.env = gym.make(id=METADATA["id"], cfg=cfg)
@@ -159,22 +161,23 @@ def test_env_rollout(self):
def teardown_method(self):
"""Clean up resources after each test method."""
- self.env.close()
+ if hasattr(self, "env") and self.env is not None:
+ self.env.close()
+ import embodichain.lab.sim as om
+ om.SimulationManager.flush_cleanup_queue()
+ import gc
-@pytest.mark.skip(reason="Skipping CUDA tests temporarily")
-class TestCPU(EmbodiedEnvTest):
- def setup_method(self):
- self.setup_simulation("cpu", enable_rt=False)
+ gc.collect()
-@pytest.mark.skip(reason="Skipping CUDA tests temporarily")
-class TestCPURT(EmbodiedEnvTest):
+# @pytest.mark.skip(reason="Skipping tests temporarily")
+class TestCPU(EmbodiedEnvTest):
def setup_method(self):
- self.setup_simulation("cpu", enable_rt=True)
+ self.setup_simulation("cpu")
-@pytest.mark.skip(reason="Skipping CUDA tests temporarily")
+# @pytest.mark.skip(reason="Skipping tests temporarily")
class TestCUDA(EmbodiedEnvTest):
def setup_method(self):
- self.setup_simulation("cuda", enable_rt=False)
+ self.setup_simulation("cuda")
diff --git a/tests/sim/atomic_actions/__init__.py b/tests/sim/atomic_actions/__init__.py
new file mode 100644
index 00000000..0671165d
--- /dev/null
+++ b/tests/sim/atomic_actions/__init__.py
@@ -0,0 +1,17 @@
+# ----------------------------------------------------------------------------
+# Copyright (c) 2021-2026 DexForce Technology Co., Ltd.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ----------------------------------------------------------------------------
+
+"""Tests for atomic actions module."""
diff --git a/tests/sim/atomic_actions/test_actions.py b/tests/sim/atomic_actions/test_actions.py
new file mode 100644
index 00000000..ba7324cc
--- /dev/null
+++ b/tests/sim/atomic_actions/test_actions.py
@@ -0,0 +1,304 @@
+# ----------------------------------------------------------------------------
+# Copyright (c) 2021-2026 DexForce Technology Co., Ltd.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ----------------------------------------------------------------------------
+
+"""Tests for atomic action implementations (MoveAction, PickUpAction, PlaceAction)."""
+
+from __future__ import annotations
+
+import pytest
+import torch
+from unittest.mock import MagicMock, Mock
+
+from embodichain.lab.sim.atomic_actions.core import (
+ ActionCfg,
+ Affordance,
+ ObjectSemantics,
+)
+from embodichain.lab.sim.atomic_actions.actions import (
+ MoveAction,
+ MoveActionCfg,
+ PickUpAction,
+ PickUpActionCfg,
+ PlaceAction,
+ PlaceActionCfg,
+)
+
+# ---------------------------------------------------------------------------
+# Mock Helpers
+# ---------------------------------------------------------------------------
+
+NUM_ENVS = 2 # number of parallel environments used in tests
+ARM_DOF = 6 # typical arm joint count
+HAND_DOF = 2 # typical hand joint count
+TOTAL_DOF = ARM_DOF + HAND_DOF
+
+
+def _make_mock_robot(
+ num_envs: int = NUM_ENVS,
+ arm_dof: int = ARM_DOF,
+ hand_dof: int = HAND_DOF,
+) -> Mock:
+ """Create a mock Robot with arm and hand control parts."""
+ robot = Mock()
+ robot.device = torch.device("cpu")
+ robot.dof = arm_dof + hand_dof
+
+ def get_qpos(name=None):
+ if name == "arm":
+ return torch.zeros(num_envs, arm_dof)
+ if name == "hand":
+ return torch.zeros(num_envs, hand_dof)
+ # Full qpos
+ return torch.zeros(num_envs, arm_dof + hand_dof)
+
+ robot.get_qpos = get_qpos
+
+ def get_joint_ids(name=None):
+ if name == "arm":
+ return list(range(arm_dof))
+ if name == "hand":
+ return list(range(arm_dof, arm_dof + hand_dof))
+ return list(range(arm_dof + hand_dof))
+
+ robot.get_joint_ids = get_joint_ids
+
+ # compute_ik: return success and identity-like qpos
+ def compute_ik(pose=None, qpos_seed=None, name=None, joint_seed=None):
+ seed = joint_seed if joint_seed is not None else qpos_seed
+ if seed is None:
+ seed = torch.zeros(num_envs, arm_dof)
+ success = torch.ones(num_envs, dtype=torch.bool)
+ return success, seed.clone()
+
+ robot.compute_ik = compute_ik
+
+ # compute_fk: return identity-like poses
+ def compute_fk(qpos=None, name=None, to_matrix=True):
+ n = qpos.shape[0] if qpos is not None else num_envs
+ poses = torch.eye(4).unsqueeze(0).repeat(n, 1, 1)
+ return poses
+
+ robot.compute_fk = compute_fk
+
+ return robot
+
+
+def _make_mock_motion_generator(robot: Mock | None = None) -> Mock:
+ """Create a mock MotionGenerator."""
+ mg = Mock()
+ mg.robot = robot or _make_mock_robot()
+ mg.device = mg.robot.device
+ return mg
+
+
+# ---------------------------------------------------------------------------
+# MoveAction
+# ---------------------------------------------------------------------------
+
+
+class TestMoveActionHelpers:
+ """Tests for MoveAction helper methods that don't need simulation."""
+
+ def setup_method(self):
+ self.robot = _make_mock_robot()
+ self.mg = _make_mock_motion_generator(self.robot)
+ self.cfg = MoveActionCfg(sample_interval=50)
+ self.action = MoveAction(self.mg, cfg=self.cfg)
+
+ def test_init_sets_attributes(self):
+ assert self.action.n_envs == NUM_ENVS
+ assert self.action.dof == ARM_DOF
+ assert self.action.device == torch.device("cpu")
+
+ def test_resolve_pose_target_from_4x4(self):
+ target = torch.eye(4)
+ is_success, result = self.action._resolve_pose_target(
+ target, action_name="TestAction"
+ )
+ assert is_success is True
+ assert result.shape == (NUM_ENVS, 4, 4)
+ # Single pose should be repeated for all envs
+ for i in range(NUM_ENVS):
+ assert torch.equal(result[i], torch.eye(4))
+
+ def test_resolve_pose_target_from_batched(self):
+ target = torch.eye(4).unsqueeze(0).repeat(NUM_ENVS, 1, 1)
+ target[:, 2, 3] = 0.5 # offset z for each env
+ is_success, result = self.action._resolve_pose_target(
+ target, action_name="TestAction"
+ )
+ assert is_success is True
+ assert result.shape == (NUM_ENVS, 4, 4)
+ for i in range(NUM_ENVS):
+ assert result[i, 2, 3].item() == pytest.approx(0.5)
+
+ def test_resolve_start_qpos_defaults_to_current(self):
+ result = self.action._resolve_start_qpos(None)
+ assert result.shape == (NUM_ENVS, ARM_DOF)
+
+ def test_resolve_start_qpos_broadcasts_single(self):
+ single = torch.ones(ARM_DOF)
+ result = self.action._resolve_start_qpos(single)
+ assert result.shape == (NUM_ENVS, ARM_DOF)
+ for i in range(NUM_ENVS):
+ assert torch.equal(result[i], single)
+
+ def test_compute_three_phase_waypoints_sums_to_sample_interval(self):
+ hand_interp_steps = 5
+ first, second, third = self.action._compute_three_phase_waypoints(
+ hand_interp_steps,
+ first_phase_name="approach",
+ third_phase_name="lift",
+ )
+ assert first + second + third == self.cfg.sample_interval
+ assert first >= 2
+ assert third >= 2
+
+ def test_interpolate_hand_qpos_shape(self):
+ n_waypoints = 10
+ start = torch.zeros(HAND_DOF)
+ end = torch.ones(HAND_DOF)
+ result = self.action._interpolate_hand_qpos(start, end, n_waypoints)
+ assert result.shape == (n_waypoints, HAND_DOF)
+ # First and last should match endpoints
+ assert torch.allclose(result[0], start)
+ assert torch.allclose(result[-1], end)
+
+ def test_interpolate_hand_qpos_linear(self):
+ """Verify linear interpolation between two hand configs."""
+ n_waypoints = 3
+ start = torch.tensor([0.0, 0.0])
+ end = torch.tensor([1.0, 1.0])
+ result = self.action._interpolate_hand_qpos(start, end, n_waypoints)
+ expected_mid = torch.tensor([0.5, 0.5])
+ assert torch.allclose(result[1], expected_mid, atol=1e-6)
+
+
+# ---------------------------------------------------------------------------
+# PickUpAction
+# ---------------------------------------------------------------------------
+
+
+class TestPickUpActionInit:
+ """Tests for PickUpAction initialization and config validation."""
+
+ def setup_method(self):
+ self.robot = _make_mock_robot()
+ self.mg = _make_mock_motion_generator(self.robot)
+
+ def _make_cfg(self, **overrides):
+ defaults = dict(
+ hand_open_qpos=torch.tensor([0.0, 0.0]),
+ hand_close_qpos=torch.tensor([0.025, 0.025]),
+ control_part="arm",
+ hand_control_part="hand",
+ pre_grasp_distance=0.15,
+ lift_height=0.15,
+ approach_direction=torch.tensor([0.0, 0.0, -1.0]),
+ )
+ defaults.update(overrides)
+ return PickUpActionCfg(**defaults)
+
+ def test_init_sets_hand_joint_ids(self):
+ cfg = self._make_cfg()
+ action = PickUpAction(self.mg, cfg=cfg)
+ assert action.hand_joint_ids == list(range(ARM_DOF, ARM_DOF + HAND_DOF))
+ assert action.joint_ids == list(range(ARM_DOF)) + list(
+ range(ARM_DOF, ARM_DOF + HAND_DOF)
+ )
+ assert action.dof == TOTAL_DOF
+
+
+# ---------------------------------------------------------------------------
+# PlaceAction
+# ---------------------------------------------------------------------------
+
+
+class TestPlaceActionInit:
+ """Tests for PlaceAction initialization."""
+
+ def setup_method(self):
+ self.robot = _make_mock_robot()
+ self.mg = _make_mock_motion_generator(self.robot)
+
+ def _make_cfg(self, **overrides):
+ defaults = dict(
+ hand_open_qpos=torch.tensor([0.0, 0.0]),
+ hand_close_qpos=torch.tensor([0.025, 0.025]),
+ control_part="arm",
+ hand_control_part="hand",
+ lift_height=0.15,
+ )
+ defaults.update(overrides)
+ return PlaceActionCfg(**defaults)
+
+ def test_init_sets_hand_joint_ids(self):
+ cfg = self._make_cfg()
+ action = PlaceAction(self.mg, cfg=cfg)
+ assert action.hand_joint_ids == list(range(ARM_DOF, ARM_DOF + HAND_DOF))
+ assert action.dof == TOTAL_DOF
+
+
+# ---------------------------------------------------------------------------
+# AtomicAction._apply_offset
+# ---------------------------------------------------------------------------
+
+
+class TestAtomicActionApplyOffset:
+ """Tests for the shared _apply_offset method inherited from AtomicAction."""
+
+ def setup_method(self):
+ self.robot = _make_mock_robot()
+ self.mg = _make_mock_motion_generator(self.robot)
+ self.cfg = MoveActionCfg()
+ self.action = MoveAction(self.mg, cfg=self.cfg)
+
+ def test_apply_offset_batched(self):
+ # [N, 4, 4] poses, [N, 3] offsets
+ poses = torch.eye(4).unsqueeze(0).repeat(3, 1, 1)
+ offsets = torch.tensor([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]])
+ result = self.action._apply_offset(poses, offsets)
+ assert result.shape == (3, 4, 4)
+ assert result[0, :3, 3].tolist() == pytest.approx([1.0, 0.0, 0.0])
+ assert result[1, :3, 3].tolist() == pytest.approx([0.0, 1.0, 0.0])
+ assert result[2, :3, 3].tolist() == pytest.approx([0.0, 0.0, 1.0])
+
+ def test_apply_offset_broadcasts_single_offset(self):
+ # [N, 4, 4] poses, [3] single offset broadcast to all
+ poses = torch.eye(4).unsqueeze(0).repeat(2, 1, 1)
+ offset = torch.tensor([0.1, 0.2, 0.3])
+ result = self.action._apply_offset(poses, offset)
+ assert result.shape == (2, 4, 4)
+ for i in range(2):
+ assert result[i, :3, 3].tolist() == pytest.approx([0.1, 0.2, 0.3])
+
+ def test_apply_offset_preserves_rotation(self):
+ """Offset only affects translation; rotation part stays unchanged."""
+ poses = torch.eye(4).unsqueeze(0).repeat(1, 1, 1)
+ # Set a non-trivial rotation
+ poses[0, 0, 1] = -1.0
+ poses[0, 1, 0] = 1.0
+ offset = torch.tensor([1.0, 2.0, 3.0])
+ result = self.action._apply_offset(poses, offset)
+ # Rotation block should be unchanged
+ assert torch.equal(result[0, :3, :3], poses[0, :3, :3])
+
+
+if __name__ == "__main__":
+ # For visual debugging
+ test = TestMoveActionHelpers()
+ test.setup_method()
+ test.test_compute_three_phase_waypoints_sums_to_sample_interval()
diff --git a/tests/sim/atomic_actions/test_core.py b/tests/sim/atomic_actions/test_core.py
new file mode 100644
index 00000000..7cebaa7b
--- /dev/null
+++ b/tests/sim/atomic_actions/test_core.py
@@ -0,0 +1,171 @@
+# ----------------------------------------------------------------------------
+# Copyright (c) 2021-2026 DexForce Technology Co., Ltd.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ----------------------------------------------------------------------------
+
+"""Tests for atomic action core module (Affordance, InteractionPoints, ObjectSemantics, ActionCfg)."""
+
+from __future__ import annotations
+
+import pytest
+import torch
+
+from embodichain.lab.sim.atomic_actions.core import (
+ ActionCfg,
+ Affordance,
+ InteractionPoints,
+ ObjectSemantics,
+)
+
+# ---------------------------------------------------------------------------
+# Affordance
+# ---------------------------------------------------------------------------
+
+
+class TestAffordance:
+ """Tests for the Affordance base dataclass."""
+
+ def test_default_values(self):
+ aff = Affordance()
+ assert aff.object_label == ""
+ assert aff.geometry == {}
+ assert aff.custom_config == {}
+
+ def test_mesh_vertices_returns_tensor(self):
+ vertices = torch.randn(10, 3)
+ aff = Affordance(geometry={"mesh_vertices": vertices})
+ assert torch.equal(aff.mesh_vertices, vertices)
+
+ def test_mesh_vertices_returns_none_when_missing(self):
+ aff = Affordance()
+ assert aff.mesh_vertices is None
+
+ def test_mesh_vertices_raises_on_wrong_type(self):
+ aff = Affordance(geometry={"mesh_vertices": [1, 2, 3]})
+ with pytest.raises(TypeError, match="must be a torch.Tensor"):
+ _ = aff.mesh_vertices
+
+ def test_mesh_triangles_returns_tensor(self):
+ triangles = torch.randint(0, 10, (5, 3))
+ aff = Affordance(geometry={"mesh_triangles": triangles})
+ assert torch.equal(aff.mesh_triangles, triangles)
+
+ def test_mesh_triangles_returns_none_when_missing(self):
+ aff = Affordance()
+ assert aff.mesh_triangles is None
+
+ def test_mesh_triangles_raises_on_wrong_type(self):
+ aff = Affordance(geometry={"mesh_triangles": "bad"})
+ with pytest.raises(TypeError, match="must be a torch.Tensor"):
+ _ = aff.mesh_triangles
+
+ def test_custom_config_get_set(self):
+ aff = Affordance()
+ aff.set_custom_config("key_a", 42)
+ assert aff.get_custom_config("key_a") == 42
+ assert aff.get_custom_config("missing") is None
+ assert aff.get_custom_config("missing", "default") == "default"
+
+ def test_get_batch_size_returns_one(self):
+ # Base Affordance always returns 1
+ assert Affordance().get_batch_size() == 1
+
+
+# ---------------------------------------------------------------------------
+# InteractionPoints
+# ---------------------------------------------------------------------------
+
+
+class TestInteractionPoints:
+ """Tests for InteractionPoints affordance."""
+
+ def test_default_points_shape(self):
+ ip = InteractionPoints()
+ assert ip.points.shape == (1, 3)
+
+ def test_get_batch_size_matches_points(self):
+ points = torch.randn(5, 3)
+ ip = InteractionPoints(points=points)
+ assert ip.get_batch_size() == 5
+
+ def test_get_points_by_type_returns_matching_subset(self):
+ points = torch.tensor([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]])
+ ip = InteractionPoints(points=points, point_types=["push", "poke", "push"])
+ result = ip.get_points_by_type("push")
+ assert result is not None
+ assert result.shape == (2, 3)
+ assert torch.equal(result[0], points[0])
+ assert torch.equal(result[1], points[2])
+
+ def test_get_points_by_type_returns_none_for_missing_type(self):
+ ip = InteractionPoints(points=torch.zeros(2, 3), point_types=["push", "push"])
+ assert ip.get_points_by_type("poke") is None
+
+ def test_get_approach_direction_from_normals(self):
+ normals = torch.tensor([[0.0, 0.0, 1.0], [1.0, 0.0, 0.0]])
+ ip = InteractionPoints(points=torch.zeros(2, 3), normals=normals)
+ # Approach is opposite of normal
+ assert torch.equal(ip.get_approach_direction(0), torch.tensor([0.0, 0.0, -1.0]))
+ assert torch.equal(ip.get_approach_direction(1), torch.tensor([-1.0, 0.0, 0.0]))
+
+ def test_get_approach_direction_default_without_normals(self):
+ ip = InteractionPoints(points=torch.zeros(1, 3))
+ direction = ip.get_approach_direction(0)
+ assert torch.equal(direction, torch.tensor([0.0, 0.0, 1.0]))
+
+
+# ---------------------------------------------------------------------------
+# ObjectSemantics
+# ---------------------------------------------------------------------------
+
+
+class TestObjectSemantics:
+ """Tests for ObjectSemantics dataclass."""
+
+ def test_post_init_binds_label_and_geometry(self):
+ geometry = {"bounding_box": [0.1, 0.2, 0.3]}
+ aff = Affordance()
+ sem = ObjectSemantics(
+ affordance=aff,
+ geometry=geometry,
+ label="mug",
+ )
+ assert sem.affordance.object_label == "mug"
+ assert sem.affordance.geometry is geometry
+
+ def test_default_optional_fields(self):
+ sem = ObjectSemantics(
+ affordance=Affordance(),
+ geometry={},
+ )
+ assert sem.label == "none"
+ assert sem.properties == {}
+ assert sem.entity is None
+
+
+# ---------------------------------------------------------------------------
+# ActionCfg
+# ---------------------------------------------------------------------------
+
+
+class TestActionCfg:
+ """Tests for ActionCfg defaults."""
+
+ def test_default_values(self):
+ cfg = ActionCfg()
+ assert cfg.name == "default"
+ assert cfg.control_part == "arm"
+ assert cfg.interpolation_type == "linear"
+ assert cfg.velocity_limit is None
+ assert cfg.acceleration_limit is None
diff --git a/tests/sim/atomic_actions/test_engine.py b/tests/sim/atomic_actions/test_engine.py
new file mode 100644
index 00000000..52dc034d
--- /dev/null
+++ b/tests/sim/atomic_actions/test_engine.py
@@ -0,0 +1,191 @@
+# ----------------------------------------------------------------------------
+# Copyright (c) 2021-2026 DexForce Technology Co., Ltd.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ----------------------------------------------------------------------------
+
+"""Tests for atomic action engine (registry, SemanticAnalyzer, AtomicActionEngine)."""
+
+from __future__ import annotations
+
+import pytest
+import torch
+from unittest.mock import MagicMock, Mock
+
+from embodichain.lab.sim.atomic_actions.core import (
+ ActionCfg,
+ Affordance,
+ ObjectSemantics,
+)
+from embodichain.lab.sim.atomic_actions.engine import (
+ AtomicActionEngine,
+ SemanticAnalyzer,
+ get_registered_actions,
+ register_action,
+ unregister_action,
+)
+
+# ---------------------------------------------------------------------------
+# Global Action Registry
+# ---------------------------------------------------------------------------
+
+
+class TestGlobalRegistry:
+ """Tests for register_action / unregister_action / get_registered_actions."""
+
+ def teardown_method(self):
+ # Clean up any test registrations
+ unregister_action("_test_dummy")
+
+ def test_register_and_retrieve(self):
+ mock_cls = Mock()
+ register_action("_test_dummy", mock_cls)
+ registry = get_registered_actions()
+ assert "_test_dummy" in registry
+ assert registry["_test_dummy"] is mock_cls
+
+ def test_unregister_removes_entry(self):
+ register_action("_test_dummy", Mock())
+ unregister_action("_test_dummy")
+ assert "_test_dummy" not in get_registered_actions()
+
+ def test_unregister_nonexistent_is_noop(self):
+ # Should not raise
+ unregister_action("_nonexistent_action")
+
+ def test_get_registered_actions_returns_copy(self):
+ """Mutating the returned dict should not affect the global registry."""
+ result = get_registered_actions()
+ result["_should_not_persist"] = Mock()
+ assert "_should_not_persist" not in get_registered_actions()
+
+
+# ---------------------------------------------------------------------------
+# SemanticAnalyzer
+# ---------------------------------------------------------------------------
+
+
+class TestSemanticAnalyzer:
+ """Tests for SemanticAnalyzer."""
+
+ def setup_method(self):
+ self.analyzer = SemanticAnalyzer()
+
+ def test_analyze_returns_object_semantics(self):
+ sem = self.analyzer.analyze("mug")
+ assert isinstance(sem, ObjectSemantics)
+ assert sem.label == "mug"
+ assert isinstance(sem.affordance, Affordance)
+
+ def test_analyze_caches_by_default(self):
+ sem1 = self.analyzer.analyze("bottle")
+ sem2 = self.analyzer.analyze("bottle")
+ assert sem1 is sem2
+
+ def test_analyze_bypasses_cache_with_geometry(self):
+ sem1 = self.analyzer.analyze("bottle")
+ sem2 = self.analyzer.analyze(
+ "bottle", geometry={"bounding_box": [0.2, 0.2, 0.2]}
+ )
+ assert sem1 is not sem2
+
+ def test_analyze_no_cache(self):
+ sem1 = self.analyzer.analyze("cup", use_cache=False)
+ sem2 = self.analyzer.analyze("cup", use_cache=False)
+ assert sem1 is not sem2
+
+ def test_clear_cache(self):
+ self.analyzer.analyze("can")
+ self.analyzer.clear_cache()
+ # After clearing, a new object should be created
+ sem1 = self.analyzer.analyze("can")
+ sem2 = self.analyzer.analyze("can")
+ assert sem1 is sem2 # re-cached after clear
+
+
+# ---------------------------------------------------------------------------
+# AtomicActionEngine._resolve_target
+# ---------------------------------------------------------------------------
+
+
+class TestResolveTarget:
+ """Tests for AtomicActionEngine._resolve_target with various input types."""
+
+ def setup_method(self):
+ self.robot = Mock()
+ self.robot.device = torch.device("cpu")
+ self.robot.dof = 6
+ self.robot.get_qpos.return_value = torch.zeros(1, 6)
+ self.robot.get_joint_ids.return_value = list(range(6))
+
+ self.mg = Mock()
+ self.mg.robot = self.robot
+ self.mg.device = torch.device("cpu")
+
+ self.engine = AtomicActionEngine(self.mg, actions_cfg_list=[])
+
+ def test_tensor_passthrough(self):
+ tensor = torch.eye(4)
+ result = self.engine._resolve_target(tensor)
+ assert result is tensor
+
+ def test_object_semantics_passthrough(self):
+ sem = ObjectSemantics(affordance=Affordance(), geometry={})
+ result = self.engine._resolve_target(sem)
+ assert result is sem
+
+ def test_string_resolved_via_semantic_analyzer(self):
+ result = self.engine._resolve_target("mug")
+ assert isinstance(result, ObjectSemantics)
+ assert result.label == "mug"
+
+ def test_dict_with_pose_key(self):
+ pose = torch.eye(4)
+ result = self.engine._resolve_target({"pose": pose})
+ assert result is pose
+
+ def test_dict_with_pose_raises_on_non_tensor(self):
+ with pytest.raises(TypeError, match="must be a torch.Tensor"):
+ self.engine._resolve_target({"pose": "not_a_tensor"})
+
+ def test_dict_with_semantics_key(self):
+ sem = ObjectSemantics(affordance=Affordance(), geometry={}, label="bottle")
+ result = self.engine._resolve_target({"semantics": sem})
+ assert result is sem
+
+ def test_dict_with_semantics_raises_on_wrong_type(self):
+ with pytest.raises(TypeError, match="must be an ObjectSemantics"):
+ self.engine._resolve_target({"semantics": "wrong"})
+
+ def test_dict_with_label_uses_analyzer(self):
+ result = self.engine._resolve_target({"label": "apple"})
+ assert isinstance(result, ObjectSemantics)
+ assert result.label == "apple"
+
+ def test_dict_without_label_raises(self):
+ with pytest.raises(ValueError, match="must provide 'label'"):
+ self.engine._resolve_target({"geometry": {}})
+
+ def test_dict_with_non_string_label_raises(self):
+ with pytest.raises(TypeError, match="must be a string"):
+ self.engine._resolve_target({"label": 123})
+
+ def test_unsupported_type_raises(self):
+ with pytest.raises(TypeError, match="target must be"):
+ self.engine._resolve_target(42)
+
+
+if __name__ == "__main__":
+ test = TestSemanticAnalyzer()
+ test.setup_method()
+ test.test_analyze_returns_object_semantics()
diff --git a/tests/sim/objects/test_articulation.py b/tests/sim/objects/test_articulation.py
index 8140b775..6f2dc692 100644
--- a/tests/sim/objects/test_articulation.py
+++ b/tests/sim/objects/test_articulation.py
@@ -248,6 +248,13 @@ def test_get_joint_drive_with_joint_ids(self):
def teardown_method(self):
"""Clean up resources after each test method."""
self.sim.destroy()
+ import embodichain.lab.sim as om
+
+ om.SimulationManager.flush_cleanup_queue()
+ self.__dict__.clear()
+ import gc
+
+ gc.collect()
class TestArticulationCPU(BaseArticulationTest):
@@ -255,7 +262,6 @@ def setup_method(self):
self.setup_simulation("cpu")
-@pytest.mark.skip(reason="Skipping CUDA tests temporarily")
class TestArticulationCUDA(BaseArticulationTest):
def setup_method(self):
self.setup_simulation("cuda")
diff --git a/tests/sim/objects/test_cloth_object.py b/tests/sim/objects/test_cloth_object.py
index d7182b66..afa182e5 100644
--- a/tests/sim/objects/test_cloth_object.py
+++ b/tests/sim/objects/test_cloth_object.py
@@ -68,7 +68,6 @@ def setup_simulation(self):
headless=True,
physics_dt=1.0 / 100.0, # Physics timestep (100 Hz)
sim_device="cuda",
- enable_rt=False, # Enable ray tracing for better visuals
num_envs=4,
arena_space=3.0,
)
@@ -133,6 +132,13 @@ def test_get_current_vertex_positions(self):
def teardown_method(self):
"""Clean up resources after each test method."""
self.sim.destroy()
+ import embodichain.lab.sim as om
+
+ om.SimulationManager.flush_cleanup_queue()
+ self.__dict__.clear()
+ import gc
+
+ gc.collect()
class TestSoftObjectCUDA(BaseSoftObjectTest):
diff --git a/tests/sim/objects/test_light.py b/tests/sim/objects/test_light.py
index ac3b70cc..7e9d58c4 100644
--- a/tests/sim/objects/test_light.py
+++ b/tests/sim/objects/test_light.py
@@ -152,3 +152,10 @@ def test_set_and_get_local_pose_matrix_and_vector(self):
def teardown_method(self):
"""Clean up resources after each test method."""
self.sim.destroy()
+ import embodichain.lab.sim as om
+
+ om.SimulationManager.flush_cleanup_queue()
+ self.__dict__.clear()
+ import gc
+
+ gc.collect()
diff --git a/tests/sim/objects/test_rigid_object.py b/tests/sim/objects/test_rigid_object.py
index 55bc73a9..5beebe26 100644
--- a/tests/sim/objects/test_rigid_object.py
+++ b/tests/sim/objects/test_rigid_object.py
@@ -29,6 +29,8 @@
from embodichain.data import get_data_path
from dexsim.types import ActorType
+from embodichain.lab.sim.cfg import RenderCfg, RigidObjectCfg
+
DUCK_PATH = "ToyDuck/toy_duck.glb"
TABLE_PATH = "ShopTableSimple/shop_table_simple.ply"
CHAIR_PATH = "Chair/chair.glb"
@@ -44,7 +46,7 @@ def setup_simulation(self, sim_device):
headless=True, sim_device=sim_device, num_envs=NUM_ARENAS
)
self.sim = SimulationManager(config)
-
+ self.sim.enable_physics(False)
duck_path = get_data_path(DUCK_PATH)
assert os.path.isfile(duck_path)
table_path = get_data_path(TABLE_PATH)
@@ -235,6 +237,44 @@ def test_set_velocity(self):
duck_ang_vel, ang_vel
), f"Angular velocity not set correctly: expected {ang_vel}, got {duck_ang_vel}"
+ def test_get_acceleration(self):
+ """Test that lin_acc, ang_acc, and acc return correct shapes and values."""
+
+ # Apply a force to generate non-zero acceleration
+ force = (
+ torch.tensor([10.0, 0.0, 0.0], device=self.sim.device)
+ .unsqueeze(0)
+ .repeat(NUM_ARENAS, 1)
+ )
+ self.duck.add_force_torque(force=force)
+ self.sim.update(0.01)
+
+ # Read back accelerations
+ duck_lin_acc = self.duck.body_data.lin_acc
+ duck_ang_acc = self.duck.body_data.ang_acc
+ duck_acc = self.duck.body_data.acc
+
+ assert duck_lin_acc.shape == (
+ NUM_ARENAS,
+ 3,
+ ), f"Linear acceleration shape mismatch: expected ({NUM_ARENAS}, 3), got {duck_lin_acc.shape}"
+ assert duck_ang_acc.shape == (
+ NUM_ARENAS,
+ 3,
+ ), f"Angular acceleration shape mismatch: expected ({NUM_ARENAS}, 3), got {duck_ang_acc.shape}"
+ assert duck_acc.shape == (
+ NUM_ARENAS,
+ 6,
+ ), f"Concatenated acceleration shape mismatch: expected ({NUM_ARENAS}, 6), got {duck_acc.shape}"
+
+ # Verify concatenated acceleration matches individual components
+ assert torch.allclose(
+ duck_acc[:, :3], duck_lin_acc
+ ), "First 3 columns of acc should match lin_acc"
+ assert torch.allclose(
+ duck_acc[:, 3:], duck_ang_acc
+ ), "Last 3 columns of acc should match ang_acc"
+
def test_set_visual_material(self):
"""Test that set_material correctly assigns the material to the duck."""
@@ -541,6 +581,13 @@ def test_misc_properties(self):
def teardown_method(self):
"""Clean up resources after each test method."""
self.sim.destroy()
+ import embodichain.lab.sim as om
+
+ om.SimulationManager.flush_cleanup_queue()
+ self.__dict__.clear()
+ import gc
+
+ gc.collect()
class TestRigidObjectCPU(BaseRigidObjectTest):
@@ -548,7 +595,6 @@ def setup_method(self):
self.setup_simulation("cpu")
-@pytest.mark.skip(reason="Skipping CUDA tests temporarily")
class TestRigidObjectCUDA(BaseRigidObjectTest):
def setup_method(self):
self.setup_simulation("cuda")
diff --git a/tests/sim/objects/test_rigid_object_group.py b/tests/sim/objects/test_rigid_object_group.py
index b6802743..896f5ad3 100644
--- a/tests/sim/objects/test_rigid_object_group.py
+++ b/tests/sim/objects/test_rigid_object_group.py
@@ -119,6 +119,13 @@ def test_set_visible(self):
def teardown_method(self):
"""Clean up resources after each test method."""
self.sim.destroy()
+ import embodichain.lab.sim as om
+
+ om.SimulationManager.flush_cleanup_queue()
+ self.__dict__.clear()
+ import gc
+
+ gc.collect()
class TestRigidObjectGroupCPU(BaseRigidObjectGroupTest):
@@ -126,7 +133,6 @@ def setup_method(self):
self.setup_simulation("cpu")
-# TODO: Fix CUDA tests issue.
@pytest.mark.skip(reason="Skipping CUDA tests temporarily")
class TestRigidObjectGroupCUDA(BaseRigidObjectGroupTest):
def setup_method(self):
diff --git a/tests/sim/objects/test_robot.py b/tests/sim/objects/test_robot.py
index 784aeaee..83b1414d 100644
--- a/tests/sim/objects/test_robot.py
+++ b/tests/sim/objects/test_robot.py
@@ -24,7 +24,6 @@
from embodichain.lab.sim.robots.dexforce_w1 import DexforceW1Cfg
from embodichain.data import get_data_path
-
# Define control parts
CONTROL_PARTS = {
"left_arm": [
@@ -50,10 +49,13 @@
# Base test class for CPU and CUDA
class BaseRobotTest:
- def setup_simulation(self, sim_device):
+ @classmethod
+ def setup_simulation(cls, sim_device):
+ if hasattr(cls, "sim"):
+ return
# Set up simulation with specified device (CPU or CUDA)
config = SimulationManagerCfg(headless=True, sim_device=sim_device, num_envs=10)
- self.sim = SimulationManager(config)
+ cls.sim = SimulationManager(config)
cfg = DexforceW1Cfg.from_dict(
{
@@ -63,11 +65,11 @@ def setup_simulation(self, sim_device):
}
)
- self.robot: Robot = self.sim.add_robot(cfg=cfg)
+ cls.robot: Robot = cls.sim.add_robot(cfg=cfg)
# Initialize GPU physics if needed
- if sim_device == "cuda" and getattr(self.sim, "is_use_gpu_physics", False):
- self.sim.init_gpu_physics()
+ if sim_device == "cuda" and getattr(cls.sim, "is_use_gpu_physics", False):
+ cls.sim.init_gpu_physics()
def test_get_joint_ids(self):
left_joint_ids = self.robot.get_joint_ids("left_arm")
@@ -139,6 +141,7 @@ def test_compute_fk(self):
],
],
dtype=torch.float32,
+ device=self.sim.device,
).unsqueeze_(0)
assert torch.allclose(
@@ -287,8 +290,20 @@ def test_robot_cfg_merge(self):
), "Solver config merge failed."
def teardown_method(self):
- """Clean up resources after each test method."""
- self.sim.destroy()
+ pass
+
+ @classmethod
+ def teardown_class(cls):
+ """Clean up resources after each test class."""
+ if hasattr(cls, "sim"):
+ cls.sim.destroy()
+ import embodichain.lab.sim as om
+
+ om.SimulationManager.flush_cleanup_queue()
+ del cls.sim
+ import gc
+
+ gc.collect()
def test_set_physical_visible(self):
self.robot.set_physical_visible(
@@ -311,7 +326,6 @@ def setup_method(self):
self.setup_simulation("cpu")
-@pytest.mark.skip(reason="Skipping CUDA tests temporarily")
class TestRobotCUDA(BaseRobotTest):
def setup_method(self):
self.setup_simulation("cuda")
@@ -319,6 +333,6 @@ def setup_method(self):
if __name__ == "__main__":
# Run tests directly
- test_cpu = TestRobotCPU()
+ test_cpu = TestRobotCUDA()
test_cpu.setup_method()
- test_cpu.test_fk("left_arm")
+ test_cpu.test_compute_jacobian()
diff --git a/tests/sim/objects/test_soft_object.py b/tests/sim/objects/test_soft_object.py
index b3955d88..06b3c1dc 100644
--- a/tests/sim/objects/test_soft_object.py
+++ b/tests/sim/objects/test_soft_object.py
@@ -18,6 +18,7 @@
from dexsim.utility.path import get_resources_data_path
from embodichain.lab.sim import SimulationManager, SimulationManagerCfg
from embodichain.lab.sim.cfg import (
+ RenderCfg,
SoftbodyVoxelAttributesCfg,
SoftbodyPhysicalAttributesCfg,
)
@@ -39,7 +40,6 @@ def setup_simulation(self):
headless=True,
physics_dt=1.0 / 100.0, # Physics timestep (100 Hz)
sim_device="cuda",
- enable_rt=False, # Enable ray tracing for better visuals
num_envs=4,
arena_space=3.0,
)
@@ -91,6 +91,13 @@ def test_remove(self):
def teardown_method(self):
"""Clean up resources after each test method."""
self.sim.destroy()
+ import embodichain.lab.sim as om
+
+ om.SimulationManager.flush_cleanup_queue()
+ self.__dict__.clear()
+ import gc
+
+ gc.collect()
class TestSoftObjectCUDA(BaseSoftObjectTest):
diff --git a/tests/sim/objects/test_usd.py b/tests/sim/objects/test_usd.py
index 350c9daf..a5558a39 100644
--- a/tests/sim/objects/test_usd.py
+++ b/tests/sim/objects/test_usd.py
@@ -23,6 +23,7 @@
)
from embodichain.lab.sim.objects import Articulation, RigidObject
from embodichain.lab.sim.cfg import (
+ RenderCfg,
ArticulationCfg,
RigidObjectCfg,
JointDrivePropertiesCfg,
@@ -39,7 +40,9 @@ class BaseUsdTest:
def setup_simulation(self, sim_device):
config = SimulationManagerCfg(
- headless=True, sim_device=sim_device, num_envs=NUM_ARENAS, enable_rt=False
+ headless=True,
+ sim_device=sim_device,
+ num_envs=NUM_ARENAS,
)
self.sim = SimulationManager(config)
@@ -166,8 +169,16 @@ def export_usd(self):
def teardown_method(self):
"""Clean up resources after each test method."""
self.sim.destroy()
+ import embodichain.lab.sim as om
+ om.SimulationManager.flush_cleanup_queue()
+ self.__dict__.clear()
+ import gc
+ gc.collect()
+
+
+@pytest.mark.skip(reason="Skipping CUDA tests temporarily")
class TestUsdCPU(BaseUsdTest):
def setup_method(self):
self.setup_simulation("cpu")
diff --git a/tests/sim/planners/test_motion_generator.py b/tests/sim/planners/test_motion_generator.py
index 511189d6..300d191b 100644
--- a/tests/sim/planners/test_motion_generator.py
+++ b/tests/sim/planners/test_motion_generator.py
@@ -33,6 +33,7 @@
MoveType,
MovePart,
)
+from embodichain.lab.sim.cfg import RenderCfg
def to_numpy(tensor):
@@ -45,8 +46,10 @@ def to_numpy(tensor):
class BaseTestMotionGenerator(object):
- @classmethod
- def setup_class(cls):
+ def setup_simulation(self):
+ cls = type(self)
+ if hasattr(cls, "robot_sim"):
+ return
cls.config = SimulationManagerCfg(headless=True, sim_device="cpu")
cls.robot_sim = SimulationManager(cls.config)
cls.robot_sim.set_manual_update(False)
@@ -157,11 +160,15 @@ def _execute_trajectory(self, qpos_list, forward=True, delay=0.01):
@classmethod
def teardown_class(cls):
- try:
+ if hasattr(cls, "robot_sim"):
cls.robot_sim.destroy()
- print("robot_sim destroyed successfully")
- except Exception as e:
- print(f"Error during robot_sim.destroy(): {e}")
+ import embodichain.lab.sim as om
+
+ om.SimulationManager.flush_cleanup_queue()
+ del cls.robot_sim
+ import gc
+
+ gc.collect()
def _execute_forward_trajectory(self, robot, qpos_list, delay=0.1):
"""Helper method to execute trajectory"""
@@ -183,6 +190,12 @@ def _execute_backward_trajectory(self, robot, qpos_list, delay=0.1):
class TestMotionGenerator(BaseTestMotionGenerator):
"""Test suite for MotionGenerator trajectory generation"""
+ def setup_method(self):
+ self.setup_simulation()
+
+ def teardown_method(self):
+ pass
+
@pytest.mark.parametrize("is_linear", [True, False])
def test_create_trajectory_with_xpos(self, is_linear):
"""Test trajectory generation with cartesian positions"""
diff --git a/tests/sim/planners/test_toppra_planner.py b/tests/sim/planners/test_toppra_planner.py
index d46f7e12..604581df 100644
--- a/tests/sim/planners/test_toppra_planner.py
+++ b/tests/sim/planners/test_toppra_planner.py
@@ -17,11 +17,14 @@
from embodichain.lab.sim.planners.utils import PlanState, TrajectorySampleMethod
from embodichain.lab.sim import SimulationManager, SimulationManagerCfg
from embodichain.lab.sim.robots import CobotMagicCfg
+from embodichain.lab.sim.cfg import RenderCfg
class TestToppraPlanner:
- @classmethod
- def setup_class(cls):
+ def setup_simulation(self):
+ cls = type(self)
+ if hasattr(cls, "sim"):
+ return
cls.sim_config = SimulationManagerCfg(headless=True, sim_device="cpu")
cls.sim = SimulationManager(cls.sim_config)
@@ -32,16 +35,28 @@ def setup_class(cls):
}
cls.robot = cls.sim.add_robot(cfg=CobotMagicCfg.from_dict(cfg_dict))
- @classmethod
- def teardown_class(cls):
- cls.sim.destroy()
-
def setup_method(self):
+ self.setup_simulation()
cfg = ToppraPlannerCfg(
robot_uid="CobotMagic_toppra",
)
self.planner = ToppraPlanner(cfg=cfg)
+ def teardown_method(self):
+ pass
+
+ @classmethod
+ def teardown_class(cls):
+ if hasattr(cls, "sim"):
+ cls.sim.destroy()
+ import embodichain.lab.sim as om
+
+ om.SimulationManager.flush_cleanup_queue()
+ del cls.sim
+ import gc
+
+ gc.collect()
+
def test_initialization(self):
assert self.planner.device == torch.device("cpu")
diff --git a/tests/sim/sensors/test_camera.py b/tests/sim/sensors/test_camera.py
index 0a70d35a..d95f0c4f 100644
--- a/tests/sim/sensors/test_camera.py
+++ b/tests/sim/sensors/test_camera.py
@@ -23,19 +23,21 @@
from embodichain.lab.sim import SimulationManager, SimulationManagerCfg
from embodichain.lab.sim.sensors import Camera, SensorCfg, CameraCfg
from embodichain.lab.sim.objects import Articulation
-from embodichain.lab.sim.cfg import ArticulationCfg
+from embodichain.lab.sim.cfg import ArticulationCfg, RenderCfg
from embodichain.data import get_data_path
-
NUM_ENVS = 4
ART_PATH = "SlidingBoxDrawer/SlidingBoxDrawer.urdf"
class CameraTest:
- def setup_simulation(self, sim_device, enable_rt):
+ def setup_simulation(self, sim_device, renderer="hybrid"):
# Setup SimulationManager
config = SimulationManagerCfg(
- headless=True, sim_device=sim_device, enable_rt=enable_rt, num_envs=NUM_ENVS
+ headless=True,
+ sim_device=sim_device,
+ render_cfg=RenderCfg(renderer=renderer),
+ num_envs=NUM_ENVS,
)
self.sim = SimulationManager(config)
# Create batch of cameras
@@ -137,30 +139,46 @@ def test_set_intrinsics(self):
def teardown_method(self):
"""Clean up resources after each test method."""
- self.sim.destroy()
+ if (
+ hasattr(self, "camera")
+ and getattr(self.camera, "uid", None) is not None
+ and hasattr(self, "sim")
+ ):
+ self.sim.remove_asset(self.camera.uid)
+ if hasattr(self, "sim"):
+ self.sim.destroy()
+ import embodichain.lab.sim as om
+
+ om.SimulationManager.flush_cleanup_queue()
+ import gc
+ gc.collect()
-class TestCameraRaster(CameraTest):
+
+class TestCameraHybrid(CameraTest):
def setup_method(self):
- self.setup_simulation("cpu", enable_rt=False)
+
+ self.setup_simulation("cpu", renderer="hybrid")
-class TestCameraRaster(CameraTest):
+class TestCameraHybridCUDA(CameraTest):
def setup_method(self):
- self.setup_simulation("cuda", enable_rt=False)
+
+ self.setup_simulation("cuda", renderer="hybrid")
class TestCameraFastRT(CameraTest):
def setup_method(self):
- self.setup_simulation("cpu", enable_rt=True)
+ self.setup_simulation("cpu", renderer="fast-rt")
-class TestCameraFastRT(CameraTest):
+class TestCameraFastRTCUDA(CameraTest):
def setup_method(self):
- self.setup_simulation("cuda", enable_rt=True)
+
+ self.setup_simulation("cuda", renderer="fast-rt")
if __name__ == "__main__":
- test = CameraTest()
- test.setup_simulation("cpu", enable_rt=False)
+ test = TestCameraFastRT()
+ test.setup_method()
test.test_attach_to_parent()
diff --git a/tests/sim/sensors/test_contact.py b/tests/sim/sensors/test_contact.py
index 07ad6c9a..aa38fc22 100644
--- a/tests/sim/sensors/test_contact.py
+++ b/tests/sim/sensors/test_contact.py
@@ -23,6 +23,7 @@
from embodichain.lab.sim import SimulationManager, SimulationManagerCfg
from embodichain.lab.sim.cfg import (
+ RenderCfg,
RigidBodyAttributesCfg,
)
from embodichain.lab.sim.sensors import (
@@ -38,7 +39,7 @@
class ContactTest:
- def setup_simulation(self, sim_device, enable_rt):
+ def setup_simulation(self, sim_device, renderer="hybrid"):
sim_cfg = SimulationManagerCfg(
width=1920,
height=1080,
@@ -46,7 +47,7 @@ def setup_simulation(self, sim_device, enable_rt):
headless=True,
physics_dt=1.0 / 100.0, # Physics timestep (100 Hz)
sim_device=sim_device,
- enable_rt=enable_rt, # Enable ray tracing for better visuals
+ render_cfg=RenderCfg(renderer=renderer),
)
# Create the simulation instance
@@ -63,9 +64,9 @@ def setup_simulation(self, sim_device, enable_rt):
contact_filter_art_cfg.link_name_list = ["finger1_link", "finger2_link"]
contact_filter_cfg.articulation_cfg_list = [contact_filter_art_cfg]
contact_filter_cfg.filter_need_both_actor = True
- self.contact_sensor = self.sim.add_sensor(sensor_cfg=contact_filter_cfg)
self.to_grasp_pose(cube2)
+ self.contact_sensor = self.sim.add_sensor(sensor_cfg=contact_filter_cfg)
def create_cube(self, uid: str, position: list = (0.0, 0.0, 0)) -> RigidObject:
"""create cube
@@ -78,7 +79,7 @@ def create_cube(self, uid: str, position: list = (0.0, 0.0, 0)) -> RigidObject:
Returns:
RigidObject: rigid object
"""
- cube_size = (0.025, 0.025, 0.025)
+ cube_size = (0.05, 0.05, 0.05)
cube: RigidObject = self.sim.add_rigid_object(
cfg=RigidObjectCfg(
uid=uid,
@@ -175,12 +176,14 @@ def to_grasp_pose(self, cube: RigidObject):
approach_xpos = target_xpos.clone()
approach_xpos[:, 2, 3] += 0.1
- is_success, approach_qpos = self.robot.compute_ik(
+ is_success_approach, approach_qpos = self.robot.compute_ik(
pose=approach_xpos, joint_seed=rest_arm_qpos, name="arm"
)
- is_success, target_qpos = self.robot.compute_ik(
+ print(f"Approach IK success: {is_success_approach}")
+ is_success_target, target_qpos = self.robot.compute_ik(
pose=target_xpos, joint_seed=approach_qpos, name="arm"
)
+ print(f"Target IK success: {is_success_target}")
self.robot.set_qpos(approach_qpos, joint_ids=arm_ids)
self.sim.update(step=40)
@@ -192,11 +195,22 @@ def to_grasp_pose(self, cube: RigidObject):
.repeat(self.sim.num_envs, 1)
)
self.robot.set_qpos(hand_close_qpos, joint_ids=gripper_ids)
- self.sim.update(step=20)
+ self.sim.update(step=200)
+
+ finger1_pose = self.robot.get_link_pose("finger1_link")
+ finger2_pose = self.robot.get_link_pose("finger2_link")
+ cube_pose = cube.get_local_pose()
+ print(f"Finger 1 pose: {finger1_pose[0][:3]}")
+ print(f"Finger 2 pose: {finger2_pose[0][:3]}")
+ print(f"Cube pose at end of grasp: {cube_pose[0][:3]}")
def test_fetch_contact(self):
- self.sim.update(step=1)
- self.contact_sensor.update()
+ # In a test suite, run multiple steps until contact is actually detected
+ for i in range(50):
+ self.sim.update(step=20)
+ self.contact_sensor.update()
+ if getattr(self.contact_sensor, "total_current_contacts", 0) > 0:
+ break
contact_report = self.contact_sensor.get_data()
# Check that contact data has correct shape (num_envs, max_contacts_per_env, ...)
@@ -230,7 +244,13 @@ def test_fetch_contact(self):
finger1_user_ids = (
self.sim.get_robot("UR10_PGI").get_user_ids("finger1_link").reshape(-1)
)
- filter_user_ids = torch.cat([cube2_user_ids, finger1_user_ids])
+ filter_user_ids = torch.cat(
+ [
+ cube2_user_ids,
+ self.sim.get_robot("UR10_PGI").get_user_ids("finger1_link").reshape(-1),
+ self.sim.get_robot("UR10_PGI").get_user_ids("finger2_link").reshape(-1),
+ ]
+ )
filter_contact_report = self.contact_sensor.filter_by_user_ids(filter_user_ids)
n_filtered_contact = filter_contact_report["position"].shape[0]
assert n_filtered_contact > 0, "No contact detected between gripper and cube."
@@ -241,27 +261,46 @@ def test_fetch_contact(self):
def teardown_method(self):
"""Clean up resources after each test method."""
- self.sim.destroy()
+ if (
+ hasattr(self, "contact_sensor")
+ and getattr(self.contact_sensor, "uid", None) is not None
+ and hasattr(self, "sim")
+ ):
+ self.sim.remove_asset(self.contact_sensor.uid)
+ if hasattr(self, "sim"):
+ self.sim.destroy()
+ import embodichain.lab.sim as om
+
+ om.SimulationManager.flush_cleanup_queue()
+ import gc
+ gc.collect()
-class TestContactRaster(ContactTest):
+
+class TestContactHybrid(ContactTest):
def setup_method(self):
- self.setup_simulation("cpu", enable_rt=False)
+
+ self.setup_simulation("cpu", renderer="hybrid")
-class TestContactRasterCuda(ContactTest):
+@pytest.mark.skip(reason="Skipping CUDA tests temporarily")
+class TestContactHybridCuda(ContactTest):
def setup_method(self):
- self.setup_simulation("cuda", enable_rt=False)
+
+ self.setup_simulation("cuda", renderer="hybrid")
class TestContactFastRT(ContactTest):
def setup_method(self):
- self.setup_simulation("cpu", enable_rt=True)
+ self.setup_simulation("cpu", renderer="fast-rt")
-class TestContactFastRTCuda(ContactTest):
+
+@pytest.mark.skip(reason="Skipping CUDA tests temporarily")
+class TestContactFastRTCUDA(ContactTest):
def setup_method(self):
- self.setup_simulation("cuda", enable_rt=True)
+
+ self.setup_simulation("cuda", renderer="fast-rt")
def test_contact_sensor_from_dict():
@@ -295,6 +334,6 @@ def test_contact_sensor_from_dict():
if __name__ == "__main__":
- test = ContactTest()
- test.setup_simulation("cuda", enable_rt=True)
+ test = TestContactHybridCuda()
+ test.setup_simulation("cuda", renderer="hybrid")
test.test_fetch_contact()
diff --git a/tests/sim/sensors/test_stereo.py b/tests/sim/sensors/test_stereo.py
index d74b9f77..58c5caed 100644
--- a/tests/sim/sensors/test_stereo.py
+++ b/tests/sim/sensors/test_stereo.py
@@ -16,18 +16,22 @@
import pytest
import torch
+
+from embodichain.lab.sim.cfg import RenderCfg
from embodichain.lab.sim import SimulationManager, SimulationManagerCfg
from embodichain.lab.sim.sensors import StereoCamera, SensorCfg
-
NUM_ENVS = 4
class StereoCameraTest:
- def setup_simulation(self, sim_device, enable_rt):
+ def setup_simulation(self, sim_device, renderer="hybrid"):
# Setup SimulationManager
config = SimulationManagerCfg(
- headless=True, sim_device=sim_device, enable_rt=enable_rt, num_envs=NUM_ENVS
+ headless=True,
+ sim_device=sim_device,
+ num_envs=NUM_ENVS,
+ render_cfg=RenderCfg(renderer=renderer),
)
self.sim = SimulationManager(config)
# Create batch of cameras
@@ -138,24 +142,41 @@ def test_set_intrinsics(self):
def teardown_method(self):
"""Clean up resources after each test method."""
- self.sim.destroy()
+ if (
+ hasattr(self, "camera")
+ and getattr(self.camera, "uid", None) is not None
+ and hasattr(self, "sim")
+ ):
+ self.sim.remove_asset(self.camera.uid)
+ if hasattr(self, "sim"):
+ self.sim.destroy()
+ import embodichain.lab.sim as om
+
+ om.SimulationManager.flush_cleanup_queue()
+ import gc
+ gc.collect()
-class TestStereoCameraRaster(StereoCameraTest):
+
+class TestStereoCameraHybrid(StereoCameraTest):
def setup_method(self):
- self.setup_simulation("cpu", enable_rt=False)
+
+ self.setup_simulation("cpu", renderer="hybrid")
-class TestStereoCameraRaster(StereoCameraTest):
+class TestStereoCameraHybridCUDA(StereoCameraTest):
def setup_method(self):
- self.setup_simulation("cuda", enable_rt=False)
+
+ self.setup_simulation("cuda", renderer="hybrid")
class TestStereoCameraFastRT(StereoCameraTest):
def setup_method(self):
- self.setup_simulation("cpu", enable_rt=True)
+ self.setup_simulation("cpu", renderer="fast-rt")
-class TestStereoCameraFastRT(StereoCameraTest):
+
+class TestStereoCameraFastRTCUDA(StereoCameraTest):
def setup_method(self):
- self.setup_simulation("cuda", enable_rt=True)
+
+ self.setup_simulation("cuda", renderer="fast-rt")
diff --git a/tests/sim/solvers/test_differential_solver.py b/tests/sim/solvers/test_differential_solver.py
index ace1c5d1..0e22a567 100644
--- a/tests/sim/solvers/test_differential_solver.py
+++ b/tests/sim/solvers/test_differential_solver.py
@@ -21,7 +21,7 @@
from embodichain.lab.sim import SimulationManager, SimulationManagerCfg
from embodichain.lab.sim.objects import Robot
-from embodichain.lab.sim.cfg import RobotCfg
+from embodichain.lab.sim.cfg import RobotCfg, RenderCfg
from embodichain.data import get_data_path
diff --git a/tests/sim/solvers/test_opw_solver.py b/tests/sim/solvers/test_opw_solver.py
index fe04f4b4..8153489d 100644
--- a/tests/sim/solvers/test_opw_solver.py
+++ b/tests/sim/solvers/test_opw_solver.py
@@ -21,6 +21,7 @@
from embodichain.lab.sim import SimulationManager, SimulationManagerCfg
from embodichain.lab.sim.objects import Robot
from embodichain.lab.sim.robots import CobotMagicCfg
+from embodichain.lab.sim.cfg import RenderCfg
def grid_sample_qpos_from_limits(
@@ -28,6 +29,7 @@ def grid_sample_qpos_from_limits(
steps_per_joint: int = 4,
device=None,
max_samples: int = 4096,
+ safe_margin: float = 5 / 180 * np.pi, # 5 degrees in radians
) -> torch.Tensor:
"""Generate grid samples for qpos from qpos_limits.
@@ -44,8 +46,8 @@ def grid_sample_qpos_from_limits(
device = qpos_limits.device
limits = qpos_limits.squeeze(0) if qpos_limits.dim() == 3 else qpos_limits
- lows = limits[:, 0].to(device)
- highs = limits[:, 1].to(device)
+ lows = limits[:, 0].to(device) + safe_margin * 1.01
+ highs = limits[:, 1].to(device) - safe_margin * 1.01
# create per-joint linspaces
grids = [
@@ -98,12 +100,20 @@ def setup_simulation(self, sim_device):
"end_link_name": "left_link6",
"root_link_name": "left_arm_base",
"tcp": [[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0.143], [0, 0, 0, 1]],
+ "qpos_limits": [
+ [-2.618, 0.0, -2.967, -1.745, -1.22, -2.0944],
+ [2.618, 3.14159, 0.0, 1.745, 1.22, 2.0944],
+ ],
},
"right_arm": {
"class_type": "OPWSolver",
"end_link_name": "right_link6",
"root_link_name": "right_arm_base",
"tcp": [[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0.143], [0, 0, 0, 1]],
+ "qpos_limits": [
+ [-2.618, 0.0, -2.967, -1.745, -1.22, -2.0944],
+ [2.618, 3.14159, 0.0, 1.745, 1.22, 2.0944],
+ ],
},
},
}
@@ -165,7 +175,7 @@ def test_ik(self, arm_name: str):
device=self.robot.device,
)
res, ik_qpos = self.robot.compute_ik(
- pose=invalid_pose, joint_seed=ik_qpos, name=arm_name
+ pose=invalid_pose, joint_seed=ik_qpos[:, 0, :], name=arm_name
)
dof = ik_qpos.shape[-1]
assert res[0] == False
@@ -181,7 +191,6 @@ def setup_method(self):
self.setup_simulation("cpu")
-@pytest.mark.skip(reason="Skipping CUDA tests temporarily")
class TestOPWSolverCUDA(BaseSolverTest):
def setup_method(self):
self.setup_simulation("cuda")
diff --git a/tests/sim/solvers/test_pink_solver.py b/tests/sim/solvers/test_pink_solver.py
index a8fda5fd..d5589fde 100644
--- a/tests/sim/solvers/test_pink_solver.py
+++ b/tests/sim/solvers/test_pink_solver.py
@@ -21,7 +21,7 @@
from embodichain.lab.sim import SimulationManager, SimulationManagerCfg
from embodichain.lab.sim.objects import Robot
-from embodichain.lab.sim.cfg import RobotCfg
+from embodichain.lab.sim.cfg import RobotCfg, RenderCfg
from embodichain.data import get_data_path
diff --git a/tests/sim/solvers/test_pinocchio_solver.py b/tests/sim/solvers/test_pinocchio_solver.py
index 34c91c47..698cb1f9 100644
--- a/tests/sim/solvers/test_pinocchio_solver.py
+++ b/tests/sim/solvers/test_pinocchio_solver.py
@@ -21,7 +21,7 @@
from embodichain.lab.sim import SimulationManager, SimulationManagerCfg
from embodichain.lab.sim.objects import Robot
-from embodichain.lab.sim.cfg import RobotCfg
+from embodichain.lab.sim.cfg import RobotCfg, RenderCfg
from embodichain.data import get_data_path
diff --git a/tests/sim/solvers/test_pytorch_solver.py b/tests/sim/solvers/test_pytorch_solver.py
index 5339c130..64bafee8 100644
--- a/tests/sim/solvers/test_pytorch_solver.py
+++ b/tests/sim/solvers/test_pytorch_solver.py
@@ -21,8 +21,48 @@
from embodichain.lab.sim import SimulationManager, SimulationManagerCfg
from embodichain.lab.sim.objects import Robot
-from embodichain.lab.sim.cfg import RobotCfg
+from embodichain.lab.sim.cfg import RobotCfg, RenderCfg
from embodichain.data import get_data_path
+from embodichain.utils.utility import reset_all_seeds
+
+
+def grid_sample_qpos_from_limits(
+ qpos_limits: torch.Tensor,
+ steps_per_joint: int = 4,
+ device=None,
+ max_samples: int = 4096,
+) -> torch.Tensor:
+ """Generate grid samples for qpos from qpos_limits.
+
+ Args:
+ qpos_limits: tensor of shape (1, n, 2) or (n, 2) where each row is [low, high].
+ steps_per_joint: number of values per joint (defaults to 2: low and high).
+ device: torch device to place the samples on.
+ max_samples: cap the number of returned samples (take first N if grid is larger).
+
+ Returns:
+ Tensor of shape (N, n) where N <= max_samples.
+ """
+ if device is None:
+ device = qpos_limits.device
+
+ limits = qpos_limits.squeeze(0) if qpos_limits.dim() == 3 else qpos_limits
+ lows = limits[:, 0].to(device) + 1e-2
+ highs = limits[:, 1].to(device) - 1e-2
+
+ # create per-joint linspaces
+ grids = [
+ torch.linspace(l.item(), h.item(), steps_per_joint, device=device)
+ for l, h in zip(lows, highs)
+ ]
+
+ # meshgrid and stack
+ mesh = torch.meshgrid(*grids, indexing="ij")
+ stacked = torch.stack([m.reshape(-1) for m in mesh], dim=1)
+
+ if stacked.shape[0] > max_samples:
+ return stacked[:max_samples]
+ return stacked
# Base test class for CPU and CUDA
@@ -50,11 +90,13 @@ def setup_simulation(self, solver_type: str):
"end_link_name": "left_ee",
"root_link_name": "left_arm_base",
"ik_nearest_weight": [1.0, 1.0, 1.0, 0.9, 0.9, 0.1, 0.1],
+ "num_samples": 30,
},
"right_arm": {
"class_type": solver_type,
"end_link_name": "right_ee",
"root_link_name": "right_arm_base",
+ "num_samples": 30,
},
},
}
@@ -66,27 +108,46 @@ def setup_simulation(self, solver_type: str):
@pytest.mark.parametrize("arm_name", ["left_arm", "right_arm"])
def test_ik(self, arm_name: str):
- # Test inverse kinematics (IK) with a 1x4x4 homogeneous matrix pose and a joint_seed
+ reset_all_seeds(0)
+ qpos_limit = torch.tensor(
+ [
+ [0.2, 0.8],
+ [0.2, 0.8],
+ [0.2, 0.8],
+ [0.2, 0.8],
+ [0.2, 0.8],
+ [0.2, 0.8],
+ [0.2, 0.8],
+ ]
+ )
+ # generate a small grid of qpos samples from the joint limits (low/high)
+ sample_qpos = grid_sample_qpos_from_limits(
+ qpos_limit, steps_per_joint=3, device=self.robot.device, max_samples=200
+ )
+ sample_qpos = sample_qpos[None, :, :]
- qpos_fk = torch.tensor(
- [[0.0, 0.0, 0.0, -np.pi / 4, 0.0, 0.0, 0.0]], dtype=torch.float32
+ fk_xpos = self.robot.compute_batch_fk(
+ qpos=sample_qpos, name=arm_name, to_matrix=True
+ )
+ fk_xpos_xyzquat = self.robot.compute_batch_fk(
+ qpos=sample_qpos, name=arm_name, to_matrix=False
)
- fk_xpos = self.robot.compute_fk(qpos=qpos_fk, name=arm_name, to_matrix=True)
+ res, ik_qpos = self.robot.compute_batch_ik(
+ pose=fk_xpos, joint_seed=sample_qpos, name=arm_name
+ )
- res, ik_qpos = self.robot.compute_ik(pose=fk_xpos, name=arm_name)
+ res, ik_qpos_xyzquat = self.robot.compute_batch_ik(
+ pose=fk_xpos_xyzquat, joint_seed=sample_qpos, name=arm_name
+ )
- if ik_qpos.dim() == 3:
- ik_xpos = self.robot.compute_fk(
- qpos=ik_qpos[0][0], name=arm_name, to_matrix=True
- )
- else:
- ik_xpos = self.robot.compute_fk(qpos=ik_qpos, name=arm_name, to_matrix=True)
+ ik_xpos = self.robot.compute_batch_fk(
+ qpos=ik_qpos_xyzquat, name=arm_name, to_matrix=True
+ )
assert torch.allclose(
- fk_xpos, ik_xpos, atol=1e-2, rtol=1e-2
- ), f"FK and IK results do not match for {arm_name}"
-
+ fk_xpos, ik_xpos, atol=5e-3, rtol=5e-3
+ ), f"FK and IK xpos do not match for {arm_name}"
# test for failed xpos
invalid_pose = torch.tensor(
[
@@ -101,10 +162,10 @@ def test_ik(self, arm_name: str):
device=self.robot.device,
)
res, ik_qpos = self.robot.compute_ik(
- pose=invalid_pose, joint_seed=ik_qpos, name=arm_name
+ pose=invalid_pose, joint_seed=ik_qpos[:, 0, :], name=arm_name
)
dof = ik_qpos.shape[-1]
- assert res[0] == False
+ assert res[0].item() == False
assert ik_qpos.shape == (1, dof)
def teardown_method(self):
diff --git a/tests/sim/solvers/test_srs_solver.py b/tests/sim/solvers/test_srs_solver.py
index a4a375ed..cfd970e0 100644
--- a/tests/sim/solvers/test_srs_solver.py
+++ b/tests/sim/solvers/test_srs_solver.py
@@ -21,7 +21,7 @@
from embodichain.lab.sim import SimulationManager, SimulationManagerCfg
from embodichain.lab.sim.objects import Robot
-from embodichain.lab.sim.cfg import RobotCfg
+from embodichain.lab.sim.cfg import RobotCfg, RenderCfg
from embodichain.data import get_data_path
from embodichain.lab.sim.solvers.srs_solver import SRSSolver, SRSSolverCfg
@@ -73,7 +73,7 @@ def setup_solver(self, solver_type: str, device: str = "cpu"):
)
cfg.urdf_path = urdf
cfg.dh_params = arm_params.dh_params
- cfg.qpos_limits = arm_params.qpos_limits
+ cfg.user_qpos_limits = arm_params.qpos_limits
cfg.T_e_oe = arm_params.T_e_oe
cfg.T_b_ob = arm_params.T_b_ob
cfg.link_lengths = arm_params.link_lengths
@@ -289,7 +289,6 @@ def setup_method(self):
self.setup_simulation(solver_type="SRSSolver", device="cpu")
-@pytest.mark.skip(reason="Skipping CUDA tests temporarily")
class TestSRSCUDARobotSolver(BaseRobotSolverTest):
def setup_method(self):
self.setup_simulation(solver_type="SRSSolver", device="cuda")