Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion examples/reach_plan_sweep.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,12 @@

from isaaclab.app import AppLauncher

parser = argparse.ArgumentParser(description="Sample dx/dy/dz/yaw around base reach pose and batch-plan with cuRobo.")
parser = argparse.ArgumentParser(
description=(
"Uniformly sample different object yaw rotations, select the runtime reach candidate, and sweep locally with"
" cuRobo."
)
)
parser.add_argument("--pipeline_id", required=True, type=str)
parser.add_argument(
"--reach_skill_index",
Expand All @@ -19,6 +24,12 @@
parser.add_argument("--yaw_deg", default=10.0, type=float, help="yaw range is [-yaw_deg, yaw_deg] degrees")
parser.add_argument("--seed", default=42, type=int)
parser.add_argument("--top_k", default=10, type=int)
parser.add_argument(
"--num_object_rotations",
default=4,
type=int,
help="Number of uniformly sampled object yaw rotations in [0, 360) degrees.",
)
parser.add_argument(
"--ik_only",
action="store_true",
Expand Down Expand Up @@ -57,6 +68,7 @@ def main() -> None:
),
top_k=args_cli.top_k,
ik_only=args_cli.ik_only,
num_object_rotations=args_cli.num_object_rotations,
),
)

Expand Down
263 changes: 207 additions & 56 deletions source/autosim/autosim/calibration/plan_sweep.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import math
import time
from dataclasses import dataclass, field
from typing import Any
Expand All @@ -22,9 +23,12 @@ class ReachPlanSweepCfg:
"""Number of top poses to print."""
ik_only: bool = False
"""If True, use IK-only solving instead of full motion planning. Much faster for reachability checking; does not produce trajectories."""
num_object_rotations: int = 4
"""Number of uniformly sampled object yaw rotations in [0, 2π)."""


def _tensor_to_list(x: torch.Tensor) -> list[float]:
"""Convert a tensor to a flat Python float list for reporting/serialization."""
return [float(v) for v in x.detach().cpu().flatten().tolist()]


Expand All @@ -33,6 +37,42 @@ def _fmt_pose(vals: list[float]) -> str:
return "[" + ", ".join(f"{v:.4f}" for v in vals) + "]"


def _quat_mul(q1: torch.Tensor, q2: torch.Tensor) -> torch.Tensor:
"""Multiply quaternions in wxyz format elementwise over the leading dimensions."""
w1, x1, y1, z1 = q1.unbind(-1)
w2, x2, y2, z2 = q2.unbind(-1)
return torch.stack(
[
w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2,
w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2,
w1 * y2 - x1 * z2 + y1 * w2 + z1 * x2,
w1 * z2 + x1 * y2 - y1 * x2 + z1 * w2,
],
dim=-1,
)


def _uniform_yaw_rotations(device: torch.device, dtype: torch.dtype, num_rotations: int) -> list[torch.Tensor]:
"""Generate uniformly spaced yaw-only quaternions over a full 360° rotation."""
if num_rotations <= 0:
return []
rotations = []
for idx in range(num_rotations):
yaw = (2.0 * math.pi * idx) / num_rotations
half = yaw * 0.5
rotations.append(torch.tensor([math.cos(half), 0.0, 0.0, math.sin(half)], device=device, dtype=dtype))
return rotations


def _row_sort_key(row: dict[str, Any]) -> tuple[float, float, float]:
"""Rank sampled poses by success first, then shorter trajectories, then lower position error."""
return (
0.0 if row["plan_success"] else 1.0,
float(row["traj_len"]) if row["traj_len"] is not None else 10**8,
float(row["position_error"]) if row["position_error"] is not None else 10**8,
)


