Skip to content

Commit 3418b59

Browse files
committed
fix(gepa): enable arg description optimization for ReAct tools
1 parent 91331d0 commit 3418b59

File tree

2 files changed

+34
-22
lines changed

2 files changed

+34
-22
lines changed

dspy/teleprompt/gepa/gepa.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import inspect
2+
import json
23
import logging
34
import random
4-
import json
55
from dataclasses import dataclass
66
from typing import Any, Literal, Optional, Protocol, Union
77

@@ -540,26 +540,27 @@ def feedback_fn(
540540
if not isinstance(module, ReAct):
541541
continue
542542
prefix = module_path.removeprefix("self.") if module_path != "self" else ""
543-
543+
544544
# Get first predictor name as module identifier
545545
for pred_name, _ in module.named_predictors():
546546
comp_name = pred_name if not prefix else f"{prefix}.{pred_name}"
547547
module_key = f"react_module:{comp_name.split('.')[0]}" if prefix else "react_module"
548-
549-
# Build JSON config
548+
549+
# Build JSON config with tool args for reflection
550550
config = {
551551
"react": module.react.signature.instructions,
552552
"extract": module.extract.predict.signature.instructions,
553553
"tools": {
554554
tool_name: {
555555
"desc": tool.desc,
556+
"args": tool.args,
556557
"arg_desc": tool.arg_desc or {}
557558
}
558559
for tool_name, tool in module.tools.items()
559560
if tool_name != "finish"
560561
}
561562
}
562-
563+
563564
# Replace predictor keys with module key and extract key to prevent duplicates
564565
base_program.pop(comp_name, None)
565566
extract_key = f"{prefix}.extract.predict" if prefix else "extract.predict"

dspy/teleprompt/gepa/instruction_proposal.py

Lines changed: 28 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -322,7 +322,9 @@ class GenerateImprovedReActDescriptionsFromFeedback(dspy.Signature):
322322
"""Improve a ReAct agent based on execution examples and feedback.
323323
324324
Analyze the trajectories to identify successful patterns and failure causes.
325-
Generate improved instructions and/or improved tool descriptions to help the agent succeed on similar tasks."""
325+
Generate improved texts to help the agent succeed on similar tasks.
326+
Place improved texts at their appropriate level of abstraction and specificity.
327+
"""
326328

327329
current_react_instruction = dspy.InputField(
328330
desc="Current ReAct module instruction guiding the ReAct agent's reasoning and tool selection"
@@ -331,7 +333,8 @@ class GenerateImprovedReActDescriptionsFromFeedback(dspy.Signature):
331333
desc="Current Extract module instruction for extracting final answers from trajectories"
332334
)
333335
current_tools = dspy.InputField(
334-
desc="Available tools with current descriptions"
336+
annotation=list[dspy.Tool],
337+
desc="Available tools with their complete schemas"
335338
)
336339
examples_with_feedback = dspy.InputField(
337340
desc="Execution examples with feedback showing successes and failures"
@@ -410,37 +413,46 @@ def __call__(
410413
logger.error(f"Failed to deserialize config for {module_key}: {e}")
411414
continue
412415

413-
# Build dynamic signature by extending base signature
414-
# Extract current tools from config
415-
current_tools = current_react_config.get("tools", {})
416-
logger.info(f"Found {len(current_tools)} tools: {list(current_tools.keys())}")
416+
# Reconstruct Tool objects from serialized schema
417+
current_tools_dict = current_react_config.get("tools", {})
418+
logger.info(f"Found {len(current_tools_dict)} tools: {list(current_tools_dict.keys())}")
419+
tools_list = []
420+
for tool_name, tool_info in current_tools_dict.items():
421+
tool = dspy.Tool(
422+
func=lambda: None,
423+
name=tool_name,
424+
desc=tool_info.get("desc", ""),
425+
)
426+
tool.args = tool_info.get("args", {})
427+
tool.arg_desc = tool_info.get("arg_desc", {})
428+
tools_list.append(tool)
417429

418430
# Build dynamic signature by extending base signature
419431
signature = GenerateImprovedReActDescriptionsFromFeedback
420432

421-
logger.debug(f"Building dynamic signature with {len(current_tools)} tools...")
433+
logger.debug(f"Building dynamic signature with {len(tools_list)} tools...")
422434

423435
# Add dynamic tool description and arg descriptions output fields
424-
for tool_name, tool_info in current_tools.items():
436+
for tool in tools_list:
437+
tool_name = tool.name
438+
tool_info = current_tools_dict[tool_name]
425439
sanitized_tool_name = self._sanitize_name(tool_name)
426440

427-
# Tool description (optional)
428441
signature = signature.append(
429442
f"improved_tool_{sanitized_tool_name}_desc",
430443
dspy.OutputField(
431-
desc=f"Improved description for tool '{tool_name}' (optional - leave empty to keep current)",
432-
default="" # Make optional
444+
desc=f"Improved description for tool '{tool_name}'",
445+
default=""
433446
)
434447
)
435448

436-
# Tool arg descriptions (always available if tool has args, optional)
437449
if tool_info.get("args"):
438450
for arg_name in tool_info["args"].keys():
439451
signature = signature.append(
440452
f"improved_tool_{sanitized_tool_name}_arg_{arg_name}_desc",
441453
dspy.OutputField(
442-
desc=f"Improved description for parameter '{arg_name}' (optional)",
443-
default="" # Optional - enables cold start
454+
desc=f"Improved description for parameter '{arg_name}'",
455+
default=""
444456
)
445457
)
446458

@@ -449,13 +461,12 @@ def __call__(
449461
logger.info(f"Formatted {len(reflective_dataset[module_key])} reflective examples")
450462
logger.debug(f"Examples preview: {formatted_examples[:200]}...")
451463

452-
# Call reflection LM
453464
logger.info("Calling reflection LM with dynamic signature...")
454465
propose_descriptions = dspy.Predict(signature)
455466
result = propose_descriptions(
456467
current_react_instruction=current_react_config.get("react", ""),
457468
current_extract_instruction=current_react_config.get("extract", ""),
458-
current_tools=list(current_tools.items()), # List of (name, info) tuples
469+
current_tools=tools_list, # List of Tool objects for adapter formatting
459470
examples_with_feedback=formatted_examples,
460471
)
461472

@@ -476,7 +487,7 @@ def __call__(
476487

477488
# Extract improved tool descriptions (only include if improved)
478489
improved_react_config["tools"] = {}
479-
for tool_name, tool_info in current_tools.items():
490+
for tool_name, tool_info in current_tools_dict.items():
480491
sanitized_tool_name = self._sanitize_name(tool_name)
481492

482493
# Get improved description

0 commit comments

Comments
 (0)