diff --git a/monai/handlers/classification_saver.py b/monai/handlers/classification_saver.py index 98f917330f..d77dabde2f 100644 --- a/monai/handlers/classification_saver.py +++ b/monai/handlers/classification_saver.py @@ -10,12 +10,15 @@ # limitations under the License. import logging -from typing import TYPE_CHECKING, Callable, Optional +import warnings +from typing import TYPE_CHECKING, Callable, List, Optional + +import torch from monai.data import CSVSaver from monai.handlers.utils import evenly_divisible_all_gather, string_list_all_gather from monai.utils import ImageMetaKey as Key -from monai.utils import exact_version, optional_import +from monai.utils import exact_version, issequenceiterable, optional_import idist, _ = optional_import("ignite", "0.4.4", exact_version, "distributed") Events, _ = optional_import("ignite.engine", "0.4.4", exact_version, "Events") @@ -59,13 +62,17 @@ def __init__( default to 0. """ - self._expected_rank: bool = idist.get_rank() == save_rank - self.saver = CSVSaver(output_dir, filename, overwrite) + self.save_rank = save_rank + self.output_dir = output_dir + self.filename = filename + self.overwrite = overwrite self.batch_transform = batch_transform self.output_transform = output_transform self.logger = logging.getLogger(name) self._name = name + self._outputs: List[torch.Tensor] = [] + self._filenames: List[str] = [] def attach(self, engine: Engine) -> None: """ @@ -74,10 +81,16 @@ def attach(self, engine: Engine) -> None: """ if self._name is None: self.logger = engine.logger + if not engine.has_event_handler(self._started, Events.EPOCH_STARTED): + engine.add_event_handler(Events.EPOCH_STARTED, self._started) if not engine.has_event_handler(self, Events.ITERATION_COMPLETED): engine.add_event_handler(Events.ITERATION_COMPLETED, self) - if self._expected_rank and not engine.has_event_handler(self.saver.finalize, Events.COMPLETED): - engine.add_event_handler(Events.COMPLETED, lambda engine: self.saver.finalize()) + if not engine.has_event_handler(self._finalize, Events.EPOCH_COMPLETED): + engine.add_event_handler(Events.EPOCH_COMPLETED, self._finalize) + + def _started(self, engine: Engine) -> None: + self._outputs = [] + self._filenames = [] def __call__(self, engine: Engine) -> None: """ @@ -86,12 +99,39 @@ def __call__(self, engine: Engine) -> None: Args: engine: Ignite Engine, it can be a trainer, validator or evaluator. """ - _meta_data = self.batch_transform(engine.state.batch) - if Key.FILENAME_OR_OBJ in _meta_data: - # all gather filenames across ranks, only filenames are necessary - _meta_data = {Key.FILENAME_OR_OBJ: string_list_all_gather(_meta_data[Key.FILENAME_OR_OBJ])} - # all gather predictions across ranks - _engine_output = evenly_divisible_all_gather(self.output_transform(engine.state.output)) - - if self._expected_rank: - self.saver.save_batch(_engine_output, _meta_data) + filenames = self.batch_transform(engine.state.batch).get(Key.FILENAME_OR_OBJ) + if issequenceiterable(filenames): + self._filenames.extend(filenames) + outputs = self.output_transform(engine.state.output) + if outputs is not None: + self._outputs.append(outputs) + + def _finalize(self, engine: Engine) -> None: + """ + All gather classification results from ranks and save to CSV file. + + Args: + engine: Ignite Engine, it can be a trainer, validator or evaluator. + """ + ws = idist.get_world_size() + if self.save_rank >= ws: + raise ValueError("target save rank is greater than the distributed group size.") + + outputs = torch.cat(self._outputs, dim=0) + filenames = self._filenames + if ws > 1: + outputs = evenly_divisible_all_gather(outputs) + filenames = string_list_all_gather(filenames) + + if len(filenames) == 0: + meta_dict = None + else: + if len(filenames) != len(outputs): + warnings.warn(f"filenames length: {len(filenames)} doesn't match outputs length: {len(outputs)}.") + meta_dict = {Key.FILENAME_OR_OBJ: filenames} + + # save to CSV file only in the expected rank + if idist.get_rank() == self.save_rank: + saver = CSVSaver(self.output_dir, self.filename, self.overwrite) + saver.save_batch(outputs, meta_dict) + saver.finalize() diff --git a/monai/handlers/metrics_saver.py b/monai/handlers/metrics_saver.py index 082c370e48..d40d2a5a02 100644 --- a/monai/handlers/metrics_saver.py +++ b/monai/handlers/metrics_saver.py @@ -13,7 +13,7 @@ from monai.handlers.utils import string_list_all_gather, write_metrics_reports from monai.utils import ImageMetaKey as Key -from monai.utils import ensure_tuple, exact_version, optional_import +from monai.utils import ensure_tuple, exact_version, issequenceiterable, optional_import Events, _ = optional_import("ignite.engine", "0.4.4", exact_version, "Events") idist, _ = optional_import("ignite", "0.4.4", exact_version, "distributed") @@ -86,7 +86,7 @@ def attach(self, engine: Engine) -> None: Args: engine: Ignite Engine, it can be a trainer, validator or evaluator. """ - engine.add_event_handler(Events.STARTED, self._started) + engine.add_event_handler(Events.EPOCH_STARTED, self._started) engine.add_event_handler(Events.ITERATION_COMPLETED, self._get_filenames) engine.add_event_handler(Events.EPOCH_COMPLETED, self) @@ -95,8 +95,9 @@ def _started(self, engine: Engine) -> None: def _get_filenames(self, engine: Engine) -> None: if self.metric_details is not None: - _filenames = list(ensure_tuple(self.batch_transform(engine.state.batch)[Key.FILENAME_OR_OBJ])) - self._filenames += _filenames + filenames = self.batch_transform(engine.state.batch).get(Key.FILENAME_OR_OBJ) + if issequenceiterable(filenames): + self._filenames.extend(filenames) def __call__(self, engine: Engine) -> None: """ @@ -105,7 +106,7 @@ def __call__(self, engine: Engine) -> None: """ ws = idist.get_world_size() if self.save_rank >= ws: - raise ValueError("target rank is greater than the distributed group size.") + raise ValueError("target save rank is greater than the distributed group size.") # all gather file names across ranks _images = string_list_all_gather(strings=self._filenames) if ws > 1 else self._filenames @@ -123,7 +124,7 @@ def __call__(self, engine: Engine) -> None: write_metrics_reports( save_dir=self.save_dir, - images=_images, + images=None if len(_images) == 0 else _images, metrics=_metrics, metric_details=_metric_details, summary_ops=self.summary_ops, diff --git a/tests/test_handler_classification_saver_dist.py b/tests/test_handler_classification_saver_dist.py index a33cba923a..40b2ed87de 100644 --- a/tests/test_handler_classification_saver_dist.py +++ b/tests/test_handler_classification_saver_dist.py @@ -47,6 +47,15 @@ def _train_func(engine, batch): "data_shape": [(1, 1) for _ in range(8 * rank, (8 + rank) * (rank + 1))], } ] + # rank 1 has more iterations + if rank == 1: + data.append( + { + "filename_or_obj": ["testfile" + str(i) for i in range(18, 28)], + "data_shape": [(1, 1) for _ in range(18, 28)], + } + ) + engine.run(data, max_epochs=1) filepath = os.path.join(tempdir, "predictions.csv") if rank == 1: @@ -58,7 +67,7 @@ def _train_func(engine, batch): self.assertEqual(row[0], "testfile" + str(i)) self.assertEqual(np.array(row[1:]).astype(np.float32), 0.0) i += 1 - self.assertEqual(i, 18) + self.assertEqual(i, 28) if __name__ == "__main__":