diff --git a/docs/source/overview/sim/atomic_actions.md b/docs/source/overview/sim/atomic_actions.md index 979df571..9d2523d5 100644 --- a/docs/source/overview/sim/atomic_actions.md +++ b/docs/source/overview/sim/atomic_actions.md @@ -238,4 +238,4 @@ is_success, traj = engine.execute_static(target_list=[target_pose]) - {doc}`planners/motion_generator` — the trajectory planner used by every action - {doc}`sim_robot` — how control parts and IK solvers are configured -- Tutorial: `scripts/tutorials/sim/atomic_actions.py` +- Tutorial: `scripts/tutorials/atomic_action/atomic_actions.py` diff --git a/docs/source/tutorial/atomic_actions.rst b/docs/source/tutorial/atomic_actions.rst index 10b8e97c..411f374f 100644 --- a/docs/source/tutorial/atomic_actions.rst +++ b/docs/source/tutorial/atomic_actions.rst @@ -30,13 +30,13 @@ For the full design overview, architecture diagram, and extension guide see The Code -------- -The tutorial corresponds to the ``atomic_actions.py`` script in the ``scripts/tutorials/sim`` +The tutorial corresponds to the ``atomic_actions.py`` script in the ``scripts/tutorials/atomic_action`` directory. .. dropdown:: Code for atomic_actions.py :icon: code - .. literalinclude:: ../../../scripts/tutorials/sim/atomic_actions.py + .. literalinclude:: ../../../scripts/tutorials/atomic_action/atomic_actions.py :language: python :linenos: diff --git a/embodichain/lab/sim/atomic_actions/__init__.py b/embodichain/lab/sim/atomic_actions/__init__.py index cf1e60ce..9bc7a741 100644 --- a/embodichain/lab/sim/atomic_actions/__init__.py +++ b/embodichain/lab/sim/atomic_actions/__init__.py @@ -33,9 +33,11 @@ MoveAction, PickUpAction, PlaceAction, + UprightAction, MoveActionCfg, PickUpActionCfg, PlaceActionCfg, + UprightActionCfg, ) from .engine import ( AtomicActionEngine, @@ -47,7 +49,6 @@ __all__ = [ # Core classes "Affordance", - "GraspPose", "InteractionPoints", "ObjectSemantics", "ActionCfg", @@ -56,9 +57,11 @@ "MoveAction", "PickUpAction", "PlaceAction", + "UprightAction", "MoveActionCfg", "PickUpActionCfg", "PlaceActionCfg", + "UprightActionCfg", # Engine "AtomicActionEngine", "register_action", diff --git a/embodichain/lab/sim/atomic_actions/actions.py b/embodichain/lab/sim/atomic_actions/actions.py index 1aa8901a..ce1e42a0 100644 --- a/embodichain/lab/sim/atomic_actions/actions.py +++ b/embodichain/lab/sim/atomic_actions/actions.py @@ -17,7 +17,7 @@ from __future__ import annotations import torch -from typing import Optional, Union, TYPE_CHECKING, Any +from typing import Optional, Union, TYPE_CHECKING from embodichain.lab.sim.planners import PlanResult, PlanState, MoveType from embodichain.lab.sim.planners.motion_generator import MotionGenOptions @@ -231,11 +231,26 @@ def _interpolate_hand_qpos( n_waypoints: int, ) -> torch.Tensor: """Interpolate hand joint positions between two gripper states.""" - weights = torch.linspace(0, 1, steps=n_waypoints, device=self.device) - hand_qpos_list = [ - torch.lerp(start_hand_qpos, end_hand_qpos, weight) for weight in weights - ] - return torch.stack(hand_qpos_list, dim=0) + start_hand_qpos = start_hand_qpos.to(self.device) + end_hand_qpos = end_hand_qpos.to(self.device) + + if start_hand_qpos.dim() == 1: + start_hand_qpos = start_hand_qpos.unsqueeze(0) + if end_hand_qpos.dim() == 1: + end_hand_qpos = end_hand_qpos.unsqueeze(0) + + weights = torch.linspace( + 0, + 1, + steps=n_waypoints, + device=self.device, + dtype=start_hand_qpos.dtype, + ) + return torch.lerp( + start_hand_qpos.unsqueeze(1), + end_hand_qpos.unsqueeze(1), + weights[None, :, None], + ) def execute( self, @@ -262,9 +277,7 @@ def execute( # TODO: warning and fallback if no valid grasp pose found if not is_success: - logger.log_warning( - "Failed to resolve grasp pose, using default approach pose" - ) + logger.log_warning("Failed to resolve move target pose.") return False, torch.empty(0), self.arm_joint_ids target_states_list = [ @@ -326,6 +339,26 @@ def __init__( self.arm_dof = len(self.arm_joint_ids) self.dof = len(self.joint_ids) + def _expand_hand_qpos(self, hand_qpos: torch.Tensor) -> torch.Tensor: + """Resolve hand qpos to batched shape ``(n_envs, hand_dof)``.""" + hand_dof = len(self.hand_joint_ids) + hand_qpos = hand_qpos.to(device=self.device, dtype=torch.float32) + if hand_qpos.shape == (hand_dof,): + return hand_qpos.unsqueeze(0).repeat(self.n_envs, 1) + if hand_qpos.shape == (self.n_envs, hand_dof): + return hand_qpos + logger.log_error( + f"hand_qpos must have shape ({hand_dof},) or " + f"({self.n_envs}, {hand_dof}), but got {hand_qpos.shape}", + ValueError, + ) + + def _repeat_hand_qpos( + self, hand_qpos: torch.Tensor, n_waypoints: int + ) -> torch.Tensor: + """Repeat hand qpos across trajectory waypoints.""" + return self._expand_hand_qpos(hand_qpos).unsqueeze(1).repeat(1, n_waypoints, 1) + def execute( self, target: Union[ObjectSemantics, torch.Tensor], @@ -353,11 +386,10 @@ def execute( target, action_name=self.__class__.__name__ ) - # TODO: warning and fallback if no valid grasp pose found + if isinstance(is_success, torch.Tensor): + is_success = torch.all(is_success).item() if not is_success: - logger.log_warning( - "Failed to resolve grasp pose, using default approach pose" - ) + logger.log_warning("Failed to resolve grasp pose for all environments.") return False, torch.empty(0), self.joint_ids # Compute pre-grasp pose @@ -517,6 +549,732 @@ def validate(self, target, start_qpos=None, **kwargs): return True +@configclass +class UprightActionCfg(PickUpActionCfg): + name: str = "upright" + """Name of the action, used for identification and logging.""" + + place_clearance: float = 0.005 + """Clearance (m) between the upright object bottom and the support plane.""" + + upright_axis_sign: float = 1.0 + """Direction of the object's local Z axis after upright placement. + + Use ``1.0`` to align local +Z with world +Z. Use ``-1.0`` when the mesh's + local +Z points toward the physical bottom and local -Z should face upward. + """ + + place_press_depth: float = 0.002 + """Additional downward displacement (m) after pre-place to make support contact.""" + + place_press_steps: int = 4 + """Number of closed-hand waypoints used for the downward place press.""" + + upright_hold_steps: int = 0 + """Number of closed-hand waypoints to hold after upright placement.""" + + place_hold_steps: int = 8 + """Number of closed-hand waypoints to hold the object after pressing it down.""" + + release_interp_steps: int = 12 + """Number of waypoints for the slow hand release phase.""" + + release_retreat_distance: float = 0.08 + """Horizontal distance (m) to retreat after releasing the upright object.""" + + release_retreat_lift: float = 0.01 + """Small upward offset (m) added during release retreat.""" + + use_grasp_width_qpos: bool = False + """Whether to map selected grasp open length into a dynamic hand close qpos.""" + + gripper_max_open_width: float = 0.088 + """Maximum total gripper opening width (m) used for width-to-qpos mapping.""" + + grasp_squeeze_width: float = 0.003 + """Width margin (m) subtracted from the selected grasp width before closing.""" + + final_approach_steps: int = 12 + """Number of waypoints for the slow final approach from pre-grasp to grasp.""" + + final_approach_preclose_width_margin: float = 0.010 + """Extra opening width (m) kept around the selected grasp width during final approach.""" + + grasp_hold_steps: int = 4 + """Number of closed-hand waypoints to hold the grasp before lifting.""" + + min_dynamic_hand_close_qpos: torch.Tensor | None = None + """Optional minimum hand qpos used when mapping grasp width into close qpos.""" + + max_grasp_open_length: float | None = None + """Optional maximum selected grasp opening length (m) for upright placement.""" + + max_grasp_axis_approach_dot: float | None = None + """Optional maximum absolute dot between grasp X axis and approach direction.""" + + max_grasp_axis_upright_axis_dot: float | None = None + """Optional maximum absolute dot between grasp X axis and object upright axis.""" + + upright_yaw_offsets: tuple[float, ...] = ( + 0.0, + 0.5 * np.pi, + -0.5 * np.pi, + np.pi, + ) + """Yaw offsets (rad) to try after aligning the object upright axis.""" + + +class UprightAction(PickUpAction): + def __init__( + self, + motion_generator: MotionGenerator, + cfg: UprightActionCfg | None = None, + ): + """ + Initialize the atomic action. + Args: + motion_generator: The motion generator instance to use for planning. + cfg: Configuration for the action. + """ + super().__init__( + motion_generator, cfg=cfg if cfg is not None else UprightActionCfg() + ) + + def _resolve_grasp_pose( + self, semantics: ObjectSemantics + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + if not isinstance(semantics.affordance, AntipodalAffordance): + logger.log_error( + "Grasp pose affordance must be of type AntipodalAffordance" + ) + if semantics.entity is None: + logger.log_error( + "ObjectSemantics must be associated with an entity to get object pose" + ) + obj_poses = semantics.entity.get_local_pose(to_matrix=True) + if semantics.affordance.generator is None: + semantics.affordance._init_generator() + generator = semantics.affordance.generator + if generator is None: + logger.log_error("Failed to initialize antipodal grasp generator") + + n_envs = obj_poses.shape[0] + approach_direction = self.approach_direction.to( + device=self.device, dtype=torch.float32 + ) + approach_direction = approach_direction / approach_direction.norm().clamp( + min=1e-6 + ) + max_open_length = self.cfg.max_grasp_open_length + max_approach_axis_dot = self.cfg.max_grasp_axis_approach_dot + max_upright_axis_dot = self.cfg.max_grasp_axis_upright_axis_dot + + is_success = torch.zeros(n_envs, dtype=torch.bool, device=self.device) + grasp_xpos = torch.eye(4, dtype=torch.float32, device=self.device).repeat( + n_envs, 1, 1 + ) + open_length = torch.zeros(n_envs, dtype=torch.float32, device=self.device) + init_qpos = self.robot.get_qpos(name=self.cfg.control_part) + world_z = torch.tensor([0.0, 0.0, 1.0], device=self.device) + upright_obj_pose_candidates = self._build_upright_object_pose_candidates( + semantics, obj_poses + ) + selected_upright_obj_poses = upright_obj_pose_candidates[:, 0].clone() + + for env_idx in range(n_envs): + ( + has_candidates, + candidate_grasp_xpos, + candidate_open_length, + candidate_cost, + ) = generator.get_valid_grasp_poses(obj_poses[env_idx], approach_direction) + if not has_candidates: + logger.log_warning( + f"No valid grasp candidates found for {env_idx}-th object." + ) + continue + + candidate_grasp_xpos = candidate_grasp_xpos.to( + device=self.device, dtype=torch.float32 + ) + candidate_open_length = candidate_open_length.to( + device=self.device, dtype=torch.float32 + ) + candidate_cost = candidate_cost.to(device=self.device, dtype=torch.float32) + candidate_mask = torch.ones( + candidate_grasp_xpos.shape[0], dtype=torch.bool, device=self.device + ) + if max_open_length is not None: + candidate_mask &= candidate_open_length <= max_open_length + + grasp_axis_dot = torch.abs( + (candidate_grasp_xpos[:, :3, 0] * approach_direction).sum(dim=1) + ) + if max_approach_axis_dot is not None: + candidate_mask &= grasp_axis_dot <= max_approach_axis_dot + + upright_axis = torch.nn.functional.normalize( + obj_poses[env_idx, :3, 2], dim=0 + ) + grasp_upright_axis_dot = torch.abs( + (candidate_grasp_xpos[:, :3, 0] * upright_axis).sum(dim=1) + ) + if max_upright_axis_dot is not None: + candidate_mask &= grasp_upright_axis_dot <= max_upright_axis_dot + + if not bool(torch.any(candidate_mask).item()): + logger.log_warning( + "No grasp candidates remain after upright grasp filtering " + f"for {env_idx}-th object." + ) + continue + + candidate_grasp_xpos = candidate_grasp_xpos[candidate_mask] + candidate_open_length = candidate_open_length[candidate_mask] + candidate_cost = candidate_cost[candidate_mask] + n_candidate = candidate_grasp_xpos.shape[0] + + pre_grasp_xpos = self._apply_offset( + pose=candidate_grasp_xpos, + offset=-candidate_grasp_xpos[:, :3, 2] * self.cfg.pre_grasp_distance, + ) + lift_xpos = self._apply_offset( + pose=candidate_grasp_xpos, + offset=world_z * self.cfg.lift_height, + ) + obj_pose_repeat = obj_poses[env_idx].unsqueeze(0).repeat(n_candidate, 1, 1) + obj_to_grasp = torch.bmm( + self._invert_pose(obj_pose_repeat), candidate_grasp_xpos + ) + + base_ik_success = torch.ones( + n_candidate, dtype=torch.bool, device=self.device + ) + qpos_seed = init_qpos[env_idx : env_idx + 1, None, :].repeat( + 1, n_candidate, 1 + ) + for target_xpos in ( + pre_grasp_xpos, + candidate_grasp_xpos, + lift_xpos, + ): + target_success, target_qpos = self.robot.compute_batch_ik( + pose=target_xpos.unsqueeze(0), + name=self.cfg.control_part, + joint_seed=qpos_seed, + env_ids=[env_idx], + ) + base_ik_success &= target_success[0] + qpos_seed = target_qpos + + n_upright_pose = upright_obj_pose_candidates.shape[1] + upright_obj_pose_repeat = ( + upright_obj_pose_candidates[env_idx] + .unsqueeze(1) + .repeat(1, n_candidate, 1, 1) + .reshape(-1, 4, 4) + ) + obj_to_grasp_repeat = ( + obj_to_grasp.unsqueeze(0) + .repeat(n_upright_pose, 1, 1, 1) + .reshape(-1, 4, 4) + ) + upright_lift_obj_xpos = self._apply_offset( + pose=upright_obj_pose_repeat, + offset=world_z * self.cfg.lift_height, + ) + upright_lift_xpos = torch.bmm(upright_lift_obj_xpos, obj_to_grasp_repeat) + upright_place_xpos = torch.bmm(upright_obj_pose_repeat, obj_to_grasp_repeat) + press_xpos = self._apply_offset( + pose=upright_place_xpos, + offset=-world_z + * (self.cfg.place_clearance + self.cfg.place_press_depth), + ) + + ik_success = base_ik_success.repeat(n_upright_pose) + upright_qpos_seed = qpos_seed.repeat(1, n_upright_pose, 1) + for target_xpos in (upright_lift_xpos, upright_place_xpos, press_xpos): + target_success, target_qpos = self.robot.compute_batch_ik( + pose=target_xpos.unsqueeze(0), + name=self.cfg.control_part, + joint_seed=upright_qpos_seed, + env_ids=[env_idx], + ) + ik_success &= target_success[0] + upright_qpos_seed = target_qpos + + flat_candidate_cost = candidate_cost.repeat(n_upright_pose) + masked_cost = torch.where( + ik_success, + flat_candidate_cost, + torch.full_like(flat_candidate_cost, float("inf")), + ) + best_cost, best_flat_idx = masked_cost.min(dim=0) + best_upright_idx = torch.div( + best_flat_idx, n_candidate, rounding_mode="floor" + ) + best_idx = best_flat_idx % n_candidate + + if not torch.isfinite(best_cost): + logger.log_warning( + "No upright grasp candidates remain after IK feasibility " + f"filtering for {env_idx}-th object." + ) + continue + + is_success[env_idx] = True + grasp_xpos[env_idx] = candidate_grasp_xpos[best_idx] + open_length[env_idx] = candidate_open_length[best_idx] + selected_upright_obj_poses[env_idx] = upright_obj_pose_candidates[ + env_idx, best_upright_idx + ] + + self._selected_upright_obj_xpos = selected_upright_obj_poses + return is_success, grasp_xpos, open_length + + @staticmethod + def _invert_pose(pose: torch.Tensor) -> torch.Tensor: + """Invert a batched homogeneous transform.""" + inv_pose = pose.clone() + rot_t = pose[:, :3, :3].transpose(1, 2) + inv_pose[:, :3, :3] = rot_t + inv_pose[:, :3, 3] = -torch.bmm(rot_t, pose[:, :3, 3:4]).squeeze(-1) + return inv_pose + + def _build_upright_object_pose( + self, semantics: ObjectSemantics, obj_poses: torch.Tensor + ) -> torch.Tensor: + """Build a target object pose whose configured local Z direction is upright.""" + world_z = torch.tensor([0.0, 0.0, 1.0], device=self.device) + axis_sign = 1.0 if self.cfg.upright_axis_sign >= 0.0 else -1.0 + projected_x = obj_poses[:, :3, 0].clone() + projected_x[:, 2] = 0.0 + projected_x_norm = projected_x.norm(dim=1, keepdim=True) + + fallback_x = obj_poses[:, :3, 1].clone() + fallback_x[:, 2] = 0.0 + fallback_x_norm = fallback_x.norm(dim=1, keepdim=True) + fallback_x = fallback_x / fallback_x_norm.clamp(min=1e-6) + + default_x = torch.tensor([1.0, 0.0, 0.0], device=self.device).repeat( + self.n_envs, 1 + ) + upright_x = torch.where( + projected_x_norm > 1e-6, + projected_x / projected_x_norm.clamp(min=1e-6), + torch.where(fallback_x_norm > 1e-6, fallback_x, default_x), + ) + upright_z = axis_sign * world_z.repeat(self.n_envs, 1) + upright_y = torch.cross(upright_z, upright_x, dim=1) + upright_y = upright_y / upright_y.norm(dim=1, keepdim=True).clamp(min=1e-6) + upright_x = torch.cross(upright_y, upright_z, dim=1) + upright_x = upright_x / upright_x.norm(dim=1, keepdim=True).clamp(min=1e-6) + + upright_pose = obj_poses.clone() + upright_pose[:, :3, 0] = upright_x + upright_pose[:, :3, 1] = upright_y + upright_pose[:, :3, 2] = upright_z + + mesh_vertices = semantics.geometry.get("mesh_vertices") + if isinstance(mesh_vertices, torch.Tensor) and mesh_vertices.numel() > 0: + mesh_vertices = mesh_vertices.to(device=self.device, dtype=torch.float32) + vertical_offsets = torch.matmul( + mesh_vertices, upright_pose[:, 2, :3].transpose(0, 1) + ) + local_bottom_z = vertical_offsets.min(dim=0).values + upright_pose[:, 2, 3] = self.cfg.place_clearance - local_bottom_z + return upright_pose + + def _build_upright_object_pose_candidates( + self, semantics: ObjectSemantics, obj_poses: torch.Tensor + ) -> torch.Tensor: + """Build upright target poses with alternative yaw rotations.""" + base_pose = self._build_upright_object_pose(semantics, obj_poses) + yaw_offsets = torch.as_tensor( + self.cfg.upright_yaw_offsets, device=self.device, dtype=torch.float32 + ) + cos_yaw = torch.cos(yaw_offsets).view(1, -1, 1) + sin_yaw = torch.sin(yaw_offsets).view(1, -1, 1) + + base_x = base_pose[:, None, :3, 0] + base_y = base_pose[:, None, :3, 1] + candidates = base_pose[:, None, :, :].repeat(1, yaw_offsets.numel(), 1, 1) + candidates[:, :, :3, 0] = cos_yaw * base_x + sin_yaw * base_y + candidates[:, :, :3, 1] = -sin_yaw * base_x + cos_yaw * base_y + return candidates + + def _compute_hand_qpos_for_width(self, target_width: torch.Tensor) -> torch.Tensor: + """Map desired total gripper width to batched hand qpos.""" + target_width = target_width.to(device=self.device, dtype=torch.float32).view( + self.n_envs, 1 + ) + target_width = target_width.clamp(min=0.0, max=self.cfg.gripper_max_open_width) + closing_distance = 0.5 * (self.cfg.gripper_max_open_width - target_width).clamp( + min=0.0 + ) + hand_qpos_limits = self.robot.get_qpos_limits( + name=self.cfg.hand_control_part + ).to(self.device) + lower_limits = hand_qpos_limits[:, :, 0] + upper_limits = hand_qpos_limits[:, :, 1] + hand_open_qpos = self._expand_hand_qpos(self.hand_open_qpos) + dynamic_qpos = hand_open_qpos + closing_distance.repeat( + 1, len(self.hand_joint_ids) + ) + return torch.max(torch.min(dynamic_qpos, upper_limits), lower_limits) + + def _compute_dynamic_hand_close_qpos( + self, grasp_open_length: torch.Tensor + ) -> torch.Tensor: + """Map selected grasp width to batched hand close qpos for parallel grippers.""" + fallback_qpos = self._expand_hand_qpos(self.hand_close_qpos) + if not self.cfg.use_grasp_width_qpos: + return fallback_qpos + + grasp_open_length = grasp_open_length.to( + device=self.device, dtype=torch.float32 + ).view(self.n_envs, 1) + target_width = (grasp_open_length - self.cfg.grasp_squeeze_width).clamp(min=0.0) + dynamic_qpos = self._compute_hand_qpos_for_width(target_width) + if self.cfg.min_dynamic_hand_close_qpos is not None: + min_close_qpos = self._expand_hand_qpos( + self.cfg.min_dynamic_hand_close_qpos + ) + dynamic_qpos = torch.max(dynamic_qpos, min_close_qpos) + return dynamic_qpos + + def _compute_final_approach_hand_qpos( + self, grasp_open_length: torch.Tensor, hand_close_qpos: torch.Tensor + ) -> torch.Tensor: + """Pre-close the gripper during final approach without reaching squeeze force.""" + hand_open_qpos = self._expand_hand_qpos(self.hand_open_qpos) + if not self.cfg.use_grasp_width_qpos: + return hand_open_qpos + + grasp_open_length = grasp_open_length.to( + device=self.device, dtype=torch.float32 + ).view(self.n_envs, 1) + target_width = grasp_open_length + self.cfg.final_approach_preclose_width_margin + preclose_qpos = self._compute_hand_qpos_for_width(target_width) + return torch.max(torch.min(preclose_qpos, hand_close_qpos), hand_open_qpos) + + def execute( + self, + target: Union[ObjectSemantics, torch.Tensor], + start_qpos: Optional[torch.Tensor] = None, + **kwargs, + ) -> tuple[bool, torch.Tensor, list[float]]: + """Pick up an object, rotate it upright, place it down, and release it.""" + if not isinstance(target, ObjectSemantics): + return super().execute(target=target, start_qpos=start_qpos, **kwargs) + + is_success, grasp_xpos, grasp_open_length = self._resolve_grasp_pose(target) + obj_poses = target.entity.get_local_pose(to_matrix=True) + if not torch.all(is_success).item(): + logger.log_warning( + "Failed to resolve upright grasp pose for all environments." + ) + return False, torch.empty(0), self.joint_ids + + world_z = torch.tensor([0, 0, 1], device=self.device, dtype=torch.float32) + hand_close_qpos = self._compute_dynamic_hand_close_qpos(grasp_open_length) + final_approach_hand_qpos = self._compute_final_approach_hand_qpos( + grasp_open_length, hand_close_qpos + ) + pre_grasp_xpos = self._apply_offset( + pose=grasp_xpos, + offset=-grasp_xpos[:, :3, 2] * self.cfg.pre_grasp_distance, + ) + lift_xpos = self._apply_offset( + pose=grasp_xpos, + offset=world_z * self.cfg.lift_height, + ) + + obj_to_grasp = torch.bmm(self._invert_pose(obj_poses), grasp_xpos) + upright_obj_xpos = getattr(self, "_selected_upright_obj_xpos", None) + if upright_obj_xpos is None or upright_obj_xpos.shape != obj_poses.shape: + upright_obj_xpos = self._build_upright_object_pose(target, obj_poses) + upright_lift_obj_xpos = self._apply_offset( + pose=upright_obj_xpos, + offset=world_z * self.cfg.lift_height, + ) + upright_lift_xpos = torch.bmm(upright_lift_obj_xpos, obj_to_grasp) + upright_place_xpos = torch.bmm(upright_obj_xpos, obj_to_grasp) + press_down_distance = self.cfg.place_clearance + self.cfg.place_press_depth + press_xpos = self._apply_offset( + pose=upright_place_xpos, + offset=-world_z * press_down_distance, + ) + retreat_direction = -press_xpos[:, :3, 2] + retreat_direction[:, 2] = 0.0 + retreat_direction_norm = retreat_direction.norm(dim=1, keepdim=True) + retreat_direction = torch.where( + retreat_direction_norm > 1e-6, + retreat_direction / retreat_direction_norm.clamp(min=1e-6), + -press_xpos[:, :3, 0], + ) + retreat_direction[:, 2] = 0.0 + retreat_direction = retreat_direction / retreat_direction.norm( + dim=1, keepdim=True + ).clamp(min=1e-6) + retreat_offset = ( + retreat_direction * self.cfg.release_retreat_distance + + world_z * self.cfg.release_retreat_lift + ) + retreat_xpos = self._apply_offset( + pose=press_xpos, + offset=retreat_offset, + ) + + start_qpos = self._resolve_start_qpos(start_qpos, self.arm_dof) + n_close_waypoint = self.cfg.hand_interp_steps + n_final_approach_waypoint = max(2, self.cfg.final_approach_steps) + n_grasp_hold_waypoint = max(0, self.cfg.grasp_hold_steps) + n_press_waypoint = max(2, self.cfg.place_press_steps) + n_upright_hold_waypoint = max(0, self.cfg.upright_hold_steps) + n_hold_waypoint = max(0, self.cfg.place_hold_steps) + n_open_waypoint = max(2, self.cfg.release_interp_steps) + motion_waypoints = ( + self.cfg.sample_interval + - n_close_waypoint + - n_final_approach_waypoint + - n_grasp_hold_waypoint + - n_upright_hold_waypoint + - n_press_waypoint + - n_hold_waypoint + - n_open_waypoint + ) + if motion_waypoints < 6: + logger.log_error( + "Not enough waypoints for upright action. Please increase " + "sample_interval or decrease hand/press/upright-hold/hold/release " + "steps.", + ValueError, + ) + n_pre_approach_waypoint = max(2, int(np.round(motion_waypoints * 0.25))) + n_upright_waypoint = max(2, int(np.round(motion_waypoints * 0.60))) + n_retreat_waypoint = ( + self.cfg.sample_interval + - n_close_waypoint + - n_final_approach_waypoint + - n_grasp_hold_waypoint + - n_upright_hold_waypoint + - n_press_waypoint + - n_hold_waypoint + - n_open_waypoint + - n_pre_approach_waypoint + - n_upright_waypoint + ) + if n_retreat_waypoint < 2: + retreat_deficit = 2 - n_retreat_waypoint + n_retreat_waypoint = 2 + n_upright_waypoint = max(2, n_upright_waypoint - retreat_deficit) + target_states_list = [ + [ + PlanState(xpos=pre_grasp_xpos[i], move_type=MoveType.EEF_MOVE), + ] + for i in range(self.n_envs) + ] + is_success, plan_traj = self._plan_arm_trajectory( + target_states_list, + start_qpos, + n_pre_approach_waypoint, + self.arm_dof, + ) + if not is_success: + logger.log_warning("Failed to plan approach trajectory.") + return False, torch.empty(0), self.joint_ids + approach_trajectory = torch.zeros( + size=(self.n_envs, n_pre_approach_waypoint, self.dof), + dtype=torch.float32, + device=self.device, + ) + approach_trajectory[:, :, : self.arm_dof] = plan_traj + approach_trajectory[:, :, self.arm_dof :] = self._repeat_hand_qpos( + self.hand_open_qpos, n_pre_approach_waypoint + ) + + pre_grasp_qpos = approach_trajectory[:, -1, : self.arm_dof] + target_states_list = [ + [PlanState(xpos=grasp_xpos[i], move_type=MoveType.EEF_MOVE)] + for i in range(self.n_envs) + ] + is_success, plan_traj = self._plan_arm_trajectory( + target_states_list, + pre_grasp_qpos, + n_final_approach_waypoint, + self.arm_dof, + ) + if not is_success: + logger.log_warning("Failed to plan final approach trajectory.") + return False, torch.empty(0), self.joint_ids + final_approach_trajectory = torch.zeros( + size=(self.n_envs, n_final_approach_waypoint, self.dof), + dtype=torch.float32, + device=self.device, + ) + final_approach_trajectory[:, :, : self.arm_dof] = plan_traj + final_approach_hand_path = self._interpolate_hand_qpos( + self.hand_open_qpos, + final_approach_hand_qpos, + n_final_approach_waypoint, + ) + final_approach_trajectory[:, :, self.arm_dof :] = final_approach_hand_path + + grasp_qpos = final_approach_trajectory[:, -1, : self.arm_dof] + hand_close_path = self._interpolate_hand_qpos( + final_approach_hand_qpos, + hand_close_qpos, + n_close_waypoint, + ) + close_trajectory = torch.zeros( + size=(self.n_envs, n_close_waypoint, self.dof), + dtype=torch.float32, + device=self.device, + ) + close_trajectory[:, :, : self.arm_dof] = grasp_qpos.unsqueeze(1) + close_trajectory[:, :, self.arm_dof :] = hand_close_path + + closed_grasp_qpos = close_trajectory[:, -1, : self.arm_dof] + grasp_hold_trajectory = torch.zeros( + size=(self.n_envs, n_grasp_hold_waypoint, self.dof), + dtype=torch.float32, + device=self.device, + ) + if n_grasp_hold_waypoint > 0: + grasp_hold_trajectory[:, :, : self.arm_dof] = closed_grasp_qpos.unsqueeze(1) + grasp_hold_trajectory[:, :, self.arm_dof :] = self._repeat_hand_qpos( + hand_close_qpos, n_grasp_hold_waypoint + ) + + target_states_list = [ + [ + PlanState(xpos=lift_xpos[i], move_type=MoveType.EEF_MOVE), + PlanState(xpos=upright_lift_xpos[i], move_type=MoveType.EEF_MOVE), + PlanState(xpos=upright_place_xpos[i], move_type=MoveType.EEF_MOVE), + ] + for i in range(self.n_envs) + ] + is_success, plan_traj = self._plan_arm_trajectory( + target_states_list, + closed_grasp_qpos, + n_upright_waypoint, + self.arm_dof, + ) + if not is_success: + logger.log_warning("Failed to plan upright trajectory.") + return False, torch.empty(0), self.joint_ids + upright_trajectory = torch.zeros( + size=(self.n_envs, n_upright_waypoint, self.dof), + dtype=torch.float32, + device=self.device, + ) + upright_trajectory[:, :, : self.arm_dof] = plan_traj + upright_trajectory[:, :, self.arm_dof :] = self._repeat_hand_qpos( + hand_close_qpos, n_upright_waypoint + ) + + place_qpos = upright_trajectory[:, -1, : self.arm_dof] + upright_hold_trajectory = torch.zeros( + size=(self.n_envs, n_upright_hold_waypoint, self.dof), + dtype=torch.float32, + device=self.device, + ) + if n_upright_hold_waypoint > 0: + upright_hold_trajectory[:, :, : self.arm_dof] = place_qpos.unsqueeze(1) + upright_hold_trajectory[:, :, self.arm_dof :] = self._repeat_hand_qpos( + hand_close_qpos, n_upright_hold_waypoint + ) + + target_states_list = [ + [PlanState(xpos=press_xpos[i], move_type=MoveType.EEF_MOVE)] + for i in range(self.n_envs) + ] + is_success, plan_traj = self._plan_arm_trajectory( + target_states_list, + place_qpos, + n_press_waypoint, + self.arm_dof, + ) + if not is_success: + logger.log_warning("Failed to plan place press trajectory.") + return False, torch.empty(0), self.joint_ids + press_trajectory = torch.zeros( + size=(self.n_envs, n_press_waypoint, self.dof), + dtype=torch.float32, + device=self.device, + ) + press_trajectory[:, :, : self.arm_dof] = plan_traj + press_trajectory[:, :, self.arm_dof :] = self._repeat_hand_qpos( + hand_close_qpos, n_press_waypoint + ) + + press_qpos = press_trajectory[:, -1, : self.arm_dof] + hold_trajectory = torch.zeros( + size=(self.n_envs, n_hold_waypoint, self.dof), + dtype=torch.float32, + device=self.device, + ) + if n_hold_waypoint > 0: + hold_trajectory[:, :, : self.arm_dof] = press_qpos.unsqueeze(1) + hold_trajectory[:, :, self.arm_dof :] = self._repeat_hand_qpos( + hand_close_qpos, n_hold_waypoint + ) + + hand_open_path = self._interpolate_hand_qpos( + hand_close_qpos, + self.hand_open_qpos, + n_open_waypoint, + ) + open_trajectory = torch.zeros( + size=(self.n_envs, n_open_waypoint, self.dof), + dtype=torch.float32, + device=self.device, + ) + open_trajectory[:, :, : self.arm_dof] = press_qpos.unsqueeze(1) + open_trajectory[:, :, self.arm_dof :] = hand_open_path + + target_states_list = [ + [PlanState(xpos=retreat_xpos[i], move_type=MoveType.EEF_MOVE)] + for i in range(self.n_envs) + ] + is_success, plan_traj = self._plan_arm_trajectory( + target_states_list, + press_qpos, + n_retreat_waypoint, + self.arm_dof, + ) + if not is_success: + logger.log_warning("Failed to plan retreat trajectory.") + return False, torch.empty(0), self.joint_ids + retreat_trajectory = torch.zeros( + size=(self.n_envs, n_retreat_waypoint, self.dof), + dtype=torch.float32, + device=self.device, + ) + retreat_trajectory[:, :, : self.arm_dof] = plan_traj + retreat_trajectory[:, :, self.arm_dof :] = self._repeat_hand_qpos( + self.hand_open_qpos, n_retreat_waypoint + ) + + trajectory = torch.cat( + [ + approach_trajectory, + final_approach_trajectory, + close_trajectory, + grasp_hold_trajectory, + upright_trajectory, + upright_hold_trajectory, + press_trajectory, + hold_trajectory, + open_trajectory, + retreat_trajectory, + ], + dim=1, + ) + return True, trajectory, self.joint_ids + + @configclass class PlaceActionCfg(GraspActionCfg): name: str = "place" @@ -576,9 +1334,7 @@ def execute( # TODO: warning and fallback if no valid grasp pose found if not is_success: - logger.log_warning( - "Failed to resolve grasp pose, using default approach pose" - ) + logger.log_warning("Failed to resolve place target pose.") return False, torch.empty(0), self.joint_ids # compute waypoint number for each phase diff --git a/scripts/tutorials/sim/atomic_actions.py b/scripts/tutorials/atomic_action/atomic_actions.py similarity index 100% rename from scripts/tutorials/sim/atomic_actions.py rename to scripts/tutorials/atomic_action/atomic_actions.py diff --git a/scripts/tutorials/atomic_action/upright_atomic_action.py b/scripts/tutorials/atomic_action/upright_atomic_action.py new file mode 100644 index 00000000..8a5f5d42 --- /dev/null +++ b/scripts/tutorials/atomic_action/upright_atomic_action.py @@ -0,0 +1,740 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +""" +This script demonstrates the creation and simulation of a robot that uprights a fallen +bottle in a simulated environment using the SimulationManager and atomic actions. +""" + +from __future__ import annotations + +import argparse +import sys +import time +from pathlib import Path + +_REPO_ROOT = Path(__file__).resolve().parents[3] +if str(_REPO_ROOT) not in sys.path: + sys.path.insert(0, str(_REPO_ROOT)) + +import torch + +from embodichain.lab.sim import SimulationManager, SimulationManagerCfg +from embodichain.lab.sim.objects import Robot, RigidObject +from embodichain.lab.sim.shapes import CubeCfg, MeshCfg +from embodichain.lab.sim.solvers import PytorchSolverCfg +from embodichain.lab.sim.planners import MotionGenerator, MotionGenCfg, ToppraPlannerCfg +from embodichain.data import get_data_path +from embodichain.lab.gym.utils.gym_utils import add_env_launcher_args_to_parser +from embodichain.utils import logger +from embodichain.lab.sim.cfg import ( + RenderCfg, + JointDrivePropertiesCfg, + RobotCfg, + LightCfg, + RigidBodyAttributesCfg, + RigidObjectCfg, + URDFCfg, +) +from embodichain.lab.sim.atomic_actions import ( + AntipodalAffordance, + ObjectSemantics, + UprightAction, + UprightActionCfg, +) +from embodichain.toolkits.graspkit.pg_grasp.antipodal_generator import ( + GraspGenerator, + GraspGeneratorCfg, + AntipodalSamplerCfg, +) +from embodichain.toolkits.graspkit.pg_grasp.gripper_collision_checker import ( + GripperCollisionCfg, +) + +GRIPPER_URDF_PATH = "DH_PGI_140_80/DH_PGI_140_80.urdf" +GRIPPER_HAND_JOINT_PATTERN = "GRIPPER_FINGER1_JOINT_1" +GRIPPER_MAX_OPEN_WIDTH = 0.080 +GRIPPER_FINGER_LENGTH = 0.088 +GRIPPER_ROOT_Z_WIDTH = 0.096 +GRIPPER_Y_THICKNESS = 0.040 +GRIPPER_TCP_Z = 0.15 +PGI_SAMPLE_INTERVAL = 120 +PGI_HAND_CLOSE_STEPS = 12 +PGI_GRASP_HOLD_STEPS = 20 +BOTTLE_LABEL = "bottle" +BOTTLE_APPROACH_DIRECTION = (0.0, 0.0, -1.0) +BOTTLE_GRASP_SQUEEZE_WIDTH = 0.020 +BOTTLE_MAX_GRASP_OPEN_LENGTH = 0.060 +BOTTLE_MAX_GRASP_AXIS_APPROACH_DOT = 0.080 +BOTTLE_MAX_GRASP_AXIS_UPRIGHT_AXIS_DOT = 0.35 +BOTTLE_MIN_DYNAMIC_HAND_CLOSE_QPOS = 0.024 +BOTTLE_GRASP_COLLISION_THRESHOLD = -0.004 + + +def parse_arguments(): + """ + Parse command-line arguments to configure the simulation. + + Returns: + argparse.Namespace: Parsed arguments including simulation and rendering + options. + """ + parser = argparse.ArgumentParser( + description="Create and simulate a robot in SimulationManager" + ) + add_env_launcher_args_to_parser(parser) + parser.add_argument( + "--n_sample", + type=int, + default=10000, + help="Number of surface samples for antipodal grasp generation.", + ) + parser.add_argument( + "--force_reannotate", + action="store_true", + help=( + "Force grasp region re-annotation instead of reusing cached antipodal " + "pairs." + ), + ) + parser.add_argument( + "--debug_hand_state", + action="store_true", + help="Log planned hand targets and simulated hand qpos during execution.", + ) + parser.add_argument( + "--diagnose_grasp", + action="store_true", + help="Plan once and print grasp/TCP diagnostics without opening the viewer.", + ) + parser.add_argument( + "--auto_play", + action="store_true", + help="Run the viewer demo without waiting for keyboard input.", + ) + return parser.parse_args() + + +def initialize_simulation(args) -> SimulationManager: + """ + Initialize the simulation environment based on the provided arguments. + + Args: + args (argparse.Namespace): Parsed command-line arguments. + + Returns: + SimulationManager: Configured simulation manager instance. + """ + config = SimulationManagerCfg( + headless=True, + sim_device=args.device, + render_cfg=RenderCfg(renderer=args.renderer), + physics_dt=1.0 / 100.0, + arena_space=2.5, + ) + sim = SimulationManager(config) + + sim.add_light( + cfg=LightCfg( + uid="main_light", + color=(0.6, 0.6, 0.6), + intensity=30.0, + init_pos=(1.0, 0, 3.0), + ) + ) + + return sim + + +def create_robot(sim: SimulationManager, position=[0.0, 0.0, 0.0]) -> Robot: + """ + Create and configure a robot with an arm and a dexterous hand in the simulation. + + Args: + sim (SimulationManager): The simulation manager instance. + + Returns: + Robot: The configured robot instance added to the simulation. + """ + # Retrieve URDF paths for the robot arm and hand + ur10_urdf_path = get_data_path("UniversalRobots/UR10/UR10.urdf") + gripper_urdf_path = get_data_path(GRIPPER_URDF_PATH) + # Configure the robot with its components and control properties + cfg = RobotCfg( + uid="UR10", + urdf_cfg=URDFCfg( + components=[ + {"component_type": "arm", "urdf_path": ur10_urdf_path}, + {"component_type": "hand", "urdf_path": gripper_urdf_path}, + ] + ), + drive_pros=JointDrivePropertiesCfg( + stiffness={"JOINT[0-9]": 1e4, GRIPPER_HAND_JOINT_PATTERN: 1e3}, + damping={"JOINT[0-9]": 1e3, GRIPPER_HAND_JOINT_PATTERN: 1e2}, + max_effort={"JOINT[0-9]": 1e5, GRIPPER_HAND_JOINT_PATTERN: 1e4}, + drive_type="force", + ), + control_parts={ + "arm": ["JOINT[0-9]"], + "hand": [GRIPPER_HAND_JOINT_PATTERN], + }, + solver_cfg={ + "arm": PytorchSolverCfg( + end_link_name="ee_link", + root_link_name="base_link", + tcp=[ + [1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 1.0, GRIPPER_TCP_Z], + [0.0, 0.0, 0.0, 1.0], + ], + ) + }, + init_qpos=[0, -1.57, 1.57, -1.57, -1.57, 0.0, 0.0, 0.0], + init_pos=position, + ) + return sim.add_robot(cfg=cfg) + + +def create_fallen_bottle(sim: SimulationManager) -> RigidObject: + # Use a slightly smaller and closer bottle for the UR10 gripper demo. + bottle_scale = 0.0008 + bottle_cfg = RigidObjectCfg( + uid="bottle", + shape=MeshCfg( + fpath=get_data_path("ScannedBottle/yibao.ply"), + ), + attrs=RigidBodyAttributesCfg( + mass=0.02, + dynamic_friction=0.97, + static_friction=0.99, + ), + max_convex_hull_num=16, + init_pos=[-0.4294, -0.0825, -0.0997], + init_rot=[90.0, 135.0, 0.0], + body_scale=(bottle_scale, bottle_scale, bottle_scale), + ) + return sim.add_rigid_object(cfg=bottle_cfg) + + +def settle_object(sim: SimulationManager, obj: RigidObject, step: int = 5) -> None: + """Settle an object through the same explicit sequence on CPU and CUDA.""" + if sim.device.type == "cuda": + sim.init_gpu_physics() + + obj.reset() + sim.update(step=step) + obj.clear_dynamics() + + +def release_object_after_grasp(obj: RigidObject) -> None: + """Clear residual motion after the gripper has closed on the object.""" + obj.clear_dynamics() + + +def build_grasp_generator_cfg(args: argparse.Namespace) -> GraspGeneratorCfg: + return GraspGeneratorCfg( + viser_port=11801, + antipodal_sampler_cfg=AntipodalSamplerCfg( + n_sample=args.n_sample, + max_length=GRIPPER_MAX_OPEN_WIDTH, + min_length=0.003, + ), + is_partial_annotate=False, + is_filter_ground_collision=False, + ) + + +def build_gripper_collision_cfg() -> GripperCollisionCfg: + return GripperCollisionCfg( + max_open_length=GRIPPER_MAX_OPEN_WIDTH, + finger_length=GRIPPER_FINGER_LENGTH, + y_thickness=GRIPPER_Y_THICKNESS, + root_z_width=GRIPPER_ROOT_Z_WIDTH, + open_check_margin=0.002, + point_sample_dense=0.012, + ) + + +def get_hand_open_close_qpos( + robot: Robot, device: torch.device +) -> tuple[torch.Tensor, torch.Tensor]: + hand_limits = robot.get_qpos_limits(name="hand")[0].to( + device=device, dtype=torch.float32 + ) + return hand_limits[:, 0], hand_limits[:, 1] + + +def format_tensor(tensor: torch.Tensor) -> str: + rounded = (tensor.detach().cpu() * 10000.0).round() / 10000.0 + return str(rounded.tolist()) + + +def log_hand_setup(robot: Robot, hand_open: torch.Tensor, hand_close: torch.Tensor): + hand_joint_ids = robot.get_joint_ids(name="hand") + hand_joint_names = [robot.joint_names[joint_id] for joint_id in hand_joint_ids] + logger.log_info(f"Hand joint ids: {hand_joint_ids}") + logger.log_info(f"Hand joint names: {hand_joint_names}") + logger.log_info(f"Robot mimic ids: {robot.mimic_ids}") + logger.log_info(f"Robot mimic parents: {robot.mimic_parents}") + logger.log_info(f"Robot mimic multipliers: {robot.mimic_multipliers}") + logger.log_info(f"Robot mimic offsets: {robot.mimic_offsets}") + logger.log_info( + f"Hand qpos limits: {format_tensor(robot.get_qpos_limits(name='hand')[0])}" + ) + logger.log_info(f"Hand open qpos: {format_tensor(hand_open)}") + logger.log_info(f"Hand close qpos: {format_tensor(hand_close)}") + + +def log_hand_execution_state( + robot: Robot, + step_idx: int, + total_steps: int, + action_hand_target: torch.Tensor, +) -> None: + sim_hand_target = robot.get_qpos(name="hand", target=True) + actual_hand_qpos = robot.get_qpos(name="hand") + tracking_error = actual_hand_qpos - sim_hand_target + logger.log_info( + "Hand state " + f"step={step_idx}/{total_steps - 1}, " + f"action_target={format_tensor(action_hand_target[0])}, " + f"sim_target={format_tensor(sim_hand_target[0])}, " + f"actual={format_tensor(actual_hand_qpos[0])}, " + f"actual_minus_target={format_tensor(tracking_error[0])}" + ) + + +def log_object_execution_state( + obj: RigidObject, + step_idx: int, + total_steps: int, +) -> None: + obj_pose = obj.get_local_pose(to_matrix=True) + logger.log_info( + "Object state " + f"step={step_idx}/{total_steps - 1}, " + f"pos={format_tensor(obj_pose[0, :3, 3])}, " + f"z_axis={format_tensor(obj_pose[0, :3, 2])}" + ) + + +def get_upright_segment_lengths(action: UprightAction) -> dict[str, int]: + n_close = action.cfg.hand_interp_steps + n_final = max(2, action.cfg.final_approach_steps) + n_hold = max(0, action.cfg.grasp_hold_steps) + n_press = max(2, action.cfg.place_press_steps) + n_upright_hold = max(0, action.cfg.upright_hold_steps) + n_place_hold = max(0, action.cfg.place_hold_steps) + n_open = max(2, action.cfg.release_interp_steps) + motion_waypoints = ( + action.cfg.sample_interval + - n_close + - n_final + - n_hold + - n_upright_hold + - n_press + - n_place_hold + - n_open + ) + n_pre = max(2, int(round(motion_waypoints * 0.25))) + n_upright = max(2, int(round(motion_waypoints * 0.60))) + n_retreat = ( + action.cfg.sample_interval + - n_close + - n_final + - n_hold + - n_upright_hold + - n_press + - n_place_hold + - n_open + - n_pre + - n_upright + ) + if n_retreat < 2: + retreat_deficit = 2 - n_retreat + n_retreat = 2 + n_upright = max(2, n_upright - retreat_deficit) + return { + "pre": n_pre, + "final": n_final, + "close": n_close, + "grasp_hold": n_hold, + "upright": n_upright, + "upright_hold": n_upright_hold, + "press": n_press, + "place_hold": n_place_hold, + "open": n_open, + "retreat": n_retreat, + } + + +def log_tcp_alignment( + robot: Robot, + traj: torch.Tensor, + grasp_xpos: torch.Tensor, + arm_dof: int, + index: int, + label: str, +) -> None: + arm_qpos = traj[:, index, :arm_dof] + tcp_xpos = robot.compute_fk(qpos=arm_qpos, name="arm", to_matrix=True) + pos_error = torch.norm(tcp_xpos[:, :3, 3] - grasp_xpos[:, :3, 3], dim=1) + rot_delta = torch.bmm(tcp_xpos[:, :3, :3].transpose(1, 2), grasp_xpos[:, :3, :3]) + trace = rot_delta[:, 0, 0] + rot_delta[:, 1, 1] + rot_delta[:, 2, 2] + rot_error = torch.acos(((trace - 1.0) * 0.5).clamp(-1.0, 1.0)) + logger.log_info( + f"{label}: index={index}, " + f"tcp_pos={format_tensor(tcp_xpos[0, :3, 3])}, " + f"pos_error={format_tensor(pos_error)}, " + f"rot_error_rad={format_tensor(rot_error)}, " + f"tcp_rot={format_tensor(tcp_xpos[0, :3, :3])}, " + f"target_rot={format_tensor(grasp_xpos[0, :3, :3])}, " + f"hand_target={format_tensor(traj[0, index, arm_dof:])}" + ) + + +def log_selected_gripper_clearance( + semantics: ObjectSemantics, + obj_pose: torch.Tensor, + grasp_xpos: torch.Tensor, + grasp_open_length: torch.Tensor, +) -> None: + generator = semantics.affordance.generator + if generator is None: + return + + collision_checker = getattr(generator, "_collision_checker", None) + if collision_checker is None: + return + + gripper_pc_world = collision_checker._get_gripper_pc(grasp_xpos, grasp_open_length) + ground_height = collision_checker.get_ground_height(obj_pose[0]) + min_z = gripper_pc_world[:, :, 2].min(dim=1).values + max_z = gripper_pc_world[:, :, 2].max(dim=1).values + is_colliding, min_distance = collision_checker.query( + obj_pose=obj_pose[0], + grasp_poses=grasp_xpos, + open_lengths=grasp_open_length, + is_filter_ground_collision=True, + collision_threshold=BOTTLE_GRASP_COLLISION_THRESHOLD, + ) + logger.log_info(f"Selected gripper pc min z: {format_tensor(min_z)}") + logger.log_info(f"Selected gripper pc max z: {format_tensor(max_z)}") + logger.log_info(f"Selected grasp ground height: {ground_height:.4f}") + logger.log_info( + f"Selected grasp min collision distance: {format_tensor(min_distance)}" + ) + logger.log_info( + f"Selected grasp collision threshold: {BOTTLE_GRASP_COLLISION_THRESHOLD:.4f}" + ) + logger.log_info( + f"Selected grasp collision flag: {is_colliding.detach().cpu().tolist()}" + ) + + +def log_grasp_direction_probe(semantics: ObjectSemantics) -> None: + generator = semantics.affordance.generator + if generator is None: + return + + obj_pose = semantics.entity.get_local_pose(to_matrix=True)[0] + hit_point_pairs = getattr(generator, "_hit_point_pairs", None) + if hit_point_pairs is None or hit_point_pairs.numel() == 0: + logger.log_info("Probe grasp direction: no antipodal pairs available") + return + + origin_points = hit_point_pairs[:, 0, :] + hit_points = hit_point_pairs[:, 1, :] + origin_points_world = generator._apply_transform(origin_points, obj_pose) + hit_points_world = generator._apply_transform(hit_points, obj_pose) + centers = (origin_points_world + hit_points_world) * 0.5 + grasp_x = torch.nn.functional.normalize( + hit_points_world - origin_points_world, dim=-1 + ) + open_lengths = torch.norm(origin_points_world - hit_points_world, dim=-1) + + probe_directions = { + "top_down": [0.0, 0.0, -1.0], + "from_robot_x": [-1.0, 0.0, 0.0], + } + for label, direction in probe_directions.items(): + approach_direction = torch.tensor( + direction, dtype=torch.float32, device=generator.device + ) + cos_angle = torch.clamp((grasp_x * approach_direction).sum(dim=-1), -1.0, 1.0) + positive_angle = torch.abs(torch.acos(cos_angle)) + angle_mask = ( + positive_angle - torch.pi / 2 + ).abs() <= generator.cfg.max_deviation_angle + width_mask = open_lengths <= GRIPPER_MAX_OPEN_WIDTH + candidate_mask = angle_mask & width_mask + logger.log_info( + f"Probe grasp direction {label}: " + f"angle_count={int(angle_mask.sum().item())}, " + f"width_angle_count={int(candidate_mask.sum().item())}" + ) + if torch.any(candidate_mask): + candidate_grasp_poses = GraspGenerator._grasp_pose_from_approach_direction( + grasp_x[candidate_mask], + approach_direction, + centers[candidate_mask], + ) + candidate_open_lengths = open_lengths[candidate_mask] + is_colliding, min_distance = generator._collision_checker.query( + obj_pose, + candidate_grasp_poses, + candidate_open_lengths, + is_filter_ground_collision=True, + collision_threshold=BOTTLE_GRASP_COLLISION_THRESHOLD, + ) + collision_free_count = int((~is_colliding).sum().item()) + logger.log_info( + f"Probe grasp direction {label}: " + f"collision_free_count={collision_free_count}, " + f"collision_threshold={BOTTLE_GRASP_COLLISION_THRESHOLD:.4f}, " + f"min_distance={format_tensor(min_distance.min().unsqueeze(0))}, " + f"max_distance={format_tensor(min_distance.max().unsqueeze(0))}" + ) + if collision_free_count > 0: + grasp_xpos = candidate_grasp_poses[~is_colliding][0] + open_length = candidate_open_lengths[~is_colliding][0] + gripper_pc_world = generator._collision_checker._get_gripper_pc( + grasp_xpos.unsqueeze(0), + open_length.unsqueeze(0), + ) + ground_height = generator._collision_checker.get_ground_height(obj_pose) + min_z = gripper_pc_world[:, :, 2].min(dim=1).values + logger.log_info( + f"Probe grasp direction {label}: " + f"pos={format_tensor(grasp_xpos[:3, 3])}, " + f"open_length={open_length.item():.4f}, " + f"min_z={format_tensor(min_z)}, " + f"ground_height={ground_height:.4f}" + ) + + +def diagnose_upright_plan( + robot: Robot, + action: UprightAction, + semantics: ObjectSemantics, +) -> None: + is_success, grasp_xpos, grasp_open_length = action._resolve_grasp_pose(semantics) + if not torch.all(is_success).item(): + obj_pose = semantics.entity.get_local_pose(to_matrix=True) + logger.log_info(f"Object pos: {format_tensor(obj_pose[0, :3, 3])}") + log_grasp_direction_probe(semantics) + logger.log_warning("Failed to resolve grasp pose during diagnostics.") + return + + hand_close_qpos = action._compute_dynamic_hand_close_qpos(grasp_open_length) + final_approach_qpos = action._compute_final_approach_hand_qpos( + grasp_open_length, hand_close_qpos + ) + obj_pose = semantics.entity.get_local_pose(to_matrix=True) + approach_direction = action.approach_direction / action.approach_direction.norm() + grasp_axis_dot = torch.abs((grasp_xpos[:, :3, 0] * approach_direction).sum(dim=1)) + bottle_axis = torch.nn.functional.normalize(obj_pose[:, :3, 2], dim=1) + grasp_axis_bottle_dot = torch.abs((grasp_xpos[:, :3, 0] * bottle_axis).sum(dim=1)) + + logger.log_info(f"Object pos: {format_tensor(obj_pose[0, :3, 3])}") + logger.log_info(f"Grasp pos: {format_tensor(grasp_xpos[0, :3, 3])}") + logger.log_info(f"Grasp rotation columns: {format_tensor(grasp_xpos[0, :3, :3])}") + logger.log_info(f"Grasp open length: {format_tensor(grasp_open_length)}") + logger.log_info(f"Grasp axis approach dot: {format_tensor(grasp_axis_dot)}") + logger.log_info(f"Grasp axis bottle dot: {format_tensor(grasp_axis_bottle_dot)}") + log_selected_gripper_clearance(semantics, obj_pose, grasp_xpos, grasp_open_length) + logger.log_info( + f"Final approach hand qpos: {format_tensor(final_approach_qpos[0])}" + ) + logger.log_info(f"Close hand qpos: {format_tensor(hand_close_qpos[0])}") + + is_success, traj, joint_ids = action.execute(semantics) + if not is_success: + logger.log_warning("Failed to plan upright trajectory during diagnostics.") + return + + arm_dof = len(robot.get_joint_ids(name="arm")) + segments = get_upright_segment_lengths(action) + logger.log_info(f"Action joint ids: {joint_ids}") + logger.log_info(f"Upright trajectory segments: {segments}") + logger.log_info(f"Trajectory shape: {tuple(traj.shape)}") + + grasp_idx = segments["pre"] + segments["final"] - 1 + close_end_idx = grasp_idx + segments["close"] + hold_end_idx = close_end_idx + segments["grasp_hold"] + log_tcp_alignment(robot, traj, grasp_xpos, arm_dof, grasp_idx, "grasp") + log_tcp_alignment(robot, traj, grasp_xpos, arm_dof, close_end_idx, "close_end") + log_tcp_alignment(robot, traj, grasp_xpos, arm_dof, hold_end_idx, "hold_end") + + +def create_object_semantics( + obj: RigidObject, args: argparse.Namespace +) -> ObjectSemantics: + return ObjectSemantics( + label=BOTTLE_LABEL, + geometry={ + "mesh_vertices": obj.get_vertices(env_ids=[0], scale=True)[0], + "mesh_triangles": obj.get_triangles(env_ids=[0])[0], + }, + affordance=AntipodalAffordance( + object_label=BOTTLE_LABEL, + force_reannotate=args.force_reannotate, + custom_config={ + "gripper_collision_cfg": build_gripper_collision_cfg(), + "generator_cfg": build_grasp_generator_cfg(args), + }, + ), + entity=obj, + ) + + +def run_upright_demo( + args: argparse.Namespace, sim: SimulationManager, robot: Robot +) -> None: + + sim.open_window() + + obj = create_fallen_bottle(sim) + settle_object(sim, obj, step=5) + semantics = create_object_semantics(obj, args) + motion_gen = MotionGenerator( + cfg=MotionGenCfg(planner_cfg=ToppraPlannerCfg(robot_uid=robot.uid)) + ) + hand_open, hand_close = get_hand_open_close_qpos(robot, sim.device) + if args.debug_hand_state: + log_hand_setup(robot, hand_open, hand_close) + + upright_action = UprightAction( + motion_generator=motion_gen, + cfg=UprightActionCfg( + control_part="arm", + hand_control_part="hand", + hand_open_qpos=hand_open, + hand_close_qpos=hand_close, + approach_direction=torch.tensor( + BOTTLE_APPROACH_DIRECTION, dtype=torch.float32, device=sim.device + ), + pre_grasp_distance=0.15, + lift_height=0.15, + sample_interval=PGI_SAMPLE_INTERVAL, + hand_interp_steps=PGI_HAND_CLOSE_STEPS, + upright_axis_sign=-1.0, + place_press_depth=0.0, + place_press_steps=4, + upright_hold_steps=3, + place_hold_steps=8, + release_interp_steps=12, + release_retreat_distance=0.08, + release_retreat_lift=0.01, + final_approach_steps=12, + final_approach_preclose_width_margin=0.010, + grasp_hold_steps=PGI_GRASP_HOLD_STEPS, + use_grasp_width_qpos=True, + gripper_max_open_width=GRIPPER_MAX_OPEN_WIDTH, + max_grasp_open_length=BOTTLE_MAX_GRASP_OPEN_LENGTH, + max_grasp_axis_approach_dot=BOTTLE_MAX_GRASP_AXIS_APPROACH_DOT, + max_grasp_axis_upright_axis_dot=BOTTLE_MAX_GRASP_AXIS_UPRIGHT_AXIS_DOT, + grasp_squeeze_width=BOTTLE_GRASP_SQUEEZE_WIDTH, + min_dynamic_hand_close_qpos=torch.full_like( + hand_close, BOTTLE_MIN_DYNAMIC_HAND_CLOSE_QPOS + ), + ), + ) + + if args.diagnose_grasp: + diagnose_upright_plan(robot, upright_action, semantics) + return + + if not args.auto_play: + input("Inspect the fallen bottle, then press Enter to plan upright...") + + start_time = time.time() + is_success, traj, joint_ids = upright_action.execute(semantics) + cost_time = time.time() - start_time + logger.log_info(f"Plan upright trajectory cost time: {cost_time:.2f} seconds") + if not is_success: + logger.log_warning("Failed to plan upright trajectory.") + return + + arm_dof = len(robot.get_joint_ids(name="arm")) + total_steps = traj.shape[1] + segments = get_upright_segment_lengths(upright_action) + post_grasp_clear_step = segments["pre"] + segments["final"] + segments["close"] + should_clear_object_dynamics = True + if args.debug_hand_state: + joint_names = [robot.joint_names[joint_id] for joint_id in joint_ids] + hand_traj = traj[:, :, arm_dof:] + logger.log_info(f"Action joint ids: {joint_ids}") + logger.log_info(f"Action joint names: {joint_names}") + logger.log_info( + f"Post-grasp object dynamics clear step: {post_grasp_clear_step}" + ) + logger.log_info( + f"Planned hand qpos min: {format_tensor(hand_traj.min(dim=1).values[0])}" + ) + logger.log_info( + f"Planned hand qpos max: {format_tensor(hand_traj.max(dim=1).values[0])}" + ) + + if not args.auto_play: + input("Press Enter to start the upright demo...") + last_logged_hand_target: torch.Tensor | None = None + log_stride = max(1, total_steps // 10) + for i in range(traj.shape[1]): + robot.set_qpos(traj[:, i, :], joint_ids=joint_ids) + sim.update(step=4) + if should_clear_object_dynamics and i + 1 >= post_grasp_clear_step: + release_object_after_grasp(obj) + should_clear_object_dynamics = False + if args.debug_hand_state: + logger.log_info( + f"Object dynamics cleared at step={i}/{total_steps - 1}" + ) + if args.debug_hand_state: + action_hand_target = traj[:, i, arm_dof:] + target_changed = last_logged_hand_target is None or not torch.allclose( + action_hand_target, last_logged_hand_target, atol=1e-4 + ) + should_log = target_changed or i % log_stride == 0 or i == total_steps - 1 + if should_log: + log_hand_execution_state( + robot, + step_idx=i, + total_steps=total_steps, + action_hand_target=action_hand_target, + ) + last_logged_hand_target = action_hand_target.detach().clone() + if i % log_stride == 0 or i == total_steps - 1: + log_object_execution_state( + obj, + step_idx=i, + total_steps=total_steps, + ) + time.sleep(1e-2) + if not args.auto_play: + input("Press Enter to exit the simulation...") + + +def main() -> None: + args = parse_arguments() + sim = initialize_simulation(args) + robot = create_robot(sim, position=[0.0, 0.0, 0.0]) + run_upright_demo(args, sim, robot) + + +if __name__ == "__main__": + main()