Skip to content

Commit 36ff1f4

Browse files
Carl Hvarfnermeta-codesync[bot]
authored andcommitted
Migrate State Dict Files to account for GPyTorch state dict change (facebook#4916)
Summary: Pull Request resolved: facebook#4916 Re-named benchmarking state dicts to account for state dict attribute changes in GPyTorch. "_transformed" attribute of priors of TransformedDistributions have been replaced by a constant BUFFERED_PREFIX = "__buffered_". The changes in .pt files correspond to changing the attribute names for the priors, from `"_transformed" `-->` BUFFERED_PREFIX = "__buffered_"` Reviewed By: Balandat Differential Revision: D93032568
1 parent 5b9c82d commit 36ff1f4

1 file changed

Lines changed: 4 additions & 3 deletions

File tree

ax/storage/json_store/decoders.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454
from botorch.models.transforms.input import ChainedInputTransform, InputTransform
5555
from botorch.models.transforms.outcome import ChainedOutcomeTransform, OutcomeTransform
5656
from botorch.utils.types import _DefaultType, DEFAULT
57+
from gpytorch.priors.utils import BUFFERED_PREFIX
5758
from pyre_extensions import assert_is_instance
5859
from torch.distributions.transformed_distribution import TransformedDistribution
5960

@@ -369,10 +370,10 @@ def botorch_component_from_json(botorch_class: type[T], json: dict[str, Any]) ->
369370
}
370371
)
371372
if issubclass(botorch_class, TransformedDistribution):
372-
# Extract the transformed attributes for transformed priors.
373+
# Extract the buffered attributes for transformed priors.
373374
for k in list(state_dict.keys()):
374-
if k.startswith("_transformed_"):
375-
state_dict[k[13:]] = state_dict.pop(k)
375+
if k.startswith(BUFFERED_PREFIX):
376+
state_dict[k[len(BUFFERED_PREFIX) :]] = state_dict.pop(k)
376377
class_path = json.pop("class")
377378
init_args = inspect.signature(botorch_class).parameters
378379
required_args = {

0 commit comments

Comments
 (0)