diff --git a/src/envs/fleet_env/task_env.py b/src/envs/fleet_env/task_env.py index a36a8d24b..1d82fe140 100644 --- a/src/envs/fleet_env/task_env.py +++ b/src/envs/fleet_env/task_env.py @@ -8,9 +8,11 @@ 4. Executes verifier for reward on episode completion """ +import ast import asyncio import logging import os +import re from typing import Any, Dict, List, Optional, Tuple logger = logging.getLogger(__name__) @@ -77,6 +79,8 @@ class FleetTaskEnv: modality: 1800s (30 min) for computer_use, 900s (15 min) for tool_use. max_steps: Maximum steps per episode (default: 50) request_timeout_s: HTTP request timeout in seconds (default: 60.0) + partial_reward: If True, compute partial scores from verifier + error/success accumulators instead of binary 0/1 (default: False) Example: >>> task_config = { @@ -100,9 +104,11 @@ def __init__( max_steps: int = 50, request_timeout_s: float = 60.0, reset_timeout_s: float = 10.0, + partial_reward: bool = False, ): self.task = task_config self.api_key = api_key or os.environ.get("FLEET_API_KEY") + self.partial_reward = partial_reward # Auto-select TTL based on modality if not explicitly provided if ttl_seconds is not None: self.ttl_seconds = ttl_seconds @@ -517,6 +523,36 @@ async def step_async( return obs, reward, self._done, info + @staticmethod + def _parse_partial_reward(stdout: str) -> Optional[float]: + """Parse partial reward from verifier accumulator output. + + Verifiers print error/success accumulators to stdout. This parses + them to compute a fractional score (n_success / total_checks). + + Returns: + Partial score in [0, 1], or None if accumulators not found. + """ + err_match = re.search( + r">>> ERROR_ACCUMULATOR >>>\n(.+?)\n<<< ERROR_ACCUMULATOR <<<", + stdout, + re.DOTALL, + ) + suc_match = re.search( + r">>> SUCCESS_ACCUMULATOR >>>\n(.+?)\n<<< SUCCESS_ACCUMULATOR <<<", + stdout, + re.DOTALL, + ) + if not err_match and not suc_match: + return None + try: + n_errors = len(ast.literal_eval(err_match.group(1))) if err_match else 0 + n_success = len(ast.literal_eval(suc_match.group(1))) if suc_match else 0 + total = n_errors + n_success + return n_success / total if total > 0 else None + except Exception: + return None + async def _compute_reward(self) -> float: """Compute reward by executing the verifier using Fleet SDK. @@ -524,7 +560,7 @@ async def _compute_reward(self) -> float: verifier namespace with Environment type, helper functions, etc. Returns: - 1.0 if verifier passes, 0.0 otherwise + 1.0 if verifier passes, 0.0 otherwise (or partial if enabled) """ # Support both field names: verifier_code (OpenEnv) and verifier_func (Fleet SDK) verifier_code = self.task.get("verifier_code") or self.task.get("verifier_func") @@ -576,8 +612,23 @@ async def _compute_reward(self) -> float: score = 0.0 verifier_success = response.success + + # Partial reward: use accumulator counts instead of binary 0/1 + partial_score = None + if ( + self.partial_reward + and score == 0.0 + and hasattr(response, "stdout") + and response.stdout + ): + partial_score = self._parse_partial_reward(response.stdout) + if partial_score is not None: + score = partial_score + logger.info( - f"Task {self.task_key}: verifier returned success={response.success}, result={response.result}, score={score}" + f"Task {self.task_key}: verifier returned success={response.success}, " + f"result={response.result}, score={score}" + + (f", partial={partial_score:.3f}" if partial_score is not None else "") ) except ImportError as e: