Skip to content

Commit dfa8aad

Browse files
saitcakmakmeta-codesync[bot]
authored andcommitted
Fix bugs in JSON and SQA storage modules (#4893)
Summary: Pull Request resolved: #4893 Fix multiple bugs in ax/storage: - `int(None)` crash in SQA encoder: `metric_registry.get()` can return None, but the result was immediately passed to `int()` before the None check could run (sqa_store/encoder.py:393) - `AttributeError` on None kwargs in JSON decoder: `kwargs.pop()` was called before the `if kwargs is not None` guard (json_store/decoder.py:1092) - `decode_args_list` length mismatch in SQA save: used `len(trials)` instead of `len(trials_to_reduce_state)` and `[...] * len(trials)` instead of a single-element list (sqa_store/save.py:245,254) - SQLAlchemy filter no-op: `SQAExperiment.id is not None` uses Python identity check instead of SQLAlchemy's `.isnot(None)` (sqa_store/load.py:760) - Inconsistent threshold in error message: says "> 10" but the actual check is `> 15` (sqa_store/utils.py:77) - Garbled error message in with_db_settings_base.py: sentence fragments in wrong order (sqa_store/with_db_settings_base.py:106) - Doubled path `ax/ax/storage` in error message (sqa_store/encoder.py:131) - Backslash line-continuation inside string literal causes 16 extra spaces in error message (sqa_store/encoder.py:654) - Missing space between concatenated strings (sqa_store/db.py) - Typo "SCalarized" → "Scalarized" (sqa_store/encoder.py) - Missing space in SKIP_ATTRS_ERROR_SUFFIX (sqa_store/utils.py) Reviewed By: bernardbeckerman Differential Revision: D92879499
1 parent 7402a57 commit dfa8aad

7 files changed

Lines changed: 24 additions & 18 deletions

File tree

ax/storage/json_store/decoder.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1098,10 +1098,10 @@ def generator_spec_from_json(
10981098
kwargs = generator_spec_json.pop("model_kwargs", None)
10991099
else:
11001100
kwargs = generator_spec_json.pop("generator_kwargs", None)
1101-
for k in _DEPRECATED_GENERATOR_KWARGS:
1102-
# Remove deprecated model kwargs.
1103-
kwargs.pop(k, None)
11041101
if kwargs is not None:
1102+
for k in _DEPRECATED_GENERATOR_KWARGS:
1103+
# Remove deprecated model kwargs.
1104+
kwargs.pop(k, None)
11051105
kwargs = _sanitize_surrogate_spec_input(object_json=kwargs)
11061106
if "model_gen_kwargs" in generator_spec_json:
11071107
gen_kwargs = generator_spec_json.pop("model_gen_kwargs", None)

ax/storage/sqa_store/db.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -218,8 +218,8 @@ def create_all_tables(engine: Engine) -> None:
218218
"""
219219
if engine.dialect.name == "mysql" and engine.dialect.default_schema_name == "ax":
220220
raise ValueError(
221-
"The open-source Ax table creation is likely not applicable in this case,"
222-
+ "please contact the Adaptive Experimentation team if you need help."
221+
"The open-source Ax table creation is likely not applicable in this case, "
222+
"please contact the Adaptive Experimentation team if you need help."
223223
)
224224
Base.metadata.create_all(engine)
225225

ax/storage/sqa_store/encoder.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ def validate_experiment_metadata(
128128
raise ValueError(
129129
f"An experiment already exists with the name {experiment.name}. "
130130
"If you need to override this existing experiment, first delete it "
131-
"via `delete_experiment` in ax/ax/storage/sqa_store/delete.py, "
131+
"via `delete_experiment` in ax/storage/sqa_store/delete.py, "
132132
"and then resave."
133133
)
134134

@@ -391,14 +391,15 @@ def get_metric_type_and_properties(
391391
json blob.
392392
"""
393393
metric_class = type(metric)
394-
metric_type = int(self.config.metric_registry.get(metric_class))
395-
if metric_type is None:
394+
metric_type_or_none = self.config.metric_registry.get(metric_class)
395+
if metric_type_or_none is None:
396396
raise SQAEncodeError(
397397
"Cannot encode metric to SQLAlchemy because metric's "
398398
f"subclass ({metric_class}) is missing from the registry. "
399399
"The metric registry currently contains the following: "
400400
f"{','.join(map(str, self.config.metric_registry.keys()))} "
401401
)
402+
metric_type = int(metric_type_or_none)
402403

403404
properties = metric_class.serialize_init_args(obj=metric)
404405
return metric_type, object_to_json(
@@ -646,13 +647,13 @@ def outcome_constraint_to_sqa(
646647
def scalarized_outcome_constraint_to_sqa(
647648
self, outcome_constraint: ScalarizedOutcomeConstraint
648649
) -> SQAMetric:
649-
"""Convert Ax SCalarized OutcomeConstraint to SQLAlchemy."""
650+
"""Convert Ax Scalarized OutcomeConstraint to SQLAlchemy."""
650651
metrics, weights = outcome_constraint.metrics, outcome_constraint.weights
651652

652653
if metrics is None or weights is None or len(metrics) != len(weights):
653654
raise SQAEncodeError(
654-
"Metrics and weights in scalarized OutcomeConstraint \
655-
must be lists of equal length."
655+
"Metrics and weights in scalarized OutcomeConstraint "
656+
"must be lists of equal length."
656657
)
657658

658659
metrics_by_name = self.get_children_metrics_by_name(

ax/storage/sqa_store/load.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -768,7 +768,7 @@ def _query_historical_experiments_given_parameters(
768768
)
769769
)
770770
.filter(SQAExperiment.is_test == False) # noqa E712 `is` won't work for SQA
771-
.filter(SQAExperiment.id is not None)
771+
.filter(SQAExperiment.id.isnot(None))
772772
# Experiments with some data
773773
.join(SQAData, SQAParameter.experiment_id == SQAData.experiment_id)
774774
)

ax/storage/sqa_store/save.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -260,7 +260,9 @@ def trial_to_reduced_state_sqa_encoder(t: BaseTrial) -> SQATrial:
260260
objs=trials_to_reduce_state,
261261
encode_func=trial_to_reduced_state_sqa_encoder,
262262
decode_func=decoder.trial_from_sqa,
263-
decode_args_list=[{"experiment": experiment} for _ in range(len(trials))],
263+
decode_args_list=[
264+
{"experiment": experiment} for _ in range(len(trials_to_reduce_state))
265+
],
264266
modify_sqa=add_experiment_id,
265267
batch_size=batch_size,
266268
)
@@ -269,7 +271,7 @@ def trial_to_reduced_state_sqa_encoder(t: BaseTrial) -> SQATrial:
269271
objs=[latest_trial],
270272
encode_func=encoder.trial_to_sqa,
271273
decode_func=decoder.trial_from_sqa,
272-
decode_args_list=[{"experiment": experiment} for _ in range(len(trials))],
274+
decode_args_list=[{"experiment": experiment}],
273275
modify_sqa=add_experiment_id,
274276
batch_size=batch_size,
275277
)

ax/storage/sqa_store/utils.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,9 @@
4747
"_metric_fetching_errors",
4848
"_data_rows",
4949
}
50-
SKIP_ATTRS_ERROR_SUFFIX = "Consider adding to COPY_DB_IDS_ATTRS_TO_SKIP if appropriate."
50+
SKIP_ATTRS_ERROR_SUFFIX = (
51+
" Consider adding to COPY_DB_IDS_ATTRS_TO_SKIP if appropriate."
52+
)
5153

