Skip to content
28 changes: 24 additions & 4 deletions monai/engines/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from monai.engines.workflow import Workflow
from monai.inferers import Inferer, SimpleInferer
from monai.transforms import Transform
from monai.utils import min_version, optional_import
from monai.utils import PT_BEFORE_1_7, min_version, optional_import
from monai.utils.enums import CommonKeys as Keys

if TYPE_CHECKING:
Expand Down Expand Up @@ -99,6 +99,8 @@ class SupervisedTrainer(Trainer):
decollate: whether to decollate the batch-first data to a list of data after model computation,
recommend `decollate=True` when `postprocessing` uses components from `monai.transforms`.
default to `True`.
optim_set_to_none: when calling `optimizer.zero_grad()`, instead of setting to zero, set the grads to None.
more details: https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html.

"""

Expand All @@ -124,6 +126,7 @@ def __init__(
event_names: Optional[List[Union[str, EventEnum]]] = None,
event_to_attr: Optional[dict] = None,
decollate: bool = True,
optim_set_to_none: bool = False,
) -> None:
super().__init__(
device=device,
Expand All @@ -148,6 +151,7 @@ def __init__(
self.optimizer = optimizer
self.loss_function = loss_function
self.inferer = SimpleInferer() if inferer is None else inferer
self.optim_set_to_none = optim_set_to_none

def _iteration(self, engine: Engine, batchdata: Dict[str, torch.Tensor]):
"""
Expand Down Expand Up @@ -185,7 +189,12 @@ def _compute_pred_loss():
engine.fire_event(IterationEvents.LOSS_COMPLETED)

self.network.train()
self.optimizer.zero_grad()
# `set_to_none` only work from PyTorch 1.7.0
if PT_BEFORE_1_7:
self.optimizer.zero_grad()
else:
self.optimizer.zero_grad(set_to_none=self.optim_set_to_none)

if self.amp and self.scaler is not None:
with torch.cuda.amp.autocast():
_compute_pred_loss()
Expand Down Expand Up @@ -252,6 +261,8 @@ class GanTrainer(Trainer):
decollate: whether to decollate the batch-first data to a list of data after model computation,
recommend `decollate=True` when `postprocessing` uses components from `monai.transforms`.
default to `True`.
optim_set_to_none: when calling `optimizer.zero_grad()`, instead of setting to zero, set the grads to None.
more details: https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html.

"""

Expand Down Expand Up @@ -282,6 +293,7 @@ def __init__(
metric_cmp_fn: Callable = default_metric_cmp_fn,
train_handlers: Optional[Sequence] = None,
decollate: bool = True,
optim_set_to_none: bool = False,
):
if not isinstance(train_data_loader, DataLoader):
raise ValueError("train_data_loader must be PyTorch DataLoader.")
Expand Down Expand Up @@ -314,6 +326,7 @@ def __init__(
self.latent_shape = latent_shape
self.g_prepare_batch = g_prepare_batch
self.g_update_latents = g_update_latents
self.optim_set_to_none = optim_set_to_none

def _iteration(
self, engine: Engine, batchdata: Union[Dict, Sequence]
Expand Down Expand Up @@ -342,7 +355,11 @@ def _iteration(
1,
)
for _ in range(self.d_train_steps):
self.d_optimizer.zero_grad()
# `set_to_none` only work from PyTorch 1.7.0
if PT_BEFORE_1_7:
self.d_optimizer.zero_grad()
else:
self.d_optimizer.zero_grad(set_to_none=self.optim_set_to_none)
dloss = self.d_loss_function(g_output, d_input)
dloss.backward()
self.d_optimizer.step()
Expand All @@ -352,7 +369,10 @@ def _iteration(
if self.g_update_latents:
g_input = self.g_prepare_batch(batch_size, self.latent_shape, engine.state.device, engine.non_blocking)
g_output = self.g_inferer(g_input, self.g_network)
self.g_optimizer.zero_grad()
if PT_BEFORE_1_7:
self.g_optimizer.zero_grad()
else:
self.g_optimizer.zero_grad(set_to_none=self.optim_set_to_none)
g_loss = self.g_loss_function(g_output)
g_loss.backward()
self.g_optimizer.step()
Expand Down
2 changes: 1 addition & 1 deletion monai/handlers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def string_list_all_gather(strings: List[str]) -> List[str]:
if length < max_len:
strings += ["" for _ in range(max_len - length)]

if get_torch_version_tuple() <= (1, 6, 0):
if get_torch_version_tuple() <= (1, 6):
raise RuntimeError("string all_gather can not be supported in PyTorch < 1.7.0.")

for s in strings:
Expand Down
1 change: 1 addition & 0 deletions tests/test_integration_workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,7 @@ def _model_completed(self, engine):
key_train_metric={"train_acc": Accuracy(output_transform=from_engine(["pred", "label"]))},
train_handlers=train_handlers,
amp=True if amp else False,
optim_set_to_none=True,
)
trainer.run()

Expand Down