From 10721397fa5a0b1ab52b350950b48d07875d3cc6 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Tue, 27 Apr 2021 23:14:26 +0800 Subject: [PATCH 1/2] [DLMED] add detach for Tensor Signed-off-by: Nic Ma --- monai/handlers/classification_saver.py | 2 +- monai/handlers/iteration_metric.py | 2 +- monai/handlers/transform_inverter.py | 5 ++++- 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/monai/handlers/classification_saver.py b/monai/handlers/classification_saver.py index d77dabde2f..1aeff164cf 100644 --- a/monai/handlers/classification_saver.py +++ b/monai/handlers/classification_saver.py @@ -104,7 +104,7 @@ def __call__(self, engine: Engine) -> None: self._filenames.extend(filenames) outputs = self.output_transform(engine.state.output) if outputs is not None: - self._outputs.append(outputs) + self._outputs.append(outputs.detach()) def _finalize(self, engine: Engine) -> None: """ diff --git a/monai/handlers/iteration_metric.py b/monai/handlers/iteration_metric.py index d5a4d30699..0ab7adac05 100644 --- a/monai/handlers/iteration_metric.py +++ b/monai/handlers/iteration_metric.py @@ -77,7 +77,7 @@ def update(self, output: Sequence[torch.Tensor]) -> None: y_pred, y = output def _compute(y_pred, y): - score = self.metric_fn(y_pred, y) + score = self.metric_fn(y_pred.detach(), y.detach()) return score[0] if isinstance(score, (tuple, list)) else score if isinstance(y_pred, (list, tuple)) or isinstance(y, (list, tuple)): diff --git a/monai/handlers/transform_inverter.py b/monai/handlers/transform_inverter.py index 651a47c39c..d1e79c389b 100644 --- a/monai/handlers/transform_inverter.py +++ b/monai/handlers/transform_inverter.py @@ -130,8 +130,11 @@ def __call__(self, engine: Engine) -> None: align_corners=None, ) + output = engine.state.output[output_key] + if isinstance(output, torch.Tensor): + output = output.detach() segs_dict = { - batch_key: engine.state.output[output_key], + batch_key: output, transform_key: transform_info, } meta_dict_key = f"{batch_key}_{self.meta_key_postfix}" From 8e8bc80b9fef6b81ce2c028a408ce8916d3594d5 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Tue, 27 Apr 2021 17:02:50 +0100 Subject: [PATCH 2/2] fixes test Signed-off-by: Wenqi Li --- monai/handlers/classification_saver.py | 4 +++- monai/handlers/iteration_metric.py | 6 +++++- tests/test_lmdbdataset.py | 2 +- 3 files changed, 9 insertions(+), 3 deletions(-) diff --git a/monai/handlers/classification_saver.py b/monai/handlers/classification_saver.py index 1aeff164cf..40177725fd 100644 --- a/monai/handlers/classification_saver.py +++ b/monai/handlers/classification_saver.py @@ -104,7 +104,9 @@ def __call__(self, engine: Engine) -> None: self._filenames.extend(filenames) outputs = self.output_transform(engine.state.output) if outputs is not None: - self._outputs.append(outputs.detach()) + if isinstance(outputs, torch.Tensor): + outputs = outputs.detach() + self._outputs.append(outputs) def _finalize(self, engine: Engine) -> None: """ diff --git a/monai/handlers/iteration_metric.py b/monai/handlers/iteration_metric.py index 0ab7adac05..1578723582 100644 --- a/monai/handlers/iteration_metric.py +++ b/monai/handlers/iteration_metric.py @@ -77,7 +77,11 @@ def update(self, output: Sequence[torch.Tensor]) -> None: y_pred, y = output def _compute(y_pred, y): - score = self.metric_fn(y_pred.detach(), y.detach()) + if isinstance(y_pred, torch.Tensor): + y_pred = y_pred.detach() + if isinstance(y, torch.Tensor): + y = y.detach() + score = self.metric_fn(y_pred, y) return score[0] if isinstance(score, (tuple, list)) else score if isinstance(y_pred, (list, tuple)) or isinstance(y, (list, tuple)): diff --git a/tests/test_lmdbdataset.py b/tests/test_lmdbdataset.py index 96a23327fb..7ae8e57e7a 100644 --- a/tests/test_lmdbdataset.py +++ b/tests/test_lmdbdataset.py @@ -187,7 +187,7 @@ def setUp(self): def tearDown(self): shutil.rmtree(self.tempdir) - @DistCall(nnodes=1, nproc_per_node=2) + @DistCall(nnodes=1, nproc_per_node=1) def test_mp_cache(self): items = [[list(range(i))] for i in range(5)]