def _build_extra_target_link_goals(
reach_skill: ReachSkill,
activate_q: torch.Tensor,
Expand Down Expand Up @@ -63,21 +103,12 @@ def reach_plan_sweep(pipeline: AutoSimPipeline, cfg: ReachPlanSweepCfg) -> list[
All skills before the target reach skill are executed normally, so the
environment reflects the actual state at the point of interest.

Returns:
Top-k result rows sorted by plan quality. Each row contains:
"pose_oe": list[float] — main EE pose in object frame [x,y,z,qw,qx,qy,qz]
"plan_success": bool
"traj_len": int | None — trajectory length (full planning only)
"position_error": float | None — IK position error (IK-only mode)

Typical multi-reach workflow (each step is a separate script invocation):
# Step 1: sweep reach 0, note the best pose_oe from the printout
python reach_plan_sweep.py --reach_skill_index 0 ...

# Manually update object_reach_target_poses[obj][0] in the pipeline code
The offline tuning flow uniformly samples object yaw rotations, applies each
rotation to the target object, reuses the runtime reach candidate selection
logic to choose a base pose, and then runs the local sweep around that pose.

# Step 2: sweep reach 1; reach 0 now runs with the updated hard-coded pose
python reach_plan_sweep.py --reach_skill_index 1 ...
Returns:
One result block per sampled object rotation.
"""

pipeline.initialize()
Expand All @@ -87,7 +118,6 @@ def reach_plan_sweep(pipeline: AutoSimPipeline, cfg: ReachPlanSweepCfg) -> list[
pipeline.reset_env()

reach_skill_counter = 0
reach_count_per_object: dict[str, int] = {}

for subtask in decompose_result.subtasks:
for skill_info in subtask.skills:
Expand All @@ -103,14 +133,10 @@ def reach_plan_sweep(pipeline: AutoSimPipeline, cfg: ReachPlanSweepCfg) -> list[

if is_reach and reach_skill_counter == cfg.reach_skill_index:
obj_name = skill_info.target_object
obj_pose_idx = reach_count_per_object.get(obj_name, 0)
return _sweep(pipeline, cfg, obj_name, obj_pose_idx, skill)
return _sweep_all_rotations(pipeline, cfg, obj_name, skill)

goal = skill.extract_goal_from_info(skill_info, pipeline._env, pipeline._env_extra_info)
if is_reach:
reach_count_per_object[skill_info.target_object] = (
reach_count_per_object.get(skill_info.target_object, 0) + 1
)
reach_skill_counter += 1

success, _ = pipeline._execute_single_skill(skill, goal)
Expand All @@ -126,20 +152,156 @@ def reach_plan_sweep(pipeline: AutoSimPipeline, cfg: ReachPlanSweepCfg) -> list[
)


def _sweep(
def _sweep_all_rotations(
pipeline: AutoSimPipeline,
cfg: ReachPlanSweepCfg,
obj_name: str,
obj_pose_idx: int,
reach_skill: ReachSkill,
) -> list[dict[str, Any]]:
"""Evaluate the target reach step across uniformly sampled object yaw rotations.

For each sampled object rotation, this function temporarily updates the object's
world pose, reuses the runtime reach candidate selector to choose a base pose,
runs the local sweep around that base pose, and finally restores the original
object pose.
"""
env = pipeline._env
env_extra_info = pipeline._env_extra_info
obj = env.scene[obj_name]

original_pose_w = obj.data.root_pose_w[pipeline._env_id].clone()
base_quat_w = original_pose_w[3:].clone()
env_ids = torch.tensor([pipeline._env_id], device=original_pose_w.device, dtype=torch.int32)
results: list[dict[str, Any]] = []

try:
for rotation_idx, yaw_quat_w in enumerate(
_uniform_yaw_rotations(original_pose_w.device, original_pose_w.dtype, cfg.num_object_rotations)
):
rotated_pose_w = original_pose_w.clone().unsqueeze(0)
rotated_pose_w[0, 3:] = _quat_mul(base_quat_w.unsqueeze(0), yaw_quat_w.unsqueeze(0)).squeeze(0)
obj.write_root_pose_to_sim(rotated_pose_w, env_ids=env_ids)

candidates = env_extra_info.get_reach_target_poses(obj_name)
selected_pose_oe = reach_skill._select_best_candidate( # noqa: SLF001
env, obj_name, candidates, env_extra_info
).to(env.device)
selected_idx = next(
idx for idx, pose in enumerate(candidates) if torch.allclose(pose.to(env.device), selected_pose_oe)
)

result_block = _sweep(
pipeline=pipeline,
cfg=cfg,
obj_name=obj_name,
base_pose_oe=selected_pose_oe,
reach_skill=reach_skill,
rotation_idx=rotation_idx,
object_pose_w=rotated_pose_w.squeeze(0),
selected_candidate_idx=selected_idx,
)
results.append(result_block)
finally:
obj.write_root_pose_to_sim(original_pose_w.unsqueeze(0), env_ids=env_ids)

_print_summary(obj_name, cfg, results)
return results


def _print_summary(obj_name: str, cfg: ReachPlanSweepCfg, results: list[dict[str, Any]]) -> None:
"""Print the final per-rotation report and per-selected-candidate aggregate summary."""
if not results:
return

_SEP = "═" * 100
print()
print(_SEP)
print(f" reach_plan_sweep summary │ object='{obj_name}' reach_skill_index={cfg.reach_skill_index}")
print(_SEP)

candidate_summary: dict[int, dict[str, Any]] = {}

for block in results:
rotation_idx = block["rotation_index"]
success_count = block["success_count"]
num_samples = block["num_samples"]
elapsed_ms = block["elapsed_ms"]
selected_candidate_idx = block["selected_candidate_idx"]
selected_base_pose_oe = block["selected_base_pose_oe"]
print(
f" rotation={rotation_idx:02d} selected_candidate={selected_candidate_idx} "
f"success={success_count}/{num_samples} ({success_count / num_samples:.1%}) time={elapsed_ms:.0f} ms"
)
print(f" base_pose={_fmt_pose(selected_base_pose_oe)}")
for rank, row in enumerate(block["top_k"]):
mark = "✓" if row["plan_success"] else "✗"
metric = (
f"traj_len={row['traj_len']}" if row["traj_len"] is not None else f"pos_err={row['position_error']:.4f}"
)
print(f" [{rank}] {mark} {_fmt_pose(row['pose_oe'])} # {metric}")
print("─" * 100)

summary = candidate_summary.setdefault(
selected_candidate_idx,
{
"count": 0,
"total_success": 0,
"total_samples": 0,
"total_time_ms": 0.0,
"base_pose": selected_base_pose_oe,
"recommended_pose": None,
"recommended_row": None,
},
)
summary["count"] += 1
summary["total_success"] += success_count
summary["total_samples"] += num_samples
summary["total_time_ms"] += elapsed_ms

candidate_best_row = min(block["top_k"], key=_row_sort_key) if block["top_k"] else None
if candidate_best_row is not None:
if summary["recommended_row"] is None or _row_sort_key(candidate_best_row) < _row_sort_key(
summary["recommended_row"]
):
summary["recommended_row"] = candidate_best_row
summary["recommended_pose"] = candidate_best_row["pose_oe"]

print()
print(_SEP)
print(" selected_candidate aggregate")
print(_SEP)
for candidate_idx in sorted(candidate_summary):
summary = candidate_summary[candidate_idx]
total_samples = summary["total_samples"]
success_rate = summary["total_success"] / total_samples if total_samples > 0 else 0.0
avg_time_ms = summary["total_time_ms"] / summary["count"] if summary["count"] > 0 else 0.0
print(
f" candidate={candidate_idx} selected_in={summary['count']} rotation(s) "
f"success={summary['total_success']}/{total_samples} ({success_rate:.1%}) avg_time={avg_time_ms:.0f} ms"
)
print(f" base_pose={_fmt_pose(summary['base_pose'])}")
if summary["recommended_pose"] is not None:
print(f" recommended_pose={_fmt_pose(summary['recommended_pose'])}")
print(_SEP)


def _sweep(
pipeline: AutoSimPipeline,
cfg: ReachPlanSweepCfg,
obj_name: str,
base_pose_oe: torch.Tensor,
reach_skill: ReachSkill,
rotation_idx: int,
object_pose_w: torch.Tensor,
selected_candidate_idx: int,
) -> dict[str, Any]:
"""
Core sweep logic. Called once the environment is in the correct pre-reach state.

Samples K candidate poses around the base reach target (object frame), transforms them to
robot root frame, then batch-plans with cuRobo. When configured, extra link goals are
generated from the live joint state using the same extra-target strategy as
`ReachSkill.extract_goal_from_info()`.
Samples K candidate poses around the selected base reach target (object frame),
transforms them to robot root frame, then batch-plans with cuRobo. When
configured, extra link goals are generated from the live joint state using the
same extra-target strategy as `ReachSkill.extract_goal_from_info()`.
"""

env = pipeline._env
Expand All @@ -148,11 +310,7 @@ def _sweep(
planner = pipeline._motion_planner
robot = pipeline._robot

pose_list = env_extra_info.object_reach_target_poses[obj_name]
if not (0 <= obj_pose_idx < len(pose_list)):
raise ValueError(f"pose index {obj_pose_idx} out of range for object '{obj_name}' ({len(pose_list)} poses).")

base_pose_oe = torch.as_tensor(pose_list[obj_pose_idx], device=env.device, dtype=torch.float32).view(7)
base_pose_oe = torch.as_tensor(base_pose_oe, device=env.device, dtype=torch.float32).view(7)
poses_oe = cfg.sampling.sample(base_pose_oe)
k = int(poses_oe.shape[0])

Expand Down Expand Up @@ -186,11 +344,15 @@ def _sweep(
dt_ms = (time.time() - t0) * 1000.0

success = (
result.success.detach().cpu().bool() if result.success is not None else torch.zeros((k,), dtype=torch.bool)
result.success.detach().cpu().bool().reshape(-1)
if result.success is not None
else torch.zeros((k,), dtype=torch.bool)
)
pos_err = result.position_error.detach().cpu() if result.position_error is not None else None
pos_err = result.position_error.detach().cpu().reshape(-1) if result.position_error is not None else None
traj_last = (
result.path_buffer_last_tstep if (not cfg.ik_only and result.path_buffer_last_tstep is not None) else None
torch.as_tensor(result.path_buffer_last_tstep).reshape(-1)
if (not cfg.ik_only and result.path_buffer_last_tstep is not None)
else None
)

rows = []
Expand All @@ -202,25 +364,14 @@ def _sweep(
"position_error": float(pos_err[i].item()) if pos_err is not None else None,
})

def _sort_key(r):
if not r["plan_success"]:
return (10**9, 10**9)
return (r["traj_len"] or 10**8, r["position_error"] or 10**8)

top_k = sorted(rows, key=_sort_key)[: cfg.top_k]

success_count = int(success.sum().item())
_SEP = "─" * 80
print(_SEP)
print(f" reach_plan_sweep │ object='{obj_name}' reach_skill_index={cfg.reach_skill_index}")
print(f" success {success_count}/{k} ({success_count / k:.1%}) │ time {dt_ms:.0f} ms")
print(_SEP)
print(f" top {len(top_k)} poses (object frame [x, y, z, qw, qx, qy, qz])")
print()
for rank, r in enumerate(top_k):
mark = "✓" if r["plan_success"] else "✗"
metric = f"traj_len={r['traj_len']}" if r["traj_len"] is not None else f"pos_err={r['position_error']:.4f}"
print(f" [{rank}] {mark} {_fmt_pose(r['pose_oe'])} # {metric}")
print(_SEP)

return top_k
top_k = sorted(rows, key=_row_sort_key)[: cfg.top_k]
return {
"rotation_index": rotation_idx,
"object_pose_w": _tensor_to_list(object_pose_w),
"selected_candidate_idx": selected_candidate_idx,
"selected_base_pose_oe": _tensor_to_list(base_pose_oe),
"success_count": int(success.sum().item()),
"num_samples": k,
"elapsed_ms": dt_ms,
"top_k": top_k,
}
Loading
Loading