Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 55 additions & 15 deletions monai/handlers/classification_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,15 @@
# limitations under the License.

import logging
from typing import TYPE_CHECKING, Callable, Optional
import warnings
from typing import TYPE_CHECKING, Callable, List, Optional

import torch

from monai.data import CSVSaver
from monai.handlers.utils import evenly_divisible_all_gather, string_list_all_gather
from monai.utils import ImageMetaKey as Key
from monai.utils import exact_version, optional_import
from monai.utils import exact_version, issequenceiterable, optional_import

idist, _ = optional_import("ignite", "0.4.4", exact_version, "distributed")
Events, _ = optional_import("ignite.engine", "0.4.4", exact_version, "Events")
Expand Down Expand Up @@ -59,13 +62,17 @@ def __init__(
default to 0.

"""
self._expected_rank: bool = idist.get_rank() == save_rank
self.saver = CSVSaver(output_dir, filename, overwrite)
self.save_rank = save_rank
self.output_dir = output_dir
self.filename = filename
self.overwrite = overwrite
self.batch_transform = batch_transform
self.output_transform = output_transform

self.logger = logging.getLogger(name)
self._name = name
self._outputs: List[torch.Tensor] = []
self._filenames: List[str] = []

def attach(self, engine: Engine) -> None:
"""
Expand All @@ -74,10 +81,16 @@ def attach(self, engine: Engine) -> None:
"""
if self._name is None:
self.logger = engine.logger
if not engine.has_event_handler(self._started, Events.EPOCH_STARTED):
engine.add_event_handler(Events.EPOCH_STARTED, self._started)
if not engine.has_event_handler(self, Events.ITERATION_COMPLETED):
engine.add_event_handler(Events.ITERATION_COMPLETED, self)
if self._expected_rank and not engine.has_event_handler(self.saver.finalize, Events.COMPLETED):
engine.add_event_handler(Events.COMPLETED, lambda engine: self.saver.finalize())
if not engine.has_event_handler(self._finalize, Events.EPOCH_COMPLETED):
engine.add_event_handler(Events.EPOCH_COMPLETED, self._finalize)

def _started(self, engine: Engine) -> None:
self._outputs = []
self._filenames = []

def __call__(self, engine: Engine) -> None:
"""
Expand All @@ -86,12 +99,39 @@ def __call__(self, engine: Engine) -> None:
Args:
engine: Ignite Engine, it can be a trainer, validator or evaluator.
"""
_meta_data = self.batch_transform(engine.state.batch)
if Key.FILENAME_OR_OBJ in _meta_data:
# all gather filenames across ranks, only filenames are necessary
_meta_data = {Key.FILENAME_OR_OBJ: string_list_all_gather(_meta_data[Key.FILENAME_OR_OBJ])}
# all gather predictions across ranks
_engine_output = evenly_divisible_all_gather(self.output_transform(engine.state.output))

if self._expected_rank:
self.saver.save_batch(_engine_output, _meta_data)
filenames = self.batch_transform(engine.state.batch).get(Key.FILENAME_OR_OBJ)
if issequenceiterable(filenames):
self._filenames.extend(filenames)
Comment thread
Nic-Ma marked this conversation as resolved.
outputs = self.output_transform(engine.state.output)
if outputs is not None:
self._outputs.append(outputs)

def _finalize(self, engine: Engine) -> None:
"""
All gather classification results from ranks and save to CSV file.

Args:
engine: Ignite Engine, it can be a trainer, validator or evaluator.
"""
ws = idist.get_world_size()
if self.save_rank >= ws:
raise ValueError("target save rank is greater than the distributed group size.")

outputs = torch.cat(self._outputs, dim=0)
filenames = self._filenames
if ws > 1:
outputs = evenly_divisible_all_gather(outputs)
filenames = string_list_all_gather(filenames)

if len(filenames) == 0:
meta_dict = None
else:
if len(filenames) != len(outputs):
warnings.warn(f"filenames length: {len(filenames)} doesn't match outputs length: {len(outputs)}.")
meta_dict = {Key.FILENAME_OR_OBJ: filenames}

# save to CSV file only in the expected rank
if idist.get_rank() == self.save_rank:
saver = CSVSaver(self.output_dir, self.filename, self.overwrite)
Comment thread
Nic-Ma marked this conversation as resolved.
saver.save_batch(outputs, meta_dict)
saver.finalize()
13 changes: 7 additions & 6 deletions monai/handlers/metrics_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from monai.handlers.utils import string_list_all_gather, write_metrics_reports
from monai.utils import ImageMetaKey as Key
from monai.utils import ensure_tuple, exact_version, optional_import
from monai.utils import ensure_tuple, exact_version, issequenceiterable, optional_import

Events, _ = optional_import("ignite.engine", "0.4.4", exact_version, "Events")
idist, _ = optional_import("ignite", "0.4.4", exact_version, "distributed")
Expand Down Expand Up @@ -86,7 +86,7 @@ def attach(self, engine: Engine) -> None:
Args:
engine: Ignite Engine, it can be a trainer, validator or evaluator.
"""
engine.add_event_handler(Events.STARTED, self._started)
engine.add_event_handler(Events.EPOCH_STARTED, self._started)
engine.add_event_handler(Events.ITERATION_COMPLETED, self._get_filenames)
engine.add_event_handler(Events.EPOCH_COMPLETED, self)

Expand All @@ -95,8 +95,9 @@ def _started(self, engine: Engine) -> None:

def _get_filenames(self, engine: Engine) -> None:
if self.metric_details is not None:
_filenames = list(ensure_tuple(self.batch_transform(engine.state.batch)[Key.FILENAME_OR_OBJ]))
self._filenames += _filenames
filenames = self.batch_transform(engine.state.batch).get(Key.FILENAME_OR_OBJ)
if issequenceiterable(filenames):
self._filenames.extend(filenames)

def __call__(self, engine: Engine) -> None:
"""
Expand All @@ -105,7 +106,7 @@ def __call__(self, engine: Engine) -> None:
"""
ws = idist.get_world_size()
if self.save_rank >= ws:
raise ValueError("target rank is greater than the distributed group size.")
raise ValueError("target save rank is greater than the distributed group size.")

# all gather file names across ranks
_images = string_list_all_gather(strings=self._filenames) if ws > 1 else self._filenames
Expand All @@ -123,7 +124,7 @@ def __call__(self, engine: Engine) -> None:

write_metrics_reports(
save_dir=self.save_dir,
images=_images,
images=None if len(_images) == 0 else _images,
metrics=_metrics,
metric_details=_metric_details,
summary_ops=self.summary_ops,
Expand Down
11 changes: 10 additions & 1 deletion tests/test_handler_classification_saver_dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,15 @@ def _train_func(engine, batch):
"data_shape": [(1, 1) for _ in range(8 * rank, (8 + rank) * (rank + 1))],
}
]
# rank 1 has more iterations
if rank == 1:
data.append(
{
"filename_or_obj": ["testfile" + str(i) for i in range(18, 28)],
"data_shape": [(1, 1) for _ in range(18, 28)],
}
)

engine.run(data, max_epochs=1)
filepath = os.path.join(tempdir, "predictions.csv")
if rank == 1:
Expand All @@ -58,7 +67,7 @@ def _train_func(engine, batch):
self.assertEqual(row[0], "testfile" + str(i))
self.assertEqual(np.array(row[1:]).astype(np.float32), 0.0)
i += 1
self.assertEqual(i, 18)
self.assertEqual(i, 28)


if __name__ == "__main__":
Expand Down