From 8775fd7753106b2792545cad48e6e5dd1584496d Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Fri, 12 Mar 2021 23:52:41 +0800 Subject: [PATCH 1/4] [DLMED] update ignite to 0.4.4 Signed-off-by: Nic Ma --- docs/requirements.txt | 2 +- monai/engines/evaluator.py | 4 ++-- monai/engines/multi_gpu_supervised_trainer.py | 10 +++++----- monai/engines/trainer.py | 4 ++-- monai/engines/utils.py | 2 +- monai/engines/workflow.py | 10 +++++----- monai/handlers/checkpoint_loader.py | 6 +++--- monai/handlers/checkpoint_saver.py | 10 +++++----- monai/handlers/classification_saver.py | 6 +++--- monai/handlers/iteration_metric.py | 8 ++++---- monai/handlers/lr_schedule_handler.py | 4 ++-- monai/handlers/metric_logger.py | 4 ++-- monai/handlers/metrics_saver.py | 6 +++--- monai/handlers/roc_auc.py | 4 ++-- monai/handlers/segmentation_saver.py | 4 ++-- monai/handlers/smartcache_handler.py | 4 ++-- monai/handlers/stats_handler.py | 4 ++-- monai/handlers/tensorboard_handlers.py | 4 ++-- monai/handlers/utils.py | 4 ++-- monai/handlers/validation_handler.py | 4 ++-- requirements-dev.txt | 2 +- setup.cfg | 4 ++-- 22 files changed, 55 insertions(+), 55 deletions(-) diff --git a/docs/requirements.txt b/docs/requirements.txt index cd06166359..22fd2589f0 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1,6 +1,6 @@ -f https://download.pytorch.org/whl/cpu/torch-1.6.0%2Bcpu-cp37-cp37m-linux_x86_64.whl torch>=1.5 -pytorch-ignite==0.4.2 +pytorch-ignite==0.4.4 numpy>=1.17 itk>=5.0 nibabel diff --git a/monai/engines/evaluator.py b/monai/engines/evaluator.py index b8977a3652..0afa3747a4 100644 --- a/monai/engines/evaluator.py +++ b/monai/engines/evaluator.py @@ -26,8 +26,8 @@ from ignite.engine import Engine from ignite.metrics import Metric else: - Engine, _ = optional_import("ignite.engine", "0.4.2", exact_version, "Engine") - Metric, _ = optional_import("ignite.metrics", "0.4.2", exact_version, "Metric") + Engine, _ = optional_import("ignite.engine", "0.4.4", exact_version, "Engine") + Metric, _ = optional_import("ignite.metrics", "0.4.4", exact_version, "Metric") __all__ = ["Evaluator", "SupervisedEvaluator", "EnsembleEvaluator"] diff --git a/monai/engines/multi_gpu_supervised_trainer.py b/monai/engines/multi_gpu_supervised_trainer.py index d12e012a56..d0e09443fa 100644 --- a/monai/engines/multi_gpu_supervised_trainer.py +++ b/monai/engines/multi_gpu_supervised_trainer.py @@ -19,15 +19,15 @@ from monai.engines.utils import get_devices_spec from monai.utils import exact_version, optional_import -create_supervised_trainer, _ = optional_import("ignite.engine", "0.4.2", exact_version, "create_supervised_trainer") -create_supervised_evaluator, _ = optional_import("ignite.engine", "0.4.2", exact_version, "create_supervised_evaluator") -_prepare_batch, _ = optional_import("ignite.engine", "0.4.2", exact_version, "_prepare_batch") +create_supervised_trainer, _ = optional_import("ignite.engine", "0.4.4", exact_version, "create_supervised_trainer") +create_supervised_evaluator, _ = optional_import("ignite.engine", "0.4.4", exact_version, "create_supervised_evaluator") +_prepare_batch, _ = optional_import("ignite.engine", "0.4.4", exact_version, "_prepare_batch") if TYPE_CHECKING: from ignite.engine import Engine from ignite.metrics import Metric else: - Engine, _ = optional_import("ignite.engine", "0.4.2", exact_version, "Engine") - Metric, _ = optional_import("ignite.metrics", "0.4.2", exact_version, "Metric") + Engine, _ = optional_import("ignite.engine", "0.4.4", exact_version, "Engine") + Metric, _ = optional_import("ignite.metrics", "0.4.4", exact_version, "Metric") __all__ = [ "create_multigpu_supervised_trainer", diff --git a/monai/engines/trainer.py b/monai/engines/trainer.py index c3d471e261..5b996eafe1 100644 --- a/monai/engines/trainer.py +++ b/monai/engines/trainer.py @@ -26,8 +26,8 @@ from ignite.engine import Engine from ignite.metrics import Metric else: - Engine, _ = optional_import("ignite.engine", "0.4.2", exact_version, "Engine") - Metric, _ = optional_import("ignite.metrics", "0.4.2", exact_version, "Metric") + Engine, _ = optional_import("ignite.engine", "0.4.4", exact_version, "Engine") + Metric, _ = optional_import("ignite.metrics", "0.4.4", exact_version, "Metric") __all__ = ["Trainer", "SupervisedTrainer", "GanTrainer"] diff --git a/monai/engines/utils.py b/monai/engines/utils.py index 8f5899f2a5..b0b1e44f71 100644 --- a/monai/engines/utils.py +++ b/monai/engines/utils.py @@ -18,7 +18,7 @@ if TYPE_CHECKING: from ignite.engine import EventEnum else: - EventEnum, _ = optional_import("ignite.engine", "0.4.2", exact_version, "EventEnum") + EventEnum, _ = optional_import("ignite.engine", "0.4.4", exact_version, "EventEnum") __all__ = [ "IterationEvents", diff --git a/monai/engines/workflow.py b/monai/engines/workflow.py index b50d58f1a2..61b92ac5dd 100644 --- a/monai/engines/workflow.py +++ b/monai/engines/workflow.py @@ -20,15 +20,15 @@ from monai.transforms import apply_transform from monai.utils import ensure_tuple, exact_version, optional_import -IgniteEngine, _ = optional_import("ignite.engine", "0.4.2", exact_version, "Engine") -State, _ = optional_import("ignite.engine", "0.4.2", exact_version, "State") -Events, _ = optional_import("ignite.engine", "0.4.2", exact_version, "Events") +IgniteEngine, _ = optional_import("ignite.engine", "0.4.4", exact_version, "Engine") +State, _ = optional_import("ignite.engine", "0.4.4", exact_version, "State") +Events, _ = optional_import("ignite.engine", "0.4.4", exact_version, "Events") if TYPE_CHECKING: from ignite.engine import Engine from ignite.metrics import Metric else: - Engine, _ = optional_import("ignite.engine", "0.4.2", exact_version, "Engine") - Metric, _ = optional_import("ignite.metrics", "0.4.2", exact_version, "Metric") + Engine, _ = optional_import("ignite.engine", "0.4.4", exact_version, "Engine") + Metric, _ = optional_import("ignite.metrics", "0.4.4", exact_version, "Metric") class Workflow(IgniteEngine): # type: ignore[valid-type, misc] # due to optional_import diff --git a/monai/handlers/checkpoint_loader.py b/monai/handlers/checkpoint_loader.py index 648cc8360a..e6319a3c64 100644 --- a/monai/handlers/checkpoint_loader.py +++ b/monai/handlers/checkpoint_loader.py @@ -16,12 +16,12 @@ from monai.utils import exact_version, optional_import -Events, _ = optional_import("ignite.engine", "0.4.2", exact_version, "Events") -Checkpoint, _ = optional_import("ignite.handlers", "0.4.2", exact_version, "Checkpoint") +Events, _ = optional_import("ignite.engine", "0.4.4", exact_version, "Events") +Checkpoint, _ = optional_import("ignite.handlers", "0.4.4", exact_version, "Checkpoint") if TYPE_CHECKING: from ignite.engine import Engine else: - Engine, _ = optional_import("ignite.engine", "0.4.2", exact_version, "Engine") + Engine, _ = optional_import("ignite.engine", "0.4.4", exact_version, "Engine") class CheckpointLoader: diff --git a/monai/handlers/checkpoint_saver.py b/monai/handlers/checkpoint_saver.py index 1808e6b251..0c65b8cd4b 100644 --- a/monai/handlers/checkpoint_saver.py +++ b/monai/handlers/checkpoint_saver.py @@ -15,16 +15,16 @@ from monai.utils import exact_version, optional_import -Events, _ = optional_import("ignite.engine", "0.4.2", exact_version, "Events") -Checkpoint, _ = optional_import("ignite.handlers", "0.4.2", exact_version, "Checkpoint") -BaseSaveHandler, _ = optional_import("ignite.handlers.checkpoint", "0.4.2", exact_version, "BaseSaveHandler") +Events, _ = optional_import("ignite.engine", "0.4.4", exact_version, "Events") +Checkpoint, _ = optional_import("ignite.handlers", "0.4.4", exact_version, "Checkpoint") +BaseSaveHandler, _ = optional_import("ignite.handlers.checkpoint", "0.4.4", exact_version, "BaseSaveHandler") if TYPE_CHECKING: from ignite.engine import Engine from ignite.handlers import DiskSaver else: - Engine, _ = optional_import("ignite.engine", "0.4.2", exact_version, "Engine") - DiskSaver, _ = optional_import("ignite.handlers", "0.4.2", exact_version, "DiskSaver") + Engine, _ = optional_import("ignite.engine", "0.4.4", exact_version, "Engine") + DiskSaver, _ = optional_import("ignite.handlers", "0.4.4", exact_version, "DiskSaver") class CheckpointSaver: diff --git a/monai/handlers/classification_saver.py b/monai/handlers/classification_saver.py index 33ce7c7ec8..98f917330f 100644 --- a/monai/handlers/classification_saver.py +++ b/monai/handlers/classification_saver.py @@ -17,12 +17,12 @@ from monai.utils import ImageMetaKey as Key from monai.utils import exact_version, optional_import -idist, _ = optional_import("ignite", "0.4.2", exact_version, "distributed") -Events, _ = optional_import("ignite.engine", "0.4.2", exact_version, "Events") +idist, _ = optional_import("ignite", "0.4.4", exact_version, "distributed") +Events, _ = optional_import("ignite.engine", "0.4.4", exact_version, "Events") if TYPE_CHECKING: from ignite.engine import Engine else: - Engine, _ = optional_import("ignite.engine", "0.4.2", exact_version, "Engine") + Engine, _ = optional_import("ignite.engine", "0.4.4", exact_version, "Engine") class ClassificationSaver: diff --git a/monai/handlers/iteration_metric.py b/monai/handlers/iteration_metric.py index 641efad243..31c5c0498a 100644 --- a/monai/handlers/iteration_metric.py +++ b/monai/handlers/iteration_metric.py @@ -17,13 +17,13 @@ from monai.metrics import do_metric_reduction from monai.utils import MetricReduction, exact_version, optional_import -idist, _ = optional_import("ignite", "0.4.2", exact_version, "distributed") -Metric, _ = optional_import("ignite.metrics", "0.4.2", exact_version, "Metric") -reinit__is_reduced, _ = optional_import("ignite.metrics.metric", "0.4.2", exact_version, "reinit__is_reduced") +idist, _ = optional_import("ignite", "0.4.4", exact_version, "distributed") +Metric, _ = optional_import("ignite.metrics", "0.4.4", exact_version, "Metric") +reinit__is_reduced, _ = optional_import("ignite.metrics.metric", "0.4.4", exact_version, "reinit__is_reduced") if TYPE_CHECKING: from ignite.engine import Engine else: - Engine, _ = optional_import("ignite.engine", "0.4.2", exact_version, "Engine") + Engine, _ = optional_import("ignite.engine", "0.4.4", exact_version, "Engine") class IterationMetric(Metric): # type: ignore[valid-type, misc] # due to optional_import diff --git a/monai/handlers/lr_schedule_handler.py b/monai/handlers/lr_schedule_handler.py index e5593f07ff..3b300537b2 100644 --- a/monai/handlers/lr_schedule_handler.py +++ b/monai/handlers/lr_schedule_handler.py @@ -16,11 +16,11 @@ from monai.utils import ensure_tuple, exact_version, optional_import -Events, _ = optional_import("ignite.engine", "0.4.2", exact_version, "Events") +Events, _ = optional_import("ignite.engine", "0.4.4", exact_version, "Events") if TYPE_CHECKING: from ignite.engine import Engine else: - Engine, _ = optional_import("ignite.engine", "0.4.2", exact_version, "Engine") + Engine, _ = optional_import("ignite.engine", "0.4.4", exact_version, "Engine") class LrScheduleHandler: diff --git a/monai/handlers/metric_logger.py b/monai/handlers/metric_logger.py index fdd60da57c..758276d03d 100644 --- a/monai/handlers/metric_logger.py +++ b/monai/handlers/metric_logger.py @@ -14,11 +14,11 @@ from monai.utils import exact_version, optional_import -Events, _ = optional_import("ignite.engine", "0.4.2", exact_version, "Events") +Events, _ = optional_import("ignite.engine", "0.4.4", exact_version, "Events") if TYPE_CHECKING: from ignite.engine import Engine else: - Engine, _ = optional_import("ignite.engine", "0.4.2", exact_version, "Engine") + Engine, _ = optional_import("ignite.engine", "0.4.4", exact_version, "Engine") class MetricLogger: diff --git a/monai/handlers/metrics_saver.py b/monai/handlers/metrics_saver.py index 87d7223c96..082c370e48 100644 --- a/monai/handlers/metrics_saver.py +++ b/monai/handlers/metrics_saver.py @@ -15,12 +15,12 @@ from monai.utils import ImageMetaKey as Key from monai.utils import ensure_tuple, exact_version, optional_import -Events, _ = optional_import("ignite.engine", "0.4.2", exact_version, "Events") -idist, _ = optional_import("ignite", "0.4.2", exact_version, "distributed") +Events, _ = optional_import("ignite.engine", "0.4.4", exact_version, "Events") +idist, _ = optional_import("ignite", "0.4.4", exact_version, "distributed") if TYPE_CHECKING: from ignite.engine import Engine else: - Engine, _ = optional_import("ignite.engine", "0.4.2", exact_version, "Engine") + Engine, _ = optional_import("ignite.engine", "0.4.4", exact_version, "Engine") class MetricsSaver: diff --git a/monai/handlers/roc_auc.py b/monai/handlers/roc_auc.py index 2273b9ee89..3e7b8a4a1e 100644 --- a/monai/handlers/roc_auc.py +++ b/monai/handlers/roc_auc.py @@ -17,8 +17,8 @@ from monai.metrics import compute_roc_auc from monai.utils import Average, exact_version, optional_import -idist, _ = optional_import("ignite", "0.4.2", exact_version, "distributed") -EpochMetric, _ = optional_import("ignite.metrics", "0.4.2", exact_version, "EpochMetric") +idist, _ = optional_import("ignite", "0.4.4", exact_version, "distributed") +EpochMetric, _ = optional_import("ignite.metrics", "0.4.4", exact_version, "EpochMetric") class ROCAUC(EpochMetric): # type: ignore[valid-type, misc] # due to optional_import diff --git a/monai/handlers/segmentation_saver.py b/monai/handlers/segmentation_saver.py index 56370fd41c..25238ea442 100644 --- a/monai/handlers/segmentation_saver.py +++ b/monai/handlers/segmentation_saver.py @@ -18,11 +18,11 @@ from monai.transforms import SaveImage from monai.utils import GridSampleMode, GridSamplePadMode, InterpolateMode, exact_version, optional_import -Events, _ = optional_import("ignite.engine", "0.4.2", exact_version, "Events") +Events, _ = optional_import("ignite.engine", "0.4.4", exact_version, "Events") if TYPE_CHECKING: from ignite.engine import Engine else: - Engine, _ = optional_import("ignite.engine", "0.4.2", exact_version, "Engine") + Engine, _ = optional_import("ignite.engine", "0.4.4", exact_version, "Engine") class SegmentationSaver: diff --git a/monai/handlers/smartcache_handler.py b/monai/handlers/smartcache_handler.py index 423d87c22a..821f883d91 100644 --- a/monai/handlers/smartcache_handler.py +++ b/monai/handlers/smartcache_handler.py @@ -14,11 +14,11 @@ from monai.data import SmartCacheDataset from monai.utils import exact_version, optional_import -Events, _ = optional_import("ignite.engine", "0.4.2", exact_version, "Events") +Events, _ = optional_import("ignite.engine", "0.4.4", exact_version, "Events") if TYPE_CHECKING: from ignite.engine import Engine else: - Engine, _ = optional_import("ignite.engine", "0.4.2", exact_version, "Engine") + Engine, _ = optional_import("ignite.engine", "0.4.4", exact_version, "Engine") class SmartCacheHandler: diff --git a/monai/handlers/stats_handler.py b/monai/handlers/stats_handler.py index 24d844569f..6d4a4e958b 100644 --- a/monai/handlers/stats_handler.py +++ b/monai/handlers/stats_handler.py @@ -17,11 +17,11 @@ from monai.utils import exact_version, is_scalar, optional_import -Events, _ = optional_import("ignite.engine", "0.4.2", exact_version, "Events") +Events, _ = optional_import("ignite.engine", "0.4.4", exact_version, "Events") if TYPE_CHECKING: from ignite.engine import Engine else: - Engine, _ = optional_import("ignite.engine", "0.4.2", exact_version, "Engine") + Engine, _ = optional_import("ignite.engine", "0.4.4", exact_version, "Engine") DEFAULT_KEY_VAL_FORMAT = "{}: {:.4f} " DEFAULT_TAG = "Loss" diff --git a/monai/handlers/tensorboard_handlers.py b/monai/handlers/tensorboard_handlers.py index 4ee88bcfc9..9ad1fe6353 100644 --- a/monai/handlers/tensorboard_handlers.py +++ b/monai/handlers/tensorboard_handlers.py @@ -18,12 +18,12 @@ from monai.utils import exact_version, is_scalar, optional_import from monai.visualize import plot_2d_or_3d_image -Events, _ = optional_import("ignite.engine", "0.4.2", exact_version, "Events") +Events, _ = optional_import("ignite.engine", "0.4.4", exact_version, "Events") if TYPE_CHECKING: from ignite.engine import Engine from torch.utils.tensorboard import SummaryWriter else: - Engine, _ = optional_import("ignite.engine", "0.4.2", exact_version, "Engine") + Engine, _ = optional_import("ignite.engine", "0.4.4", exact_version, "Engine") SummaryWriter, _ = optional_import("torch.utils.tensorboard", name="SummaryWriter") DEFAULT_TAG = "Loss" diff --git a/monai/handlers/utils.py b/monai/handlers/utils.py index 9ed13d292c..d551ae1ce6 100644 --- a/monai/handlers/utils.py +++ b/monai/handlers/utils.py @@ -18,11 +18,11 @@ from monai.utils import ensure_tuple, exact_version, get_torch_version_tuple, optional_import -idist, _ = optional_import("ignite", "0.4.2", exact_version, "distributed") +idist, _ = optional_import("ignite", "0.4.4", exact_version, "distributed") if TYPE_CHECKING: from ignite.engine import Engine else: - Engine, _ = optional_import("ignite.engine", "0.4.2", exact_version, "Engine") + Engine, _ = optional_import("ignite.engine", "0.4.4", exact_version, "Engine") __all__ = [ "stopping_fn_from_metric", diff --git a/monai/handlers/validation_handler.py b/monai/handlers/validation_handler.py index 9cc2e926f4..4458a17380 100644 --- a/monai/handlers/validation_handler.py +++ b/monai/handlers/validation_handler.py @@ -14,11 +14,11 @@ from monai.engines.evaluator import Evaluator from monai.utils import exact_version, optional_import -Events, _ = optional_import("ignite.engine", "0.4.2", exact_version, "Events") +Events, _ = optional_import("ignite.engine", "0.4.4", exact_version, "Events") if TYPE_CHECKING: from ignite.engine import Engine else: - Engine, _ = optional_import("ignite.engine", "0.4.2", exact_version, "Engine") + Engine, _ = optional_import("ignite.engine", "0.4.4", exact_version, "Engine") class ValidationHandler: diff --git a/requirements-dev.txt b/requirements-dev.txt index 3eeab474b6..1508eae4fe 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,6 +1,6 @@ # Full requirements for developments -r requirements-min.txt -pytorch-ignite==0.4.2 +pytorch-ignite==0.4.4 gdown>=3.6.4 scipy itk>=5.0 diff --git a/setup.cfg b/setup.cfg index bbdcdf805d..9dd9fa106b 100644 --- a/setup.cfg +++ b/setup.cfg @@ -27,7 +27,7 @@ all = scikit-image>=0.14.2 pillow tensorboard - pytorch-ignite==0.4.2 + pytorch-ignite==0.4.4 gdown>=3.6.4 torchvision itk>=5.0 @@ -44,7 +44,7 @@ tensorboard = gdown = gdown>=3.6.4 ignite = - pytorch-ignite==0.4.2 + pytorch-ignite==0.4.4 torchvision = torchvision itk = From c00f10a3be52f886cdbe2b1418d1c66c5c8934f6 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Mon, 15 Mar 2021 15:37:05 +0800 Subject: [PATCH 2/4] [DLMED] fix ignite compatible issues Signed-off-by: Nic Ma --- monai/handlers/confusion_matrix.py | 4 ++-- monai/handlers/hausdorff_distance.py | 4 ++-- monai/handlers/iteration_metric.py | 6 +++--- monai/handlers/mean_dice.py | 4 ++-- monai/handlers/roc_auc.py | 2 +- monai/handlers/surface_distance.py | 4 ++-- monai/handlers/utils.py | 4 ++-- 7 files changed, 14 insertions(+), 14 deletions(-) diff --git a/monai/handlers/confusion_matrix.py b/monai/handlers/confusion_matrix.py index 1741aa305a..e50bff9123 100644 --- a/monai/handlers/confusion_matrix.py +++ b/monai/handlers/confusion_matrix.py @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Optional +from typing import Any, Callable, Union import torch @@ -28,7 +28,7 @@ def __init__( include_background: bool = True, metric_name: str = "hit_rate", output_transform: Callable = lambda x: x, - device: Optional[torch.device] = None, + device: Union[str, torch.device] = torch.device("cpu"), save_details: bool = True, ) -> None: """ diff --git a/monai/handlers/hausdorff_distance.py b/monai/handlers/hausdorff_distance.py index 7ac52d642a..f6d75aa638 100644 --- a/monai/handlers/hausdorff_distance.py +++ b/monai/handlers/hausdorff_distance.py @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, Optional +from typing import Callable, Optional, Union import torch @@ -30,7 +30,7 @@ def __init__( percentile: Optional[float] = None, directed: bool = False, output_transform: Callable = lambda x: x, - device: Optional[torch.device] = None, + device: Union[str, torch.device] = torch.device("cpu"), save_details: bool = True, ) -> None: """ diff --git a/monai/handlers/iteration_metric.py b/monai/handlers/iteration_metric.py index 31c5c0498a..b3717329b8 100644 --- a/monai/handlers/iteration_metric.py +++ b/monai/handlers/iteration_metric.py @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, Any, Callable, List, Optional, Sequence +from typing import TYPE_CHECKING, Any, Callable, List, Optional, Sequence, Union import torch @@ -46,7 +46,7 @@ def __init__( self, metric_fn: Callable, output_transform: Callable = lambda x: x, - device: Optional[torch.device] = None, + device: Union[str, torch.device] = torch.device("cpu"), save_details: bool = True, ) -> None: self._is_reduced: bool = False @@ -77,7 +77,7 @@ def update(self, output: Sequence[torch.Tensor]) -> None: score = self.metric_fn(y_pred, y) if isinstance(score, (tuple, list)): score = score[0] - self._scores.append(score) + self._scores.append(score.clone().to(self._device)) def compute(self) -> Any: """ diff --git a/monai/handlers/mean_dice.py b/monai/handlers/mean_dice.py index 7decc3ab9b..088f849ef9 100644 --- a/monai/handlers/mean_dice.py +++ b/monai/handlers/mean_dice.py @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, Optional +from typing import Callable, Union import torch @@ -27,7 +27,7 @@ def __init__( self, include_background: bool = True, output_transform: Callable = lambda x: x, - device: Optional[torch.device] = None, + device: Union[str, torch.device] = torch.device("cpu"), save_details: bool = True, ) -> None: """ diff --git a/monai/handlers/roc_auc.py b/monai/handlers/roc_auc.py index 3e7b8a4a1e..1c29453e89 100644 --- a/monai/handlers/roc_auc.py +++ b/monai/handlers/roc_auc.py @@ -61,7 +61,7 @@ def __init__( other_act: Optional[Callable] = None, average: Union[Average, str] = Average.MACRO, output_transform: Callable = lambda x: x, - device: Optional[torch.device] = None, + device: Union[str, torch.device] = torch.device("cpu"), ) -> None: def _compute_fn(pred, label): return compute_roc_auc( diff --git a/monai/handlers/surface_distance.py b/monai/handlers/surface_distance.py index d3fa69bfce..af18d4c29c 100644 --- a/monai/handlers/surface_distance.py +++ b/monai/handlers/surface_distance.py @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, Optional +from typing import Callable, Union import torch @@ -29,7 +29,7 @@ def __init__( symmetric: bool = False, distance_metric: str = "euclidean", output_transform: Callable = lambda x: x, - device: Optional[torch.device] = None, + device: Union[str, torch.device] = torch.device("cpu"), save_details: bool = True, ) -> None: """ diff --git a/monai/handlers/utils.py b/monai/handlers/utils.py index d551ae1ce6..2eaf3ab932 100644 --- a/monai/handlers/utils.py +++ b/monai/handlers/utils.py @@ -75,7 +75,7 @@ def evenly_divisible_all_gather(data: torch.Tensor) -> torch.Tensor: # make sure the data is evenly-divisible on multi-GPUs length = data.shape[0] all_lens = idist.all_gather(length) - max_len = max(all_lens).item() + max_len = max(all_lens) if length < max_len: size = [max_len - length] + list(data.shape[1:]) data = torch.cat([data, data.new_full(size, 0)], dim=0) @@ -103,7 +103,7 @@ def string_list_all_gather(strings: List[str]) -> List[str]: # get length of strings length = len(strings) all_lens = idist.all_gather(length) - max_len = max(all_lens).item() + max_len = max(all_lens) # pad the item to make sure the same length if length < max_len: strings = strings + ["" for _ in range(max_len - length)] From 203ce741cc4e6f1ca34a066d6073282f4cd032ec Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Mon, 15 Mar 2021 15:47:07 +0800 Subject: [PATCH 3/4] [DLMED] fix flake8 issue Signed-off-by: Nic Ma --- monai/handlers/confusion_matrix.py | 2 +- monai/handlers/hausdorff_distance.py | 2 +- monai/handlers/iteration_metric.py | 2 +- monai/handlers/mean_dice.py | 2 +- monai/handlers/roc_auc.py | 2 +- monai/handlers/surface_distance.py | 2 +- 6 files changed, 6 insertions(+), 6 deletions(-) diff --git a/monai/handlers/confusion_matrix.py b/monai/handlers/confusion_matrix.py index e50bff9123..551fd29199 100644 --- a/monai/handlers/confusion_matrix.py +++ b/monai/handlers/confusion_matrix.py @@ -28,7 +28,7 @@ def __init__( include_background: bool = True, metric_name: str = "hit_rate", output_transform: Callable = lambda x: x, - device: Union[str, torch.device] = torch.device("cpu"), + device: Union[str, torch.device] = "cpu", save_details: bool = True, ) -> None: """ diff --git a/monai/handlers/hausdorff_distance.py b/monai/handlers/hausdorff_distance.py index f6d75aa638..042a587852 100644 --- a/monai/handlers/hausdorff_distance.py +++ b/monai/handlers/hausdorff_distance.py @@ -30,7 +30,7 @@ def __init__( percentile: Optional[float] = None, directed: bool = False, output_transform: Callable = lambda x: x, - device: Union[str, torch.device] = torch.device("cpu"), + device: Union[str, torch.device] = "cpu", save_details: bool = True, ) -> None: """ diff --git a/monai/handlers/iteration_metric.py b/monai/handlers/iteration_metric.py index b3717329b8..3686174b5a 100644 --- a/monai/handlers/iteration_metric.py +++ b/monai/handlers/iteration_metric.py @@ -46,7 +46,7 @@ def __init__( self, metric_fn: Callable, output_transform: Callable = lambda x: x, - device: Union[str, torch.device] = torch.device("cpu"), + device: Union[str, torch.device] = "cpu", save_details: bool = True, ) -> None: self._is_reduced: bool = False diff --git a/monai/handlers/mean_dice.py b/monai/handlers/mean_dice.py index 088f849ef9..6d51c534cf 100644 --- a/monai/handlers/mean_dice.py +++ b/monai/handlers/mean_dice.py @@ -27,7 +27,7 @@ def __init__( self, include_background: bool = True, output_transform: Callable = lambda x: x, - device: Union[str, torch.device] = torch.device("cpu"), + device: Union[str, torch.device] = "cpu", save_details: bool = True, ) -> None: """ diff --git a/monai/handlers/roc_auc.py b/monai/handlers/roc_auc.py index 1c29453e89..9a9af601f9 100644 --- a/monai/handlers/roc_auc.py +++ b/monai/handlers/roc_auc.py @@ -61,7 +61,7 @@ def __init__( other_act: Optional[Callable] = None, average: Union[Average, str] = Average.MACRO, output_transform: Callable = lambda x: x, - device: Union[str, torch.device] = torch.device("cpu"), + device: Union[str, torch.device] = "cpu", ) -> None: def _compute_fn(pred, label): return compute_roc_auc( diff --git a/monai/handlers/surface_distance.py b/monai/handlers/surface_distance.py index af18d4c29c..7c2322354a 100644 --- a/monai/handlers/surface_distance.py +++ b/monai/handlers/surface_distance.py @@ -29,7 +29,7 @@ def __init__( symmetric: bool = False, distance_metric: str = "euclidean", output_transform: Callable = lambda x: x, - device: Union[str, torch.device] = torch.device("cpu"), + device: Union[str, torch.device] = "cpu", save_details: bool = True, ) -> None: """ From ba43dae0fa585148edd2cf96d9d0b458f95e60ac Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Mon, 15 Mar 2021 23:13:00 +0800 Subject: [PATCH 4/4] [DLMED] update according to comments Signed-off-by: Nic Ma --- monai/handlers/iteration_metric.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/handlers/iteration_metric.py b/monai/handlers/iteration_metric.py index 3686174b5a..f49c799a21 100644 --- a/monai/handlers/iteration_metric.py +++ b/monai/handlers/iteration_metric.py @@ -77,7 +77,7 @@ def update(self, output: Sequence[torch.Tensor]) -> None: score = self.metric_fn(y_pred, y) if isinstance(score, (tuple, list)): score = score[0] - self._scores.append(score.clone().to(self._device)) + self._scores.append(score.to(self._device)) def compute(self) -> Any: """