From 61eb550463d3a4417b311df309b1df45eff6e4c1 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Tue, 27 Apr 2021 14:41:46 +0800 Subject: [PATCH 01/10] [DLMED] fix classification issue Signed-off-by: Nic Ma --- monai/handlers/classification_saver.py | 46 +++++++++++++++---- .../test_handler_classification_saver_dist.py | 12 ++++- 2 files changed, 48 insertions(+), 10 deletions(-) diff --git a/monai/handlers/classification_saver.py b/monai/handlers/classification_saver.py index 98f917330f..c8f1492965 100644 --- a/monai/handlers/classification_saver.py +++ b/monai/handlers/classification_saver.py @@ -11,6 +11,7 @@ import logging from typing import TYPE_CHECKING, Callable, Optional +import torch from monai.data import CSVSaver from monai.handlers.utils import evenly_divisible_all_gather, string_list_all_gather @@ -61,11 +62,16 @@ def __init__( """ self._expected_rank: bool = idist.get_rank() == save_rank self.saver = CSVSaver(output_dir, filename, overwrite) + self.output_dir = output_dir + self.filename = filename + self.overwrite = overwrite self.batch_transform = batch_transform self.output_transform = output_transform self.logger = logging.getLogger(name) self._name = name + self._outputs = [] + self._filenames = [] def attach(self, engine: Engine) -> None: """ @@ -76,8 +82,8 @@ def attach(self, engine: Engine) -> None: self.logger = engine.logger if not engine.has_event_handler(self, Events.ITERATION_COMPLETED): engine.add_event_handler(Events.ITERATION_COMPLETED, self) - if self._expected_rank and not engine.has_event_handler(self.saver.finalize, Events.COMPLETED): - engine.add_event_handler(Events.COMPLETED, lambda engine: self.saver.finalize()) + if not engine.has_event_handler(self.saver.finalize, Events.COMPLETED): + engine.add_event_handler(Events.COMPLETED, self._finalize) def __call__(self, engine: Engine) -> None: """ @@ -86,12 +92,34 @@ def __call__(self, engine: Engine) -> None: Args: engine: Ignite Engine, it can be a trainer, validator or evaluator. """ - _meta_data = self.batch_transform(engine.state.batch) - if Key.FILENAME_OR_OBJ in _meta_data: - # all gather filenames across ranks, only filenames are necessary - _meta_data = {Key.FILENAME_OR_OBJ: string_list_all_gather(_meta_data[Key.FILENAME_OR_OBJ])} - # all gather predictions across ranks - _engine_output = evenly_divisible_all_gather(self.output_transform(engine.state.output)) + filenames = self.batch_transform(engine.state.batch).get(Key.FILENAME_OR_OBJ, None) + if filenames is not None: + self._filenames.extend(filenames) + outputs = self.output_transform(engine.state.output) + if outputs is not None: + self._outputs.append(outputs) + def _finalize(self, engine: Engine) -> None: + """ + All gather classification results from ranks and save to CSV file. + + Args: + engine: Ignite Engine, it can be a trainer, validator or evaluator. + """ + outputs = evenly_divisible_all_gather(torch.cat(self._outputs, dim=0)) + filenames = string_list_all_gather(self._filenames) + if len(filenames) == 0: + meta_dict = None + elif len(filenames) != len(outputs): + raise RuntimeError(f"filenames length: {len(filenames)} doesn't match outputs length: {len(outputs)}.") + else: + meta_dict = {Key.FILENAME_OR_OBJ: filenames} + + # save to CSV file only in the expected rank if self._expected_rank: - self.saver.save_batch(_engine_output, _meta_data) + saver = CSVSaver(self.output_dir, self.filename, self.overwrite) + saver.save_batch(outputs, meta_dict) + saver.finalize() + # reset cache + self._outputs = [] + self._filenames = [] diff --git a/tests/test_handler_classification_saver_dist.py b/tests/test_handler_classification_saver_dist.py index a33cba923a..8fd80d4993 100644 --- a/tests/test_handler_classification_saver_dist.py +++ b/tests/test_handler_classification_saver_dist.py @@ -47,6 +47,16 @@ def _train_func(engine, batch): "data_shape": [(1, 1) for _ in range(8 * rank, (8 + rank) * (rank + 1))], } ] + # rank 1 has more iterations + if rank == 1: + data.append( + { + "filename_or_obj": ["testfile" + str(i) for i in range(18, 28)], + "data_shape": [(1, 1) for _ in range(18, 28)], + + } + ) + engine.run(data, max_epochs=1) filepath = os.path.join(tempdir, "predictions.csv") if rank == 1: @@ -58,7 +68,7 @@ def _train_func(engine, batch): self.assertEqual(row[0], "testfile" + str(i)) self.assertEqual(np.array(row[1:]).astype(np.float32), 0.0) i += 1 - self.assertEqual(i, 18) + self.assertEqual(i, 28) if __name__ == "__main__": From 5425da953bfd1762f655749561ac1b498f63ad87 Mon Sep 17 00:00:00 2001 From: monai-bot Date: Tue, 27 Apr 2021 06:46:57 +0000 Subject: [PATCH 02/10] [MONAI] python code formatting Signed-off-by: monai-bot --- monai/handlers/classification_saver.py | 1 + tests/test_handler_classification_saver_dist.py | 1 - 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/handlers/classification_saver.py b/monai/handlers/classification_saver.py index c8f1492965..70805b8a28 100644 --- a/monai/handlers/classification_saver.py +++ b/monai/handlers/classification_saver.py @@ -11,6 +11,7 @@ import logging from typing import TYPE_CHECKING, Callable, Optional + import torch from monai.data import CSVSaver diff --git a/tests/test_handler_classification_saver_dist.py b/tests/test_handler_classification_saver_dist.py index 8fd80d4993..40b2ed87de 100644 --- a/tests/test_handler_classification_saver_dist.py +++ b/tests/test_handler_classification_saver_dist.py @@ -53,7 +53,6 @@ def _train_func(engine, batch): { "filename_or_obj": ["testfile" + str(i) for i in range(18, 28)], "data_shape": [(1, 1) for _ in range(18, 28)], - } ) From 2737dd2a04c19907c3834bc99d00f0731ba2b7cc Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Tue, 27 Apr 2021 15:03:42 +0800 Subject: [PATCH 03/10] [DLMED] add typehints Signed-off-by: Nic Ma --- monai/handlers/classification_saver.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/monai/handlers/classification_saver.py b/monai/handlers/classification_saver.py index 70805b8a28..4072d8cc9c 100644 --- a/monai/handlers/classification_saver.py +++ b/monai/handlers/classification_saver.py @@ -10,7 +10,7 @@ # limitations under the License. import logging -from typing import TYPE_CHECKING, Callable, Optional +from typing import TYPE_CHECKING, Callable, Optional, List import torch @@ -71,8 +71,8 @@ def __init__( self.logger = logging.getLogger(name) self._name = name - self._outputs = [] - self._filenames = [] + self._outputs: List[torch.Tensor] = [] + self._filenames: List[str] = [] def attach(self, engine: Engine) -> None: """ From 35b6972185cd4b6ca931f0987cb21164b17c22dc Mon Sep 17 00:00:00 2001 From: monai-bot Date: Tue, 27 Apr 2021 07:16:23 +0000 Subject: [PATCH 04/10] [MONAI] python code formatting Signed-off-by: monai-bot --- monai/handlers/classification_saver.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/handlers/classification_saver.py b/monai/handlers/classification_saver.py index 4072d8cc9c..86692a4da8 100644 --- a/monai/handlers/classification_saver.py +++ b/monai/handlers/classification_saver.py @@ -10,7 +10,7 @@ # limitations under the License. import logging -from typing import TYPE_CHECKING, Callable, Optional, List +from typing import TYPE_CHECKING, Callable, List, Optional import torch From f63a5b502293e8c0285a499963dde123bdd0bd0e Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Tue, 27 Apr 2021 19:40:19 +0800 Subject: [PATCH 05/10] [DLMED] update according to comments Signed-off-by: Nic Ma --- monai/handlers/classification_saver.py | 22 ++++++++++++---------- monai/handlers/metrics_saver.py | 2 +- 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/monai/handlers/classification_saver.py b/monai/handlers/classification_saver.py index 86692a4da8..ae8e479570 100644 --- a/monai/handlers/classification_saver.py +++ b/monai/handlers/classification_saver.py @@ -11,13 +11,13 @@ import logging from typing import TYPE_CHECKING, Callable, List, Optional - +import warnings import torch from monai.data import CSVSaver from monai.handlers.utils import evenly_divisible_all_gather, string_list_all_gather from monai.utils import ImageMetaKey as Key -from monai.utils import exact_version, optional_import +from monai.utils import exact_version, optional_import, issequenceiterable idist, _ = optional_import("ignite", "0.4.4", exact_version, "distributed") Events, _ = optional_import("ignite.engine", "0.4.4", exact_version, "Events") @@ -62,7 +62,6 @@ def __init__( """ self._expected_rank: bool = idist.get_rank() == save_rank - self.saver = CSVSaver(output_dir, filename, overwrite) self.output_dir = output_dir self.filename = filename self.overwrite = overwrite @@ -81,10 +80,16 @@ def attach(self, engine: Engine) -> None: """ if self._name is None: self.logger = engine.logger + if not engine.has_event_handler(self._started, Events.EPOCH_STARTED): + engine.add_event_handler(Events.EPOCH_STARTED, self._started) if not engine.has_event_handler(self, Events.ITERATION_COMPLETED): engine.add_event_handler(Events.ITERATION_COMPLETED, self) - if not engine.has_event_handler(self.saver.finalize, Events.COMPLETED): - engine.add_event_handler(Events.COMPLETED, self._finalize) + if not engine.has_event_handler(self._finalize, Events.COMPLETED): + engine.add_event_handler(Events.EPOCH_COMPLETED, self._finalize) + + def _started(self, engine: Engine) -> None: + self._outputs = [] + self._filenames = [] def __call__(self, engine: Engine) -> None: """ @@ -94,7 +99,7 @@ def __call__(self, engine: Engine) -> None: engine: Ignite Engine, it can be a trainer, validator or evaluator. """ filenames = self.batch_transform(engine.state.batch).get(Key.FILENAME_OR_OBJ, None) - if filenames is not None: + if issequenceiterable(filenames): self._filenames.extend(filenames) outputs = self.output_transform(engine.state.output) if outputs is not None: @@ -112,7 +117,7 @@ def _finalize(self, engine: Engine) -> None: if len(filenames) == 0: meta_dict = None elif len(filenames) != len(outputs): - raise RuntimeError(f"filenames length: {len(filenames)} doesn't match outputs length: {len(outputs)}.") + warnings.warn(f"filenames length: {len(filenames)} doesn't match outputs length: {len(outputs)}.") else: meta_dict = {Key.FILENAME_OR_OBJ: filenames} @@ -121,6 +126,3 @@ def _finalize(self, engine: Engine) -> None: saver = CSVSaver(self.output_dir, self.filename, self.overwrite) saver.save_batch(outputs, meta_dict) saver.finalize() - # reset cache - self._outputs = [] - self._filenames = [] diff --git a/monai/handlers/metrics_saver.py b/monai/handlers/metrics_saver.py index 082c370e48..884ad72061 100644 --- a/monai/handlers/metrics_saver.py +++ b/monai/handlers/metrics_saver.py @@ -86,7 +86,7 @@ def attach(self, engine: Engine) -> None: Args: engine: Ignite Engine, it can be a trainer, validator or evaluator. """ - engine.add_event_handler(Events.STARTED, self._started) + engine.add_event_handler(Events.EPOCH_STARTED, self._started) engine.add_event_handler(Events.ITERATION_COMPLETED, self._get_filenames) engine.add_event_handler(Events.EPOCH_COMPLETED, self) From 85ee6158361697935bdc39097002c5ab0cf9b510 Mon Sep 17 00:00:00 2001 From: monai-bot Date: Tue, 27 Apr 2021 11:44:06 +0000 Subject: [PATCH 06/10] [MONAI] python code formatting Signed-off-by: monai-bot --- monai/handlers/classification_saver.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/monai/handlers/classification_saver.py b/monai/handlers/classification_saver.py index ae8e479570..1332225255 100644 --- a/monai/handlers/classification_saver.py +++ b/monai/handlers/classification_saver.py @@ -10,14 +10,15 @@ # limitations under the License. import logging -from typing import TYPE_CHECKING, Callable, List, Optional import warnings +from typing import TYPE_CHECKING, Callable, List, Optional + import torch from monai.data import CSVSaver from monai.handlers.utils import evenly_divisible_all_gather, string_list_all_gather from monai.utils import ImageMetaKey as Key -from monai.utils import exact_version, optional_import, issequenceiterable +from monai.utils import exact_version, issequenceiterable, optional_import idist, _ = optional_import("ignite", "0.4.4", exact_version, "distributed") Events, _ = optional_import("ignite.engine", "0.4.4", exact_version, "Events") From a70a0cb7cf1ace2fcffb6b7ec3b738baa4c8811a Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Tue, 27 Apr 2021 20:13:59 +0800 Subject: [PATCH 07/10] [DLMED] update according to comments Signed-off-by: Nic Ma --- monai/handlers/classification_saver.py | 16 ++++++++++++---- monai/handlers/metrics_saver.py | 2 +- 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/monai/handlers/classification_saver.py b/monai/handlers/classification_saver.py index 1332225255..5f182586ab 100644 --- a/monai/handlers/classification_saver.py +++ b/monai/handlers/classification_saver.py @@ -62,7 +62,7 @@ def __init__( default to 0. """ - self._expected_rank: bool = idist.get_rank() == save_rank + self.save_rank = save_rank self.output_dir = output_dir self.filename = filename self.overwrite = overwrite @@ -113,8 +113,16 @@ def _finalize(self, engine: Engine) -> None: Args: engine: Ignite Engine, it can be a trainer, validator or evaluator. """ - outputs = evenly_divisible_all_gather(torch.cat(self._outputs, dim=0)) - filenames = string_list_all_gather(self._filenames) + ws = idist.get_world_size() + if self.save_rank >= ws: + raise ValueError("target save rank is greater than the distributed group size.") + + outputs = torch.cat(self._outputs, dim=0) + filenames = self._filenames + if ws > 1: + outputs = evenly_divisible_all_gather(outputs) + filenames = string_list_all_gather(filenames) + if len(filenames) == 0: meta_dict = None elif len(filenames) != len(outputs): @@ -123,7 +131,7 @@ def _finalize(self, engine: Engine) -> None: meta_dict = {Key.FILENAME_OR_OBJ: filenames} # save to CSV file only in the expected rank - if self._expected_rank: + if idist.get_rank() == self.save_rank: saver = CSVSaver(self.output_dir, self.filename, self.overwrite) saver.save_batch(outputs, meta_dict) saver.finalize() diff --git a/monai/handlers/metrics_saver.py b/monai/handlers/metrics_saver.py index 884ad72061..c1d44deca9 100644 --- a/monai/handlers/metrics_saver.py +++ b/monai/handlers/metrics_saver.py @@ -105,7 +105,7 @@ def __call__(self, engine: Engine) -> None: """ ws = idist.get_world_size() if self.save_rank >= ws: - raise ValueError("target rank is greater than the distributed group size.") + raise ValueError("target save rank is greater than the distributed group size.") # all gather file names across ranks _images = string_list_all_gather(strings=self._filenames) if ws > 1 else self._filenames From 2955f98f36aa9aa8564aabb984fcf9bf8bf5f272 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Tue, 27 Apr 2021 21:08:11 +0800 Subject: [PATCH 08/10] [DLMED] update according to comments Signed-off-by: Nic Ma --- monai/handlers/classification_saver.py | 2 +- monai/handlers/metrics_saver.py | 9 +++++---- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/monai/handlers/classification_saver.py b/monai/handlers/classification_saver.py index 5f182586ab..3bf59713b8 100644 --- a/monai/handlers/classification_saver.py +++ b/monai/handlers/classification_saver.py @@ -99,7 +99,7 @@ def __call__(self, engine: Engine) -> None: Args: engine: Ignite Engine, it can be a trainer, validator or evaluator. """ - filenames = self.batch_transform(engine.state.batch).get(Key.FILENAME_OR_OBJ, None) + filenames = self.batch_transform(engine.state.batch).get(Key.FILENAME_OR_OBJ) if issequenceiterable(filenames): self._filenames.extend(filenames) outputs = self.output_transform(engine.state.output) diff --git a/monai/handlers/metrics_saver.py b/monai/handlers/metrics_saver.py index c1d44deca9..d40d2a5a02 100644 --- a/monai/handlers/metrics_saver.py +++ b/monai/handlers/metrics_saver.py @@ -13,7 +13,7 @@ from monai.handlers.utils import string_list_all_gather, write_metrics_reports from monai.utils import ImageMetaKey as Key -from monai.utils import ensure_tuple, exact_version, optional_import +from monai.utils import ensure_tuple, exact_version, issequenceiterable, optional_import Events, _ = optional_import("ignite.engine", "0.4.4", exact_version, "Events") idist, _ = optional_import("ignite", "0.4.4", exact_version, "distributed") @@ -95,8 +95,9 @@ def _started(self, engine: Engine) -> None: def _get_filenames(self, engine: Engine) -> None: if self.metric_details is not None: - _filenames = list(ensure_tuple(self.batch_transform(engine.state.batch)[Key.FILENAME_OR_OBJ])) - self._filenames += _filenames + filenames = self.batch_transform(engine.state.batch).get(Key.FILENAME_OR_OBJ) + if issequenceiterable(filenames): + self._filenames.extend(filenames) def __call__(self, engine: Engine) -> None: """ @@ -123,7 +124,7 @@ def __call__(self, engine: Engine) -> None: write_metrics_reports( save_dir=self.save_dir, - images=_images, + images=None if len(_images) == 0 else _images, metrics=_metrics, metric_details=_metric_details, summary_ops=self.summary_ops, From 49e238ad06d75062f4aacb63b729a5c6e9096d72 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Tue, 27 Apr 2021 21:12:28 +0800 Subject: [PATCH 09/10] [DLMED] fix typo Signed-off-by: Nic Ma --- monai/handlers/classification_saver.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/monai/handlers/classification_saver.py b/monai/handlers/classification_saver.py index 3bf59713b8..280431e472 100644 --- a/monai/handlers/classification_saver.py +++ b/monai/handlers/classification_saver.py @@ -125,9 +125,9 @@ def _finalize(self, engine: Engine) -> None: if len(filenames) == 0: meta_dict = None - elif len(filenames) != len(outputs): - warnings.warn(f"filenames length: {len(filenames)} doesn't match outputs length: {len(outputs)}.") else: + if len(filenames) != len(outputs): + warnings.warn(f"filenames length: {len(filenames)} doesn't match outputs length: {len(outputs)}.") meta_dict = {Key.FILENAME_OR_OBJ: filenames} # save to CSV file only in the expected rank From 2f82c2b7302e82e2ea0da0fd5c73fdef98e5b73d Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Tue, 27 Apr 2021 21:15:11 +0800 Subject: [PATCH 10/10] [DLMED] fix typo Signed-off-by: Nic Ma --- monai/handlers/classification_saver.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/handlers/classification_saver.py b/monai/handlers/classification_saver.py index 280431e472..d77dabde2f 100644 --- a/monai/handlers/classification_saver.py +++ b/monai/handlers/classification_saver.py @@ -85,7 +85,7 @@ def attach(self, engine: Engine) -> None: engine.add_event_handler(Events.EPOCH_STARTED, self._started) if not engine.has_event_handler(self, Events.ITERATION_COMPLETED): engine.add_event_handler(Events.ITERATION_COMPLETED, self) - if not engine.has_event_handler(self._finalize, Events.COMPLETED): + if not engine.has_event_handler(self._finalize, Events.EPOCH_COMPLETED): engine.add_event_handler(Events.EPOCH_COMPLETED, self._finalize) def _started(self, engine: Engine) -> None: