Skip to content

Commit 89fba1c

Browse files
CristianLarafacebook-github-bot
authored andcommitted
Method to consolidate Experiment.status from generator runs (#4900)
Summary: Add a new static method `experiment_status_from_generator_runs()` to `GenerationStrategy` that extracts and validates a suggested ExperimentStatus from a list of GeneratorRun objects. It collects all unique suggested_experiment_status values from the runs and: - Returns None with a warning if there are conflicting statuses across runs - Returns None with an info log if no statuses are found - Returns the single agreed-upon status otherwise Differential Revision: D92985915
1 parent 7bc0302 commit 89fba1c

2 files changed

Lines changed: 107 additions & 0 deletions

File tree

ax/generation_strategy/generation_strategy.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from ax.adapter.base import Adapter
1717
from ax.core.data import Data
1818
from ax.core.experiment import Experiment
19+
from ax.core.experiment_status import ExperimentStatus
1920
from ax.core.generator_run import GeneratorRun
2021
from ax.core.observation import ObservationFeatures
2122
from ax.core.utils import extend_pending_observations, extract_pending_observations
@@ -312,6 +313,43 @@ def gen(
312313
)
313314
return grs_for_multiple_trials
314315

316+
@staticmethod
317+
def experiment_status_from_generator_runs(
318+
generator_runs: list[GeneratorRun],
319+
) -> ExperimentStatus | None:
320+
"""Extract and validate suggested experiment status from generator runs.
321+
322+
Collects the suggested_experiment_status directly from the GeneratorRun
323+
objects, validates that all runs suggest the same status, and returns
324+
that status.
325+
326+
Args:
327+
generator_runs: List of generator runs to extract statuses from.
328+
329+
Returns:
330+
The suggested experiment status that all generator runs agree on,
331+
or None if no statuses were found or if there are conflicting statuses.
332+
"""
333+
suggested_statuses: set[ExperimentStatus] = set()
334+
for gr in generator_runs:
335+
if gr.suggested_experiment_status is not None:
336+
suggested_statuses.add(gr.suggested_experiment_status)
337+
338+
if len(suggested_statuses) > 1:
339+
logger.warning(
340+
"Multiple different suggested experiment statuses found: "
341+
f"{suggested_statuses}. "
342+
"All generator runs used in a single gen() call should suggest the "
343+
"same experiment status. Skipping updating experiment status."
344+
)
345+
return None
346+
347+
if len(suggested_statuses) == 0:
348+
logger.info("No suggested_experiment_status found on any generator runs.")
349+
return None
350+
351+
return suggested_statuses.pop()
352+
315353
def current_generator_run_limit(
316354
self,
317355
) -> tuple[int, bool]:

ax/generation_strategy/tests/test_generation_strategy.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from ax.adapter.torch import TorchAdapter
2525
from ax.core.arm import Arm
2626
from ax.core.experiment import Experiment
27+
from ax.core.experiment_status import ExperimentStatus
2728
from ax.core.generator_run import GeneratorRun
2829
from ax.core.observation import ObservationFeatures
2930
from ax.core.parameter import ChoiceParameter, FixedParameter, Parameter, ParameterType
@@ -2022,6 +2023,74 @@ def test_optimization_complete_single_node_no_criteria(self) -> None:
20222023

20232024
self.assertFalse(gs.optimization_complete)
20242025

2026+
def test_experiment_status_from_generation_strategy(self) -> None:
2027+
"""Test that experiment status is correctly propagated through
2028+
generator runs and extracted via experiment_status_from_generator_runs."""
2029+
2030+
with self.subTest("gen returns GRs with correct suggested_experiment_status"):
2031+
for status in [
2032+
ExperimentStatus.INITIALIZATION,
2033+
ExperimentStatus.OPTIMIZATION,
2034+
]:
2035+
with self.subTest(status=status):
2036+
exp = get_branin_experiment()
2037+
node_with_status = GenerationNode(
2038+
name="test_node",
2039+
generator_specs=[self.sobol_generator_spec],
2040+
suggested_experiment_status=status,
2041+
)
2042+
gs = GenerationStrategy(nodes=[node_with_status])
2043+
gs.experiment = exp
2044+
2045+
grs = gs.gen(experiment=exp, num_trials=1)
2046+
flat_grs = [gr for trial_grs in grs for gr in trial_grs]
2047+
2048+
extracted_status = (
2049+
GenerationStrategy.experiment_status_from_generator_runs(
2050+
flat_grs
2051+
)
2052+
)
2053+
self.assertEqual(extracted_status, status)
2054+
2055+
with self.subTest("conflicting statuses return None"):
2056+
gr1 = GeneratorRun(
2057+
arms=[Arm(name="0_0", parameters={"x1": 0.0, "x2": 0.0})],
2058+
suggested_experiment_status=ExperimentStatus.INITIALIZATION,
2059+
)
2060+
gr2 = GeneratorRun(
2061+
arms=[Arm(name="0_1", parameters={"x1": 1.0, "x2": 1.0})],
2062+
suggested_experiment_status=ExperimentStatus.OPTIMIZATION,
2063+
)
2064+
mixed_grs = [gr1, gr2]
2065+
2066+
result = GenerationStrategy.experiment_status_from_generator_runs(mixed_grs)
2067+
self.assertIsNone(result)
2068+
2069+
with self.subTest("multiple trials all carry experiment status"):
2070+
exp = get_branin_experiment()
2071+
node_with_status = GenerationNode(
2072+
name="multi_trial_node",
2073+
generator_specs=[self.sobol_generator_spec],
2074+
suggested_experiment_status=ExperimentStatus.INITIALIZATION,
2075+
)
2076+
gs = GenerationStrategy(nodes=[node_with_status])
2077+
gs.experiment = exp
2078+
2079+
grs = gs.gen(experiment=exp, num_trials=3)
2080+
2081+
self.assertEqual(len(grs), 3)
2082+
for gr_list in grs:
2083+
self.assertEqual(len(gr_list), 1)
2084+
self.assertEqual(gr_list[0]._generation_node_name, "multi_trial_node")
2085+
self.assertEqual(
2086+
gr_list[0].suggested_experiment_status,
2087+
ExperimentStatus.INITIALIZATION,
2088+
)
2089+
extracted_status = GenerationStrategy.experiment_status_from_generator_runs(
2090+
[gr for trial_grs in grs for gr in trial_grs]
2091+
)
2092+
self.assertEqual(extracted_status, ExperimentStatus.INITIALIZATION)
2093+
20252094
# ------------- Testing helpers (put tests above this line) -------------
20262095

20272096
def _run_GS_for_N_rounds(

0 commit comments

Comments
 (0)