5254

5355
def is_foreign_key_field(field: str) -> bool:
@@ -75,7 +77,7 @@ def copy_db_ids(source: Any, target: Any, path: list[str] | None = None) -> None
7577
if len(path) > 15:
7678
# This shouldn't happen, but is a precaution against accidentally
7779
# introducing infinite loops
78-
raise SQADecodeError(error_message_prefix + "Encountered path of length > 10.")
80+
raise SQADecodeError(error_message_prefix + "Encountered path of length > 15.")
7981

8082
if type(source) is not type(target):
8183
if not issubclass(type(target), type(source)):

ax/storage/sqa_store/with_db_settings_base.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,8 +106,9 @@ def __init__(
106106
if db_settings and (not DBSettings or not isinstance(db_settings, DBSettings)):
107107
raise ValueError(
108108
"`db_settings` argument should be of type ax.storage.sqa_store."
109-
f"(Got: {db_settings} of type {type(db_settings)}. "
110-
"structs.DBSettings. To use `DBSettings`, you will need SQLAlchemy "
109+
"structs.DBSettings. "
110+
f"(Got: {db_settings} of type {type(db_settings)}). "
111+
"To use `DBSettings`, you will need SQLAlchemy "
111112
"installed in your environment (can be installed through pip)."
112113
)
113114
self._db_settings = db_settings or self._get_default_db_settings()

0 commit comments

Comments
 (0)