diff --git a/monai/handlers/classification_saver.py b/monai/handlers/classification_saver.py index d77dabde2f..40177725fd 100644 --- a/monai/handlers/classification_saver.py +++ b/monai/handlers/classification_saver.py @@ -104,6 +104,8 @@ def __call__(self, engine: Engine) -> None: self._filenames.extend(filenames) outputs = self.output_transform(engine.state.output) if outputs is not None: + if isinstance(outputs, torch.Tensor): + outputs = outputs.detach() self._outputs.append(outputs) def _finalize(self, engine: Engine) -> None: diff --git a/monai/handlers/iteration_metric.py b/monai/handlers/iteration_metric.py index d5a4d30699..1578723582 100644 --- a/monai/handlers/iteration_metric.py +++ b/monai/handlers/iteration_metric.py @@ -77,6 +77,10 @@ def update(self, output: Sequence[torch.Tensor]) -> None: y_pred, y = output def _compute(y_pred, y): + if isinstance(y_pred, torch.Tensor): + y_pred = y_pred.detach() + if isinstance(y, torch.Tensor): + y = y.detach() score = self.metric_fn(y_pred, y) return score[0] if isinstance(score, (tuple, list)) else score diff --git a/monai/handlers/transform_inverter.py b/monai/handlers/transform_inverter.py index 651a47c39c..d1e79c389b 100644 --- a/monai/handlers/transform_inverter.py +++ b/monai/handlers/transform_inverter.py @@ -130,8 +130,11 @@ def __call__(self, engine: Engine) -> None: align_corners=None, ) + output = engine.state.output[output_key] + if isinstance(output, torch.Tensor): + output = output.detach() segs_dict = { - batch_key: engine.state.output[output_key], + batch_key: output, transform_key: transform_info, } meta_dict_key = f"{batch_key}_{self.meta_key_postfix}" diff --git a/tests/test_lmdbdataset.py b/tests/test_lmdbdataset.py index 96a23327fb..7ae8e57e7a 100644 --- a/tests/test_lmdbdataset.py +++ b/tests/test_lmdbdataset.py @@ -187,7 +187,7 @@ def setUp(self): def tearDown(self): shutil.rmtree(self.tempdir) - @DistCall(nnodes=1, nproc_per_node=2) + @DistCall(nnodes=1, nproc_per_node=1) def test_mp_cache(self): items = [[list(range(i))] for i in range(5)]