Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Added
- Python 3.13 support ([#227](https://github.com/MobileTeleSystems/RecTools/pull/227))
- `fit_partial` implementation for transformer-based models ([#273](https://github.com/MobileTeleSystems/RecTools/pull/273))
- `map_location` and `model_params_update` arguments for the function `load_from_checkpoint` for Transformer-based models. Use `map_location` to explicitly specify the computing new device and `model_params_update` to update original model parameters (e.g. remove training-specific parameters that are not needed anymore) ([#281](https://github.com/MobileTeleSystems/RecTools/pull/281))

## [0.13.0] - 10.04.2025

Expand Down
28 changes: 22 additions & 6 deletions rectools/models/nn/transformers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from rectools.dataset.dataset import Dataset, DatasetSchema, DatasetSchemaDict, IdMap
from rectools.models.base import ErrorBehaviour, InternalRecoTriplet, ModelBase, ModelConfig
from rectools.types import InternalIdsArray
from rectools.utils.misc import get_class_or_function_full_path, import_object
from rectools.utils.misc import get_class_or_function_full_path, import_object, make_dict_flat, unflatten_dict

from ..item_net import (
CatFeaturesItemNet,
Expand Down Expand Up @@ -623,20 +623,36 @@ def __setstate__(self, state: tp.Dict[str, tp.Any]) -> None:
self.__dict__.update(loaded.__dict__)

@classmethod
def load_from_checkpoint(cls, checkpoint_path: tp.Union[str, Path]) -> tpe.Self:
"""
Load model from Lightning checkpoint path.
def load_from_checkpoint(
cls,
checkpoint_path: tp.Union[str, Path],
map_location: tp.Optional[tp.Union[str, torch.device]] = None,
model_params_update: tp.Optional[tp.Dict[str, tp.Any]] = None,
) -> tpe.Self:
"""Load model from Lightning checkpoint path.

Parameters
----------
checkpoint_path: Union[str, Path]
Path to checkpoint location.

map_location: Union[str, torch.device], optional
Target device to load the checkpoint (e.g., 'cpu', 'cuda:0').
If None, will use the device the checkpoint was saved on.
model_params_update: Dict[str, tp.Any], optional
Contains custom values for checkpoint['hyper_parameters']['model_config'].
Has to be flattened with 'dot' reducer, before passed.
You can use this argument to remove training-specific parameters that are not needed anymore.
e.g. 'get_trainer_func'
Returns
-------
Model instance.
"""
checkpoint = torch.load(checkpoint_path, weights_only=False)
checkpoint = torch.load(checkpoint_path, map_location=map_location, weights_only=False)
if model_params_update:
prev_model_config = checkpoint["hyper_parameters"]["model_config"]
prev_config_flatten = make_dict_flat(prev_model_config)
prev_config_flatten.update(model_params_update)
checkpoint["hyper_parameters"]["model_config"] = unflatten_dict(prev_config_flatten)
loaded = cls._model_from_checkpoint(checkpoint)
return loaded

Expand Down
31 changes: 30 additions & 1 deletion tests/models/nn/transformers/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,10 +154,37 @@ def test_save_load_for_fitted_model(

@pytest.mark.parametrize("test_dataset", ("dataset", "dataset_item_features"))
@pytest.mark.parametrize("model_cls", (SASRecModel, BERT4RecModel))
@pytest.mark.parametrize(
"map_location",
(
"cpu",
pytest.param(
"cuda:0",
marks=pytest.mark.skipif(torch.cuda.is_available() is False, reason="GPU is not available"),
),
None,
),
)
@pytest.mark.parametrize(
"model_params_update",
(
{
"get_val_mask_func": "tests.models.nn.transformers.utils.leave_one_out_mask",
"get_trainer_func": "tests.models.nn.transformers.utils.custom_trainer",
},
{
"get_val_mask_func": None,
"get_trainer_func": None,
},
None,
),
)
def test_load_from_checkpoint(
self,
model_cls: tp.Type[TransformerModelBase],
test_dataset: str,
map_location: tp.Optional[tp.Union[str, torch.device]],
model_params_update: tp.Optional[tp.Dict[str, tp.Any]],
request: FixtureRequest,
) -> None:

Expand All @@ -175,7 +202,9 @@ def test_load_from_checkpoint(
raise ValueError("No log dir")
ckpt_path = os.path.join(model.fit_trainer.log_dir, "checkpoints", "last_epoch.ckpt")
assert os.path.isfile(ckpt_path)
recovered_model = model_cls.load_from_checkpoint(ckpt_path)
recovered_model = model_cls.load_from_checkpoint(
ckpt_path, map_location=map_location, model_params_update=model_params_update
)
assert isinstance(recovered_model, model_cls)

self._assert_same_reco(model, recovered_model, dataset)
Expand Down
1 change: 1 addition & 0 deletions tests/models/nn/transformers/test_bert4rec.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,6 +451,7 @@ def test_i2i(
whitelist: tp.Optional[np.ndarray],
expected: pd.DataFrame,
) -> None:

model = BERT4RecModel(
n_factors=32,
n_blocks=2,
Expand Down