diff --git a/monai/engines/trainer.py b/monai/engines/trainer.py index 582da10a25..44e265be1f 100644 --- a/monai/engines/trainer.py +++ b/monai/engines/trainer.py @@ -26,7 +26,7 @@ from monai.engines.workflow import Workflow from monai.inferers import Inferer, SimpleInferer from monai.transforms import Transform -from monai.utils import min_version, optional_import +from monai.utils import PT_BEFORE_1_7, min_version, optional_import from monai.utils.enums import CommonKeys as Keys if TYPE_CHECKING: @@ -99,6 +99,8 @@ class SupervisedTrainer(Trainer): decollate: whether to decollate the batch-first data to a list of data after model computation, recommend `decollate=True` when `postprocessing` uses components from `monai.transforms`. default to `True`. + optim_set_to_none: when calling `optimizer.zero_grad()`, instead of setting to zero, set the grads to None. + more details: https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html. """ @@ -124,6 +126,7 @@ def __init__( event_names: Optional[List[Union[str, EventEnum]]] = None, event_to_attr: Optional[dict] = None, decollate: bool = True, + optim_set_to_none: bool = False, ) -> None: super().__init__( device=device, @@ -148,6 +151,7 @@ def __init__( self.optimizer = optimizer self.loss_function = loss_function self.inferer = SimpleInferer() if inferer is None else inferer + self.optim_set_to_none = optim_set_to_none def _iteration(self, engine: Engine, batchdata: Dict[str, torch.Tensor]): """ @@ -185,7 +189,12 @@ def _compute_pred_loss(): engine.fire_event(IterationEvents.LOSS_COMPLETED) self.network.train() - self.optimizer.zero_grad() + # `set_to_none` only work from PyTorch 1.7.0 + if PT_BEFORE_1_7: + self.optimizer.zero_grad() + else: + self.optimizer.zero_grad(set_to_none=self.optim_set_to_none) + if self.amp and self.scaler is not None: with torch.cuda.amp.autocast(): _compute_pred_loss() @@ -252,6 +261,8 @@ class GanTrainer(Trainer): decollate: whether to decollate the batch-first data to a list of data after model computation, recommend `decollate=True` when `postprocessing` uses components from `monai.transforms`. default to `True`. + optim_set_to_none: when calling `optimizer.zero_grad()`, instead of setting to zero, set the grads to None. + more details: https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html. """ @@ -282,6 +293,7 @@ def __init__( metric_cmp_fn: Callable = default_metric_cmp_fn, train_handlers: Optional[Sequence] = None, decollate: bool = True, + optim_set_to_none: bool = False, ): if not isinstance(train_data_loader, DataLoader): raise ValueError("train_data_loader must be PyTorch DataLoader.") @@ -314,6 +326,7 @@ def __init__( self.latent_shape = latent_shape self.g_prepare_batch = g_prepare_batch self.g_update_latents = g_update_latents + self.optim_set_to_none = optim_set_to_none def _iteration( self, engine: Engine, batchdata: Union[Dict, Sequence] @@ -342,7 +355,11 @@ def _iteration( 1, ) for _ in range(self.d_train_steps): - self.d_optimizer.zero_grad() + # `set_to_none` only work from PyTorch 1.7.0 + if PT_BEFORE_1_7: + self.d_optimizer.zero_grad() + else: + self.d_optimizer.zero_grad(set_to_none=self.optim_set_to_none) dloss = self.d_loss_function(g_output, d_input) dloss.backward() self.d_optimizer.step() @@ -352,7 +369,10 @@ def _iteration( if self.g_update_latents: g_input = self.g_prepare_batch(batch_size, self.latent_shape, engine.state.device, engine.non_blocking) g_output = self.g_inferer(g_input, self.g_network) - self.g_optimizer.zero_grad() + if PT_BEFORE_1_7: + self.g_optimizer.zero_grad() + else: + self.g_optimizer.zero_grad(set_to_none=self.optim_set_to_none) g_loss = self.g_loss_function(g_output) g_loss.backward() self.g_optimizer.step() diff --git a/monai/handlers/utils.py b/monai/handlers/utils.py index 821a694cb4..302f8ab7a7 100644 --- a/monai/handlers/utils.py +++ b/monai/handlers/utils.py @@ -112,7 +112,7 @@ def string_list_all_gather(strings: List[str]) -> List[str]: if length < max_len: strings += ["" for _ in range(max_len - length)] - if get_torch_version_tuple() <= (1, 6, 0): + if get_torch_version_tuple() <= (1, 6): raise RuntimeError("string all_gather can not be supported in PyTorch < 1.7.0.") for s in strings: diff --git a/tests/test_integration_workflows.py b/tests/test_integration_workflows.py index 0610bf0324..7fcc0b4064 100644 --- a/tests/test_integration_workflows.py +++ b/tests/test_integration_workflows.py @@ -202,6 +202,7 @@ def _model_completed(self, engine): key_train_metric={"train_acc": Accuracy(output_transform=from_engine(["pred", "label"]))}, train_handlers=train_handlers, amp=True if amp else False, + optim_set_to_none=True, ) trainer.run()