From 1b4e656ac4765b9e80cb1d870700f2cd1d950204 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Wed, 14 Jul 2021 11:15:53 +0800 Subject: [PATCH 1/5] [DLMED] add set_to_none arg Signed-off-by: Nic Ma --- monai/engines/trainer.py | 26 ++++++++++++++++++++++---- tests/test_integration_workflows.py | 1 + 2 files changed, 23 insertions(+), 4 deletions(-) diff --git a/monai/engines/trainer.py b/monai/engines/trainer.py index 582da10a25..bdc2c35377 100644 --- a/monai/engines/trainer.py +++ b/monai/engines/trainer.py @@ -26,7 +26,7 @@ from monai.engines.workflow import Workflow from monai.inferers import Inferer, SimpleInferer from monai.transforms import Transform -from monai.utils import min_version, optional_import +from monai.utils import get_torch_version_tuple, min_version, optional_import from monai.utils.enums import CommonKeys as Keys if TYPE_CHECKING: @@ -99,6 +99,8 @@ class SupervisedTrainer(Trainer): decollate: whether to decollate the batch-first data to a list of data after model computation, recommend `decollate=True` when `postprocessing` uses components from `monai.transforms`. default to `True`. + optim_set_to_none: when calling `optimizer.zero_grad()`, instead of setting to zero, set the grads to None. + more details: https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html. """ @@ -124,6 +126,7 @@ def __init__( event_names: Optional[List[Union[str, EventEnum]]] = None, event_to_attr: Optional[dict] = None, decollate: bool = True, + optim_set_to_none: bool = False, ) -> None: super().__init__( device=device, @@ -148,6 +151,7 @@ def __init__( self.optimizer = optimizer self.loss_function = loss_function self.inferer = SimpleInferer() if inferer is None else inferer + self.optim_set_to_none = optim_set_to_none def _iteration(self, engine: Engine, batchdata: Dict[str, torch.Tensor]): """ @@ -185,7 +189,11 @@ def _compute_pred_loss(): engine.fire_event(IterationEvents.LOSS_COMPLETED) self.network.train() - self.optimizer.zero_grad() + if get_torch_version_tuple() < (1, 7, 0): + self.optimizer.zero_grad() + else: + self.optimizer.zero_grad(set_to_none=self.optim_set_to_none) + if self.amp and self.scaler is not None: with torch.cuda.amp.autocast(): _compute_pred_loss() @@ -252,6 +260,8 @@ class GanTrainer(Trainer): decollate: whether to decollate the batch-first data to a list of data after model computation, recommend `decollate=True` when `postprocessing` uses components from `monai.transforms`. default to `True`. + optim_set_to_none: when calling `optimizer.zero_grad()`, instead of setting to zero, set the grads to None. + more details: https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html. """ @@ -282,6 +292,7 @@ def __init__( metric_cmp_fn: Callable = default_metric_cmp_fn, train_handlers: Optional[Sequence] = None, decollate: bool = True, + optim_set_to_none: bool = False, ): if not isinstance(train_data_loader, DataLoader): raise ValueError("train_data_loader must be PyTorch DataLoader.") @@ -314,6 +325,7 @@ def __init__( self.latent_shape = latent_shape self.g_prepare_batch = g_prepare_batch self.g_update_latents = g_update_latents + self.optim_set_to_none = optim_set_to_none def _iteration( self, engine: Engine, batchdata: Union[Dict, Sequence] @@ -342,7 +354,10 @@ def _iteration( 1, ) for _ in range(self.d_train_steps): - self.d_optimizer.zero_grad() + if get_torch_version_tuple() < (1, 7, 0): + self.d_optimizer.zero_grad() + else: + self.d_optimizer.zero_grad(set_to_none=self.optim_set_to_none) dloss = self.d_loss_function(g_output, d_input) dloss.backward() self.d_optimizer.step() @@ -352,7 +367,10 @@ def _iteration( if self.g_update_latents: g_input = self.g_prepare_batch(batch_size, self.latent_shape, engine.state.device, engine.non_blocking) g_output = self.g_inferer(g_input, self.g_network) - self.g_optimizer.zero_grad() + if get_torch_version_tuple() < (1, 7, 0): + self.g_optimizer.zero_grad() + else: + self.g_optimizer.zero_grad(set_to_none=self.optim_set_to_none) g_loss = self.g_loss_function(g_output) g_loss.backward() self.g_optimizer.step() diff --git a/tests/test_integration_workflows.py b/tests/test_integration_workflows.py index 0610bf0324..7fcc0b4064 100644 --- a/tests/test_integration_workflows.py +++ b/tests/test_integration_workflows.py @@ -202,6 +202,7 @@ def _model_completed(self, engine): key_train_metric={"train_acc": Accuracy(output_transform=from_engine(["pred", "label"]))}, train_handlers=train_handlers, amp=True if amp else False, + optim_set_to_none=True, ) trainer.run() From f2782c7d92203e5cd0b00ca4981d2754097c4034 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Wed, 14 Jul 2021 16:33:23 +0800 Subject: [PATCH 2/5] [DLMED] update according to comments Signed-off-by: Nic Ma --- monai/engines/trainer.py | 4 ++-- monai/handlers/utils.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/monai/engines/trainer.py b/monai/engines/trainer.py index bdc2c35377..0b7123a062 100644 --- a/monai/engines/trainer.py +++ b/monai/engines/trainer.py @@ -189,7 +189,7 @@ def _compute_pred_loss(): engine.fire_event(IterationEvents.LOSS_COMPLETED) self.network.train() - if get_torch_version_tuple() < (1, 7, 0): + if get_torch_version_tuple() < (1, 7): self.optimizer.zero_grad() else: self.optimizer.zero_grad(set_to_none=self.optim_set_to_none) @@ -354,7 +354,7 @@ def _iteration( 1, ) for _ in range(self.d_train_steps): - if get_torch_version_tuple() < (1, 7, 0): + if get_torch_version_tuple() < (1, 7): self.d_optimizer.zero_grad() else: self.d_optimizer.zero_grad(set_to_none=self.optim_set_to_none) diff --git a/monai/handlers/utils.py b/monai/handlers/utils.py index f4d1a1da65..d9f9eeb008 100644 --- a/monai/handlers/utils.py +++ b/monai/handlers/utils.py @@ -112,7 +112,7 @@ def string_list_all_gather(strings: List[str]) -> List[str]: if length < max_len: strings = strings + ["" for _ in range(max_len - length)] - if get_torch_version_tuple() > (1, 6, 0): + if get_torch_version_tuple() > (1, 6): for s in strings: gathered = idist.all_gather(s) for i, g in enumerate(gathered): From 19cf515a33b0d2fd1e6b18a08a04807a24b8b412 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Wed, 14 Jul 2021 22:50:23 +0800 Subject: [PATCH 3/5] [DLMED] fix version issue Signed-off-by: Nic Ma --- monai/engines/trainer.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/monai/engines/trainer.py b/monai/engines/trainer.py index 0b7123a062..7c0bb7b0d0 100644 --- a/monai/engines/trainer.py +++ b/monai/engines/trainer.py @@ -26,7 +26,7 @@ from monai.engines.workflow import Workflow from monai.inferers import Inferer, SimpleInferer from monai.transforms import Transform -from monai.utils import get_torch_version_tuple, min_version, optional_import +from monai.utils import min_version, optional_import, version_leq from monai.utils.enums import CommonKeys as Keys if TYPE_CHECKING: @@ -189,7 +189,8 @@ def _compute_pred_loss(): engine.fire_event(IterationEvents.LOSS_COMPLETED) self.network.train() - if get_torch_version_tuple() < (1, 7): + # `set_to_none` only work from PyTorch 1.7.0 + if torch.__version__ != "1.7.0" and version_leq(torch.__version__, "1.7.0"): self.optimizer.zero_grad() else: self.optimizer.zero_grad(set_to_none=self.optim_set_to_none) @@ -354,7 +355,8 @@ def _iteration( 1, ) for _ in range(self.d_train_steps): - if get_torch_version_tuple() < (1, 7): + # `set_to_none` only work from PyTorch 1.7.0 + if torch.__version__ != "1.7.0" and version_leq(torch.__version__, "1.7.0"): self.d_optimizer.zero_grad() else: self.d_optimizer.zero_grad(set_to_none=self.optim_set_to_none) @@ -367,7 +369,7 @@ def _iteration( if self.g_update_latents: g_input = self.g_prepare_batch(batch_size, self.latent_shape, engine.state.device, engine.non_blocking) g_output = self.g_inferer(g_input, self.g_network) - if get_torch_version_tuple() < (1, 7, 0): + if torch.__version__ != "1.7.0" and version_leq(torch.__version__, "1.7.0"): self.g_optimizer.zero_grad() else: self.g_optimizer.zero_grad(set_to_none=self.optim_set_to_none) From 66181a243c450e1e6429f7bc56f03721868a974e Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Wed, 14 Jul 2021 22:58:43 +0800 Subject: [PATCH 4/5] [DLMED] change to PT_BEFORE_1_7 Signed-off-by: Nic Ma --- monai/engines/trainer.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/monai/engines/trainer.py b/monai/engines/trainer.py index 7c0bb7b0d0..6c6b335227 100644 --- a/monai/engines/trainer.py +++ b/monai/engines/trainer.py @@ -26,7 +26,7 @@ from monai.engines.workflow import Workflow from monai.inferers import Inferer, SimpleInferer from monai.transforms import Transform -from monai.utils import min_version, optional_import, version_leq +from monai.utils import min_version, optional_import, PT_BEFORE_1_7 from monai.utils.enums import CommonKeys as Keys if TYPE_CHECKING: @@ -190,7 +190,7 @@ def _compute_pred_loss(): self.network.train() # `set_to_none` only work from PyTorch 1.7.0 - if torch.__version__ != "1.7.0" and version_leq(torch.__version__, "1.7.0"): + if PT_BEFORE_1_7: self.optimizer.zero_grad() else: self.optimizer.zero_grad(set_to_none=self.optim_set_to_none) @@ -356,7 +356,7 @@ def _iteration( ) for _ in range(self.d_train_steps): # `set_to_none` only work from PyTorch 1.7.0 - if torch.__version__ != "1.7.0" and version_leq(torch.__version__, "1.7.0"): + if PT_BEFORE_1_7: self.d_optimizer.zero_grad() else: self.d_optimizer.zero_grad(set_to_none=self.optim_set_to_none) @@ -369,7 +369,7 @@ def _iteration( if self.g_update_latents: g_input = self.g_prepare_batch(batch_size, self.latent_shape, engine.state.device, engine.non_blocking) g_output = self.g_inferer(g_input, self.g_network) - if torch.__version__ != "1.7.0" and version_leq(torch.__version__, "1.7.0"): + if PT_BEFORE_1_7: self.g_optimizer.zero_grad() else: self.g_optimizer.zero_grad(set_to_none=self.optim_set_to_none) From b40e5fd03cc0d7f169ba899207ff9db4ed418d41 Mon Sep 17 00:00:00 2001 From: monai-bot Date: Wed, 14 Jul 2021 15:03:33 +0000 Subject: [PATCH 5/5] [MONAI] python code formatting Signed-off-by: monai-bot --- monai/engines/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/engines/trainer.py b/monai/engines/trainer.py index 6c6b335227..44e265be1f 100644 --- a/monai/engines/trainer.py +++ b/monai/engines/trainer.py @@ -26,7 +26,7 @@ from monai.engines.workflow import Workflow from monai.inferers import Inferer, SimpleInferer from monai.transforms import Transform -from monai.utils import min_version, optional_import, PT_BEFORE_1_7 +from monai.utils import PT_BEFORE_1_7, min_version, optional_import from monai.utils.enums import CommonKeys as Keys if TYPE_CHECKING: