Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 53 additions & 2 deletions src/envs/fleet_env/task_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,11 @@
4. Executes verifier for reward on episode completion
"""

import ast
import asyncio
import logging
import os
import re
from typing import Any, Dict, List, Optional, Tuple

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -77,6 +79,8 @@ class FleetTaskEnv:
modality: 1800s (30 min) for computer_use, 900s (15 min) for tool_use.
max_steps: Maximum steps per episode (default: 50)
request_timeout_s: HTTP request timeout in seconds (default: 60.0)
partial_reward: If True, compute partial scores from verifier
error/success accumulators instead of binary 0/1 (default: False)

Example:
>>> task_config = {
Expand All @@ -100,9 +104,11 @@ def __init__(
max_steps: int = 50,
request_timeout_s: float = 60.0,
reset_timeout_s: float = 10.0,
partial_reward: bool = False,
):
self.task = task_config
self.api_key = api_key or os.environ.get("FLEET_API_KEY")
self.partial_reward = partial_reward
# Auto-select TTL based on modality if not explicitly provided
if ttl_seconds is not None:
self.ttl_seconds = ttl_seconds
Expand Down Expand Up @@ -517,14 +523,44 @@ async def step_async(

return obs, reward, self._done, info

@staticmethod
def _parse_partial_reward(stdout: str) -> Optional[float]:
"""Parse partial reward from verifier accumulator output.

Verifiers print error/success accumulators to stdout. This parses
them to compute a fractional score (n_success / total_checks).

Returns:
Partial score in [0, 1], or None if accumulators not found.
"""
err_match = re.search(
r">>> ERROR_ACCUMULATOR >>>\n(.+?)\n<<< ERROR_ACCUMULATOR <<<",
stdout,
re.DOTALL,
)
suc_match = re.search(
r">>> SUCCESS_ACCUMULATOR >>>\n(.+?)\n<<< SUCCESS_ACCUMULATOR <<<",
stdout,
re.DOTALL,
)
if not err_match and not suc_match:
return None

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing accumulator gives failed verifier a perfect score

High Severity

The guard not err_match and not suc_match only returns None when both accumulators are missing. When only SUCCESS_ACCUMULATOR is present (e.g., verifier crashed before printing ERROR_ACCUMULATOR), n_errors defaults to 0, so the partial score computes to n_success / n_success = 1.0. This silently overrides a genuinely failed verifier with a perfect reward of 1.0, corrupting the training signal. The condition likely needs to be not err_match or not suc_match to require both accumulators before computing a partial score.

Fix in Cursor Fix in Web

try:
n_errors = len(ast.literal_eval(err_match.group(1))) if err_match else 0
n_success = len(ast.literal_eval(suc_match.group(1))) if suc_match else 0
total = n_errors + n_success
return n_success / total if total > 0 else None
except Exception:
return None

async def _compute_reward(self) -> float:
"""Compute reward by executing the verifier using Fleet SDK.

Uses Fleet SDK's Task.verify_detailed() which properly sets up the
verifier namespace with Environment type, helper functions, etc.

Returns:
1.0 if verifier passes, 0.0 otherwise
1.0 if verifier passes, 0.0 otherwise (or partial if enabled)
"""
# Support both field names: verifier_code (OpenEnv) and verifier_func (Fleet SDK)
verifier_code = self.task.get("verifier_code") or self.task.get("verifier_func")
Expand Down Expand Up @@ -576,8 +612,23 @@ async def _compute_reward(self) -> float:
score = 0.0

verifier_success = response.success

# Partial reward: use accumulator counts instead of binary 0/1
partial_score = None
if (
self.partial_reward
and score == 0.0
and hasattr(response, "stdout")
and response.stdout
):
partial_score = self._parse_partial_reward(response.stdout)
if partial_score is not None:
score = partial_score

logger.info(
f"Task {self.task_key}: verifier returned success={response.success}, result={response.result}, score={score}"
f"Task {self.task_key}: verifier returned success={response.success}, "
f"result={response.result}, score={score}"
+ (f", partial={partial_score:.3f}" if partial_score is not None else "")
)

except ImportError as e:
Expand Down