Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 24 additions & 19 deletions robodriver/scripts/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,51 +2,54 @@
from dataclasses import asdict, dataclass
from pprint import pformat
from typing import List

import time
import cv2
import logging_mp

from lerobot.robots import RobotConfig

from robodriver.robots.daemon import Daemon
from robodriver.utils.utils import git_branch_log
from robodriver.utils.import_utils import register_third_party_devices
from robodriver.utils import parser
from robodriver.utils.constants import DEFAULT_FPS
from robodriver.utils.import_utils import register_third_party_devices
from robodriver.utils.utils import git_branch_log

from robodriver.robots.utils import (
Robot,
busy_wait,
make_robot_from_config,
)


logging_mp.basic_config(level=logging_mp.INFO)
logger = logging_mp.get_logger(__name__)


@dataclass
class ControlPipelineConfig:
robot: RobotConfig

@classmethod
def __get_path_fields__(cls) -> List[str]:
"""This enables the parser to load config from the policy using `--policy.path=local/dir`"""
return ["control.policy"]


@parser.wrap()
async def async_main(cfg: ControlPipelineConfig):
robot = make_robot_from_config(cfg.robot)
robot.connect()
logger.info(pformat(asdict(cfg)))

while True:
start_loop_t = time.perf_counter()
robot.get_observation()

robot.send_action()
dt_s = time.perf_counter() - start_loop_t
try:
robot = make_robot_from_config(cfg.robot)
except Exception as e:
logger.critical(f"Failed to create robot: {e}")
raise

busy_wait(1 / 30 - dt_s)
logger.info("Make robot success")
logger.info(f"robot.type: {robot.robot_type}")

if not robot.is_connected:
robot.connect()
logger.info("Connect robot success")

# 在下面实现推理代码
observation = robot.get_observation() # 从臂和图像信息
action = None
if action:
robot.send_action(action)

def main():
git_branch_log()
Expand All @@ -59,3 +62,5 @@ def main():

if __name__ == "__main__":
main()

# 启动命令:python -m robodriver.scripts.evaluate --robot.type=galaxealite-aio-ros2