From 129c7eecdf23aa526918691db12f3afffd966c00 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Sun, 14 Nov 2021 11:42:22 +0000 Subject: [PATCH 1/8] update pathlike obj Signed-off-by: Wenqi Li --- monai/apps/utils.py | 46 ++++++++++++++++-------------- monai/config/__init__.py | 10 ++++++- monai/config/type_definitions.py | 14 ++++++++- tests/test_download_and_extract.py | 5 ++-- 4 files changed, 50 insertions(+), 25 deletions(-) diff --git a/monai/apps/utils.py b/monai/apps/utils.py index 2961f39720..ab884d1c52 100644 --- a/monai/apps/utils.py +++ b/monai/apps/utils.py @@ -18,10 +18,12 @@ import tempfile import warnings import zipfile +from pathlib import Path from typing import TYPE_CHECKING, Optional from urllib.error import ContentTooShortError, HTTPError, URLError from urllib.request import urlretrieve +from monai.config.type_definitions import PathLike from monai.utils import min_version, optional_import gdown, has_gdown = optional_import("gdown", "3.6") @@ -69,10 +71,10 @@ def get_logger( __all__.append("logger") -def _basename(p): +def _basename(p: PathLike) -> str: """get the last part of the path (removing the trailing slash if it exists)""" sep = os.path.sep + (os.path.altsep or "") + "/ " - return os.path.basename(p.rstrip(sep)) + return Path(f"{p}".rstrip(sep)).name def _download_with_progress(url, filepath, progress: bool = True): @@ -110,7 +112,7 @@ def update_to(self, b: int = 1, bsize: int = 1, tsize: Optional[int] = None): raise e -def check_hash(filepath: str, val: Optional[str] = None, hash_type: str = "md5") -> bool: +def check_hash(filepath: PathLike, val: Optional[str] = None, hash_type: str = "md5") -> bool: """ Verify hash signature of specified file. @@ -145,7 +147,7 @@ def check_hash(filepath: str, val: Optional[str] = None, hash_type: str = "md5") def download_url( - url: str, filepath: str = "", hash_val: Optional[str] = None, hash_type: str = "md5", progress: bool = True + url: str, filepath: PathLike = "", hash_val: Optional[str] = None, hash_type: str = "md5", progress: bool = True ) -> None: """ Download file from specified URL link, support process bar and hash check. @@ -171,9 +173,10 @@ def download_url( """ if not filepath: - filepath = os.path.abspath(os.path.join(".", _basename(url))) + filepath = Path(".", _basename(url)).resolve() logger.info(f"Default downloading to '{filepath}'") - if os.path.exists(filepath): + filepath = Path(filepath) + if filepath.exists(): if not check_hash(filepath, hash_val, hash_type): raise RuntimeError( f"{hash_type} check of existing file failed: filepath={filepath}, expected {hash_type}={hash_val}." @@ -182,21 +185,21 @@ def download_url( return with tempfile.TemporaryDirectory() as tmp_dir: - tmp_name = os.path.join(tmp_dir, f"{_basename(filepath)}") + tmp_name = Path(tmp_dir, _basename(filepath)) if url.startswith("https://drive.google.com"): if not has_gdown: raise RuntimeError("To download files from Google Drive, please install the gdown dependency.") - gdown.download(url, tmp_name, quiet=not progress) + gdown.download(url, f"{tmp_name}", quiet=not progress) else: _download_with_progress(url, tmp_name, progress=progress) - if not os.path.exists(tmp_name): + if not tmp_name.exists(): raise RuntimeError( f"Download of file from {url} to {filepath} failed due to network issue or denied permission." ) - file_dir = os.path.dirname(filepath) + file_dir = filepath.parent if file_dir: os.makedirs(file_dir, exist_ok=True) - shutil.move(tmp_name, filepath) # copy the downloaded to a user-specified cache. + shutil.move(f"{tmp_name}", f"{filepath}") # copy the downloaded to a user-specified cache. logger.info(f"Downloaded: {filepath}") if not check_hash(filepath, hash_val, hash_type): raise RuntimeError( @@ -206,8 +209,8 @@ def download_url( def extractall( - filepath: str, - output_dir: str = ".", + filepath: PathLike, + output_dir: PathLike = ".", hash_val: Optional[str] = None, hash_type: str = "md5", file_type: str = "", @@ -236,24 +239,25 @@ def extractall( """ if has_base: # the extracted files will be in this folder - cache_dir = os.path.join(output_dir, _basename(filepath).split(".")[0]) + cache_dir = Path(output_dir, _basename(filepath).split(".")[0]) else: - cache_dir = output_dir - if os.path.exists(cache_dir) and len(os.listdir(cache_dir)) > 0: + cache_dir = Path(output_dir) + if cache_dir.exists() and len(list(cache_dir.iterdir())) > 0: logger.info(f"Non-empty folder exists in {cache_dir}, skipped extracting.") return + filepath = Path(filepath) if hash_val and not check_hash(filepath, hash_val, hash_type): raise RuntimeError( f"{hash_type} check of compressed file failed: " f"filepath={filepath}, expected {hash_type}={hash_val}." ) logger.info(f"Writing into directory: {output_dir}.") _file_type = file_type.lower().strip() - if filepath.endswith("zip") or _file_type == "zip": + if filepath.name.endswith("zip") or _file_type == "zip": zip_file = zipfile.ZipFile(filepath) zip_file.extractall(output_dir) zip_file.close() return - if filepath.endswith("tar") or filepath.endswith("tar.gz") or "tar" in _file_type: + if filepath.name.endswith("tar") or filepath.name.endswith("tar.gz") or "tar" in _file_type: tar_file = tarfile.open(filepath) tar_file.extractall(output_dir) tar_file.close() @@ -265,8 +269,8 @@ def extractall( def download_and_extract( url: str, - filepath: str = "", - output_dir: str = ".", + filepath: PathLike = "", + output_dir: PathLike = ".", hash_val: Optional[str] = None, hash_type: str = "md5", file_type: str = "", @@ -293,6 +297,6 @@ def download_and_extract( progress: whether to display progress bar. """ with tempfile.TemporaryDirectory() as tmp_dir: - filename = filepath or os.path.join(tmp_dir, f"{_basename(url)}") + filename = filepath or Path(tmp_dir, _basename(url)).resolve() download_url(url=url, filepath=filename, hash_val=hash_val, hash_type=hash_type, progress=progress) extractall(filepath=filename, output_dir=output_dir, file_type=file_type, has_base=has_base) diff --git a/monai/config/__init__.py b/monai/config/__init__.py index c929cb2362..e3f623823c 100644 --- a/monai/config/__init__.py +++ b/monai/config/__init__.py @@ -19,4 +19,12 @@ print_gpu_info, print_system_info, ) -from .type_definitions import DtypeLike, IndexSelection, KeysCollection, NdarrayOrTensor, NdarrayTensor, TensorOrList +from .type_definitions import ( + DtypeLike, + IndexSelection, + KeysCollection, + NdarrayOrTensor, + NdarrayTensor, + PathLike, + TensorOrList, +) diff --git a/monai/config/type_definitions.py b/monai/config/type_definitions.py index 91ac74961b..f5f5fc2626 100644 --- a/monai/config/type_definitions.py +++ b/monai/config/type_definitions.py @@ -9,6 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os from typing import Collection, Hashable, Iterable, Sequence, TypeVar, Union import numpy as np @@ -29,7 +30,15 @@ # may be implemented). Consistent use of the concept and recorded documentation of # the rationale and convention behind it lowers the learning curve for new # developers. For readability, short names are preferred. -__all__ = ["KeysCollection", "IndexSelection", "DtypeLike", "NdarrayTensor", "NdarrayOrTensor", "TensorOrList"] +__all__ = [ + "KeysCollection", + "IndexSelection", + "DtypeLike", + "NdarrayTensor", + "NdarrayOrTensor", + "TensorOrList", + "PathLike", +] #: KeysCollection @@ -66,3 +75,6 @@ #: TensorOrList: The TensorOrList type is used for defining `batch-first Tensor` or `list of channel-first Tensor`. TensorOrList = Union[torch.Tensor, Sequence[torch.Tensor]] + +#: PathLike: The PathLike type is used for defining a file path. +PathLike = Union[str, os.PathLike] diff --git a/tests/test_download_and_extract.py b/tests/test_download_and_extract.py index 54192a1310..164fcd723d 100644 --- a/tests/test_download_and_extract.py +++ b/tests/test_download_and_extract.py @@ -12,6 +12,7 @@ import os import tempfile import unittest +from pathlib import Path from urllib.error import ContentTooShortError, HTTPError from monai.apps import download_and_extract, download_url, extractall @@ -23,8 +24,8 @@ class TestDownloadAndExtract(unittest.TestCase): def test_actions(self): testing_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "testing_data") url = "https://drive.google.com/uc?id=1QsnnkvZyJPcbRoV_ArW8SnE1OTuoVbKE" - filepath = os.path.join(testing_dir, "MedNIST.tar.gz") - output_dir = testing_dir + filepath = Path(testing_dir) / "MedNIST.tar.gz" + output_dir = Path(testing_dir) md5_value = "0bc7306e7427e00ad1c5526a6677552d" try: download_and_extract(url, filepath, output_dir, md5_value) From 508c6a1efd274f476d9274eee9a8a8e1c72b054a Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Mon, 15 Nov 2021 11:55:22 +0000 Subject: [PATCH 2/8] support of pathlike obj Signed-off-by: Wenqi Li --- monai/apps/datasets.py | 40 +++++++++++++++------------- monai/apps/mmars/mmars.py | 29 +++++++++++--------- monai/data/csv_saver.py | 18 ++++++++----- monai/data/decathlon_datalist.py | 18 +++++++------ monai/data/image_reader.py | 41 +++++++++++++++-------------- monai/data/utils.py | 13 ++++----- monai/handlers/utils.py | 4 +-- monai/transforms/post/dictionary.py | 5 ++-- monai/utils/state_cacher.py | 4 ++- tests/test_decathlondataset.py | 3 ++- tests/test_file_basename.py | 5 ++++ tests/test_mednistdataset.py | 3 ++- tests/test_save_classificationd.py | 5 ++-- tests/test_state_cacher.py | 4 ++- tests/test_write_metrics_reports.py | 3 ++- 15 files changed, 111 insertions(+), 84 deletions(-) diff --git a/monai/apps/datasets.py b/monai/apps/datasets.py index a36bcf09b0..f1f75d251c 100644 --- a/monai/apps/datasets.py +++ b/monai/apps/datasets.py @@ -9,13 +9,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os import sys +from pathlib import Path from typing import Callable, Dict, List, Optional, Sequence, Union import numpy as np from monai.apps.utils import download_and_extract +from monai.config.type_definitions import PathLike from monai.data import ( CacheDataset, load_decathlon_datalist, @@ -64,7 +65,7 @@ class MedNISTDataset(Randomizable, CacheDataset): def __init__( self, - root_dir: str, + root_dir: PathLike, section: str, transform: Union[Sequence[Callable], Callable] = (), download: bool = False, @@ -75,19 +76,20 @@ def __init__( cache_rate: float = 1.0, num_workers: int = 0, ) -> None: - if not os.path.isdir(root_dir): + root_dir = Path(root_dir) + if not root_dir.is_dir(): raise ValueError("Root directory root_dir must be a directory.") self.section = section self.val_frac = val_frac self.test_frac = test_frac self.set_random_state(seed=seed) - tarfile_name = os.path.join(root_dir, self.compressed_file_name) - dataset_dir = os.path.join(root_dir, self.dataset_folder_name) + tarfile_name = root_dir / self.compressed_file_name + dataset_dir = root_dir / self.dataset_folder_name self.num_class = 0 if download: download_and_extract(self.resource, tarfile_name, root_dir, self.md5) - if not os.path.exists(dataset_dir): + if not dataset_dir.is_dir(): raise RuntimeError( f"Cannot find dataset directory: {dataset_dir}, please use download=True to download it." ) @@ -105,19 +107,17 @@ def get_num_classes(self) -> int: """Get number of classes.""" return self.num_class - def _generate_data_list(self, dataset_dir: str) -> List[Dict]: + def _generate_data_list(self, dataset_dir: PathLike) -> List[Dict]: """ Raises: ValueError: When ``section`` is not one of ["training", "validation", "test"]. """ - class_names = sorted(x for x in os.listdir(dataset_dir) if os.path.isdir(os.path.join(dataset_dir, x))) + dataset_dir = Path(dataset_dir) + class_names = sorted(x for x in dataset_dir.iterdir() if (dataset_dir / x).is_dir()) self.num_class = len(class_names) image_files = [ - [ - os.path.join(dataset_dir, class_names[i], x) - for x in os.listdir(os.path.join(dataset_dir, class_names[i])) - ] + [dataset_dir.joinpath(class_names[i], x) for x in (dataset_dir / class_names[i]).iterdir()] for i in range(self.num_class) ] num_each = [len(image_files[i]) for i in range(self.num_class)] @@ -234,7 +234,7 @@ class DecathlonDataset(Randomizable, CacheDataset): def __init__( self, - root_dir: str, + root_dir: PathLike, task: str, section: str, transform: Union[Sequence[Callable], Callable] = (), @@ -245,19 +245,20 @@ def __init__( cache_rate: float = 1.0, num_workers: int = 0, ) -> None: - if not os.path.isdir(root_dir): + root_dir = Path(root_dir) + if not root_dir.is_dir(): raise ValueError("Root directory root_dir must be a directory.") self.section = section self.val_frac = val_frac self.set_random_state(seed=seed) if task not in self.resource: raise ValueError(f"Unsupported task: {task}, available options are: {list(self.resource.keys())}.") - dataset_dir = os.path.join(root_dir, task) + dataset_dir = root_dir / task tarfile_name = f"{dataset_dir}.tar" if download: download_and_extract(self.resource[task], tarfile_name, root_dir, self.md5[task]) - if not os.path.exists(dataset_dir): + if not dataset_dir.exists(): raise RuntimeError( f"Cannot find dataset directory: {dataset_dir}, please use download=True to download it." ) @@ -275,7 +276,7 @@ def __init__( "numTraining", "numTest", ] - self._properties = load_decathlon_properties(os.path.join(dataset_dir, "dataset.json"), property_keys) + self._properties = load_decathlon_properties(dataset_dir / "dataset.json", property_keys) if transform == (): transform = LoadImaged(["image", "label"]) CacheDataset.__init__( @@ -304,9 +305,10 @@ def get_properties(self, keys: Optional[Union[Sequence[str], str]] = None): return {key: self._properties[key] for key in ensure_tuple(keys)} return {} - def _generate_data_list(self, dataset_dir: str) -> List[Dict]: + def _generate_data_list(self, dataset_dir: PathLike) -> List[Dict]: + dataset_dir = Path(dataset_dir) section = "training" if self.section in ["training", "validation"] else "test" - datalist = load_decathlon_datalist(os.path.join(dataset_dir, "dataset.json"), True, section) + datalist = load_decathlon_datalist(dataset_dir / "dataset.json", True, section) return self._split_datalist(datalist) def _split_datalist(self, datalist: List[Dict]) -> List[Dict]: diff --git a/monai/apps/mmars/mmars.py b/monai/apps/mmars/mmars.py index 7923508376..024581838a 100644 --- a/monai/apps/mmars/mmars.py +++ b/monai/apps/mmars/mmars.py @@ -17,14 +17,15 @@ """ import json -import os import warnings -from typing import Mapping, Union +from pathlib import Path +from typing import Mapping, Optional, Union import torch import monai.networks.nets as monai_nets from monai.apps.utils import download_and_extract, logger +from monai.config.type_definitions import PathLike from monai.utils.module import optional_import from .model_desc import MODEL_DESC @@ -98,7 +99,9 @@ def _get_ngc_doc_url(model_name: str, model_prefix=""): return f"https://ngc.nvidia.com/catalog/models/{model_prefix}{model_name}" -def download_mmar(item, mmar_dir=None, progress: bool = True, api: bool = False, version: int = -1): +def download_mmar( + item, mmar_dir: Optional[PathLike] = None, progress: bool = True, api: bool = False, version: int = -1 +) -> Path: """ Download and extract Medical Model Archive (MMAR) from Nvidia Clara Train. @@ -128,10 +131,10 @@ def download_mmar(item, mmar_dir=None, progress: bool = True, api: bool = False, if not mmar_dir: get_dir, has_home = optional_import("torch.hub", name="get_dir") if has_home: - mmar_dir = os.path.join(get_dir(), "mmars") + mmar_dir = Path(get_dir()) / "mmars" else: raise ValueError("mmar_dir=None, but no suitable default directory computed. Upgrade Pytorch to 1.6+ ?") - + mmar_dir = Path(mmar_dir) if api: model_dict = _get_all_ngc_models(item) if len(model_dict) == 0: @@ -140,10 +143,10 @@ def download_mmar(item, mmar_dir=None, progress: bool = True, api: bool = False, for k, v in model_dict.items(): ver = v["latest"] if version == -1 else str(version) download_url = _get_ngc_url(k, ver) - model_dir = os.path.join(mmar_dir, v["name"]) + model_dir = mmar_dir / v["name"] download_and_extract( url=download_url, - filepath=os.path.join(mmar_dir, f'{v["name"]}_{ver}.zip'), + filepath=mmar_dir / f'{v["name"]}_{ver}.zip', output_dir=model_dir, hash_val=None, hash_type="md5", @@ -161,11 +164,11 @@ def download_mmar(item, mmar_dir=None, progress: bool = True, api: bool = False, if version > 0: ver = str(version) model_fullname = f"{item[Keys.NAME]}_{ver}" - model_dir = os.path.join(mmar_dir, model_fullname) + model_dir = mmar_dir / model_fullname model_url = item.get(Keys.URL) or _get_ngc_url(item[Keys.NAME], version=ver, model_prefix="nvidia/med/") download_and_extract( url=model_url, - filepath=os.path.join(mmar_dir, f"{model_fullname}.{item[Keys.FILE_TYPE]}"), + filepath=mmar_dir / f"{model_fullname}.{item[Keys.FILE_TYPE]}", output_dir=model_dir, hash_val=item[Keys.HASH_VAL], hash_type=item[Keys.HASH_TYPE], @@ -178,7 +181,7 @@ def download_mmar(item, mmar_dir=None, progress: bool = True, api: bool = False, def load_from_mmar( item, - mmar_dir=None, + mmar_dir: Optional[PathLike] = None, progress: bool = True, version: int = -1, map_location=None, @@ -212,11 +215,11 @@ def load_from_mmar( if not isinstance(item, Mapping): item = get_model_spec(item) model_dir = download_mmar(item=item, mmar_dir=mmar_dir, progress=progress, version=version) - model_file = os.path.join(model_dir, item[Keys.MODEL_FILE]) + model_file = model_dir / item[Keys.MODEL_FILE] logger.info(f'\n*** "{item[Keys.ID]}" available at {model_dir}.') # loading with `torch.jit.load` - if f"{model_file}".endswith(".ts"): + if model_file.name.endswith(".ts"): if not pretrained: warnings.warn("Loading a ScriptModule, 'pretrained' option ignored.") if weights_only: @@ -232,7 +235,7 @@ def load_from_mmar( model_config = _get_val(dict(model_dict).get("train_conf", {}), key=model_key, default={}) if not model_config: # 2. search json CONFIG_FILE for model config spec. - json_path = os.path.join(model_dir, item.get(Keys.CONFIG_FILE, "config_train.json")) + json_path = model_dir / item.get(Keys.CONFIG_FILE, "config_train.json") with open(json_path) as f: conf_dict = json.load(f) conf_dict = dict(conf_dict) diff --git a/monai/data/csv_saver.py b/monai/data/csv_saver.py index f9c814679d..d4005b528e 100644 --- a/monai/data/csv_saver.py +++ b/monai/data/csv_saver.py @@ -12,11 +12,13 @@ import os import warnings from collections import OrderedDict +from pathlib import Path from typing import Dict, Optional, Union import numpy as np import torch +from monai.config.type_definitions import PathLike from monai.utils import ImageMetaKey as Key @@ -32,7 +34,11 @@ class CSVSaver: """ def __init__( - self, output_dir: str = "./", filename: str = "predictions.csv", overwrite: bool = True, flush: bool = False + self, + output_dir: PathLike = "./", + filename: str = "predictions.csv", + overwrite: bool = True, + flush: bool = False, ) -> None: """ Args: @@ -44,12 +50,12 @@ def __init__( default to False. """ - self.output_dir = output_dir + self.output_dir = Path(output_dir) self._cache_dict: OrderedDict = OrderedDict() if not (isinstance(filename, str) and filename[-4:] == ".csv"): warnings.warn("CSV filename is not a string ends with '.csv'.") - self._filepath = os.path.join(output_dir, filename) - if os.path.exists(self._filepath) and overwrite: + self._filepath = self.output_dir / filename + if self._filepath.exists() and overwrite: os.remove(self._filepath) self.flush = flush @@ -60,8 +66,8 @@ def finalize(self) -> None: Writes the cached dict to a csv """ - if not os.path.exists(self.output_dir): - os.makedirs(self.output_dir) + if not self.output_dir.exists(): + self.output_dir.mkdir(parents=True, exist_ok=True) with open(self._filepath, "a") as f: for k, v in self._cache_dict.items(): f.write(k) diff --git a/monai/data/decathlon_datalist.py b/monai/data/decathlon_datalist.py index dd0a8003be..39312805a7 100644 --- a/monai/data/decathlon_datalist.py +++ b/monai/data/decathlon_datalist.py @@ -15,7 +15,7 @@ from pathlib import Path from typing import Dict, List, Optional, Sequence, Union, overload -from monai.config import KeysCollection +from monai.config import KeysCollection, PathLike from monai.data.utils import partition_dataset, select_cross_validation_folds from monai.utils import ensure_tuple @@ -84,7 +84,7 @@ def _append_paths(base_dir: str, is_segmentation: bool, items: List[Dict]) -> Li def load_decathlon_datalist( - data_list_file_path: str, + data_list_file_path: PathLike, is_segmentation: bool = True, data_list_key: str = "training", base_dir: Optional[str] = None, @@ -114,7 +114,8 @@ def load_decathlon_datalist( ] """ - if not os.path.isfile(data_list_file_path): + data_list_file_path = Path(data_list_file_path) + if not data_list_file_path.is_file(): raise ValueError(f"Data list file {data_list_file_path} does not exist.") with open(data_list_file_path) as json_file: json_data = json.load(json_file) @@ -125,12 +126,12 @@ def load_decathlon_datalist( expected_data = [{"image": i} for i in expected_data] if base_dir is None: - base_dir = os.path.dirname(data_list_file_path) + base_dir = data_list_file_path.parent return _append_paths(base_dir, is_segmentation, expected_data) -def load_decathlon_properties(data_property_file_path: str, property_keys: Union[Sequence[str], str]) -> Dict: +def load_decathlon_properties(data_property_file_path: PathLike, property_keys: Union[Sequence[str], str]) -> Dict: """Load the properties from the JSON file contains data property with specified `property_keys`. Args: @@ -141,7 +142,8 @@ def load_decathlon_properties(data_property_file_path: str, property_keys: Union `modality`, `labels`, `numTraining`, `numTest`, etc. """ - if not os.path.isfile(data_property_file_path): + data_property_file_path = Path(data_property_file_path) + if not data_property_file_path.is_file(): raise ValueError(f"Data property file {data_property_file_path} does not exist.") with open(data_property_file_path) as json_file: json_data = json.load(json_file) @@ -187,8 +189,8 @@ def check_missing_files( if not isinstance(f, (str, Path)): raise ValueError(f"filepath of key `{k}` must be a string or a list of strings, but got: {f}.") if isinstance(root_dir, (str, Path)): - f = os.path.join(root_dir, f) - if not os.path.exists(f): + f = Path(root_dir).joinpath(f) + if not f.exists(): missing_files.append(f) return missing_files diff --git a/monai/data/image_reader.py b/monai/data/image_reader.py index 4830b56aa8..4108421ae9 100644 --- a/monai/data/image_reader.py +++ b/monai/data/image_reader.py @@ -9,15 +9,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os import warnings from abc import ABC, abstractmethod +from pathlib import Path from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Sequence, Tuple, Union import numpy as np from torch.utils.data._utils.collate import np_str_obj_array_pattern -from monai.config import DtypeLike, KeysCollection +from monai.config import DtypeLike, KeysCollection, PathLike from monai.data.utils import correct_nifti_header_if_necessary from monai.transforms.utility.array import EnsureChannelFirst from monai.utils import ensure_tuple, ensure_tuple_rep, optional_import, require_pkg @@ -64,7 +64,7 @@ class ImageReader(ABC): """ @abstractmethod - def verify_suffix(self, filename: Union[Sequence[str], str]) -> bool: + def verify_suffix(self, filename: Union[Sequence[PathLike], PathLike]) -> bool: """ Verify whether the specified `filename` is supported by the current reader. This method should return True if the reader is able to read the format suggested by the @@ -78,7 +78,7 @@ def verify_suffix(self, filename: Union[Sequence[str], str]) -> bool: raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") @abstractmethod - def read(self, data: Union[Sequence[str], str], **kwargs) -> Union[Sequence[Any], Any]: + def read(self, data: Union[Sequence[PathLike], PathLike], **kwargs) -> Union[Sequence[Any], Any]: """ Read image data from specified file or files. Note that it returns a data object or a sequence of data objects. @@ -166,7 +166,7 @@ def __init__(self, channel_dim: Optional[int] = None, series_name: str = "", **k self.channel_dim = channel_dim self.series_name = series_name - def verify_suffix(self, filename: Union[Sequence[str], str]) -> bool: + def verify_suffix(self, filename: Union[Sequence[PathLike], PathLike]) -> bool: """ Verify whether the specified file or files format is supported by ITK reader. @@ -177,7 +177,7 @@ def verify_suffix(self, filename: Union[Sequence[str], str]) -> bool: """ return has_itk - def read(self, data: Union[Sequence[str], str], **kwargs): + def read(self, data: Union[Sequence[PathLike], PathLike], **kwargs): """ Read image data from specified file or files, it can read a list of `no-channel` images and stack them together as multi-channels data in `get_data()`. @@ -193,11 +193,12 @@ def read(self, data: Union[Sequence[str], str], **kwargs): """ img_ = [] - filenames: Sequence[str] = ensure_tuple(data) + filenames: Sequence[PathLike] = ensure_tuple(data) kwargs_ = self.kwargs.copy() kwargs_.update(kwargs) for name in filenames: - if os.path.isdir(name): + name = Path(name) + if Path(name).is_dir(): # read DICOM series # https://itk.org/ITKExamples/src/IO/GDCM/ReadDICOMSeriesAndWrite3DImage names_generator = itk.GDCMSeriesFileNames.New() @@ -348,7 +349,7 @@ def __init__( self.dtype = dtype self.kwargs = kwargs - def verify_suffix(self, filename: Union[Sequence[str], str]) -> bool: + def verify_suffix(self, filename: Union[Sequence[PathLike], PathLike]) -> bool: """ Verify whether the specified file or files format is supported by Nibabel reader. @@ -360,7 +361,7 @@ def verify_suffix(self, filename: Union[Sequence[str], str]) -> bool: suffixes: Sequence[str] = ["nii", "nii.gz"] return has_nib and is_supported_format(filename, suffixes) - def read(self, data: Union[Sequence[str], str], **kwargs): + def read(self, data: Union[Sequence[PathLike], PathLike], **kwargs): """ Read image data from specified file or files, it can read a list of `no-channel` images and stack them together as multi-channels data in `get_data()`. @@ -375,7 +376,7 @@ def read(self, data: Union[Sequence[str], str], **kwargs): """ img_: List[Nifti1Image] = [] - filenames: Sequence[str] = ensure_tuple(data) + filenames: Sequence[PathLike] = ensure_tuple(data) kwargs_ = self.kwargs.copy() kwargs_.update(kwargs) for name in filenames: @@ -503,7 +504,7 @@ def __init__(self, npz_keys: Optional[KeysCollection] = None, channel_dim: Optio self.channel_dim = channel_dim self.kwargs = kwargs - def verify_suffix(self, filename: Union[Sequence[str], str]) -> bool: + def verify_suffix(self, filename: Union[Sequence[PathLike], PathLike]) -> bool: """ Verify whether the specified file or files format is supported by Numpy reader. @@ -514,7 +515,7 @@ def verify_suffix(self, filename: Union[Sequence[str], str]) -> bool: suffixes: Sequence[str] = ["npz", "npy"] return is_supported_format(filename, suffixes) - def read(self, data: Union[Sequence[str], str], **kwargs): + def read(self, data: Union[Sequence[PathLike], PathLike], **kwargs): """ Read image data from specified file or files, it can read a list of `no-channel` data files and stack them together as multi-channels data in `get_data()`. @@ -529,7 +530,7 @@ def read(self, data: Union[Sequence[str], str], **kwargs): """ img_: List[Nifti1Image] = [] - filenames: Sequence[str] = ensure_tuple(data) + filenames: Sequence[PathLike] = ensure_tuple(data) kwargs_ = self.kwargs.copy() kwargs_.update(kwargs) for name in filenames: @@ -593,7 +594,7 @@ def __init__(self, converter: Optional[Callable] = None, **kwargs): self.converter = converter self.kwargs = kwargs - def verify_suffix(self, filename: Union[Sequence[str], str]) -> bool: + def verify_suffix(self, filename: Union[Sequence[PathLike], PathLike]) -> bool: """ Verify whether the specified file or files format is supported by PIL reader. @@ -604,7 +605,7 @@ def verify_suffix(self, filename: Union[Sequence[str], str]) -> bool: suffixes: Sequence[str] = ["png", "jpg", "jpeg", "bmp"] return has_pil and is_supported_format(filename, suffixes) - def read(self, data: Union[Sequence[str], str, np.ndarray], **kwargs): + def read(self, data: Union[Sequence[PathLike], PathLike, np.ndarray], **kwargs): """ Read image data from specified file or files, it can read a list of `no-channel` images and stack them together as multi-channels data in `get_data()`. @@ -619,7 +620,7 @@ def read(self, data: Union[Sequence[str], str, np.ndarray], **kwargs): """ img_: List[PILImage.Image] = [] - filenames: Sequence[str] = ensure_tuple(data) + filenames: Sequence[PathLike] = ensure_tuple(data) kwargs_ = self.kwargs.copy() kwargs_.update(kwargs) for name in filenames: @@ -708,7 +709,7 @@ def _set_reader(backend: str): return TiffFile raise ValueError("`backend` should be 'cuCIM', 'OpenSlide' or 'TiffFile'.") - def verify_suffix(self, filename: Union[Sequence[str], str]) -> bool: + def verify_suffix(self, filename: Union[Sequence[PathLike], PathLike]) -> bool: """ Verify whether the specified file or files format is supported by WSI reader. @@ -718,7 +719,7 @@ def verify_suffix(self, filename: Union[Sequence[str], str]) -> bool: """ return is_supported_format(filename, ["tif", "tiff"]) - def read(self, data: Union[Sequence[str], str, np.ndarray], **kwargs): + def read(self, data: Union[Sequence[PathLike], PathLike, np.ndarray], **kwargs): """ Read image data from given file or list of files. @@ -731,7 +732,7 @@ def read(self, data: Union[Sequence[str], str, np.ndarray], **kwargs): """ img_: List = [] - filenames: Sequence[str] = ensure_tuple(data) + filenames: Sequence[PathLike] = ensure_tuple(data) for name in filenames: img = self.wsi_reader(name) if self.backend == "openslide": diff --git a/monai/data/utils.py b/monai/data/utils.py index 39f2aeb137..2e5efce7cd 100644 --- a/monai/data/utils.py +++ b/monai/data/utils.py @@ -19,13 +19,14 @@ from copy import deepcopy from functools import reduce from itertools import product, starmap -from pathlib import Path, PurePath +from pathlib import PurePath from typing import Any, Dict, Generator, Iterable, List, Mapping, Optional, Sequence, Tuple, Union import numpy as np import torch from torch.utils.data._utils.collate import default_collate +from monai.config.type_definitions import PathLike from monai.networks.layers.simplelayers import GaussianFilter from monai.utils import ( MAX_SEED, @@ -678,9 +679,9 @@ def to_affine_nd(r: Union[np.ndarray, int], affine: np.ndarray) -> np.ndarray: def create_file_basename( postfix: str, - input_file_name: str, - folder_path: Union[Path, str], - data_root_dir: str = "", + input_file_name: PathLike, + folder_path: PathLike, + data_root_dir: PathLike = "", separate_folder: bool = True, patch_index: Optional[int] = None, ) -> str: @@ -795,7 +796,7 @@ def compute_importance_map( return importance_map -def is_supported_format(filename: Union[Sequence[str], str], suffixes: Sequence[str]) -> bool: +def is_supported_format(filename: Union[Sequence[PathLike], PathLike], suffixes: Sequence[str]) -> bool: """ Verify whether the specified file or files format match supported suffixes. If supported suffixes is None, skip the verification and return True. @@ -806,7 +807,7 @@ def is_supported_format(filename: Union[Sequence[str], str], suffixes: Sequence[ suffixes: all the supported image suffixes of current reader, must be a list of lower case suffixes. """ - filenames: Sequence[str] = ensure_tuple(filename) + filenames: Sequence[PathLike] = ensure_tuple(filename) for name in filenames: tokens: Sequence[str] = PurePath(name).suffixes if len(tokens) == 0 or all("." + s.lower() not in "".join(tokens) for s in suffixes): diff --git a/monai/handlers/utils.py b/monai/handlers/utils.py index 3567dbac03..06766c3e14 100644 --- a/monai/handlers/utils.py +++ b/monai/handlers/utils.py @@ -16,7 +16,7 @@ import numpy as np import torch -from monai.config import IgniteInfo, KeysCollection +from monai.config import IgniteInfo, KeysCollection, PathLike from monai.utils import ensure_tuple, look_up_option, min_version, optional_import idist, _ = optional_import("ignite", IgniteInfo.OPT_IMPORT_VERSION, min_version, "distributed") @@ -51,7 +51,7 @@ def stopping_fn(engine: Engine): def write_metrics_reports( - save_dir: str, + save_dir: PathLike, images: Optional[Sequence[str]], metrics: Optional[Dict[str, Union[torch.Tensor, np.ndarray]]], metric_details: Optional[Dict[str, Union[torch.Tensor, np.ndarray]]], diff --git a/monai/transforms/post/dictionary.py b/monai/transforms/post/dictionary.py index 9644dc4f32..5201d6a623 100644 --- a/monai/transforms/post/dictionary.py +++ b/monai/transforms/post/dictionary.py @@ -21,8 +21,7 @@ import torch -from monai.config import KeysCollection -from monai.config.type_definitions import NdarrayOrTensor +from monai.config.type_definitions import KeysCollection, NdarrayOrTensor, PathLike from monai.data.csv_saver import CSVSaver from monai.transforms.inverse import InvertibleTransform from monai.transforms.post.array import ( @@ -644,7 +643,7 @@ def __init__( meta_keys: Optional[KeysCollection] = None, meta_key_postfix: str = "meta_dict", saver: Optional[CSVSaver] = None, - output_dir: str = "./", + output_dir: PathLike = "./", filename: str = "predictions.csv", overwrite: bool = True, flush: bool = True, diff --git a/monai/utils/state_cacher.py b/monai/utils/state_cacher.py index 35ac72916e..ce394372f7 100644 --- a/monai/utils/state_cacher.py +++ b/monai/utils/state_cacher.py @@ -16,6 +16,8 @@ import torch +from monai.config.type_definitions import PathLike + __all__ = ["StateCacher"] @@ -34,7 +36,7 @@ class StateCacher: >>> model.load_state_dict(state_cacher.retrieve("model")) """ - def __init__(self, in_memory: bool, cache_dir: Optional[str] = None, allow_overwrite: bool = True) -> None: + def __init__(self, in_memory: bool, cache_dir: Optional[PathLike] = None, allow_overwrite: bool = True) -> None: """Constructor. Args: diff --git a/tests/test_decathlondataset.py b/tests/test_decathlondataset.py index db07d361db..0cb0840736 100644 --- a/tests/test_decathlondataset.py +++ b/tests/test_decathlondataset.py @@ -13,6 +13,7 @@ import shutil import unittest from urllib.error import ContentTooShortError, HTTPError +from pathlib import Path from monai.apps import DecathlonDataset from monai.transforms import AddChanneld, Compose, LoadImaged, ScaleIntensityd, ToTensord @@ -68,7 +69,7 @@ def _test_dataset(dataset): self.assertEqual(len(data), 208) # test dataset properties - data = DecathlonDataset(root_dir=testing_dir, task="Task04_Hippocampus", section="validation", download=False) + data = DecathlonDataset(root_dir=Path(testing_dir), task="Task04_Hippocampus", section="validation", download=False) properties = data.get_properties(keys="labels") self.assertDictEqual(properties["labels"], {"0": "background", "1": "Anterior", "2": "Posterior"}) diff --git a/tests/test_file_basename.py b/tests/test_file_basename.py index 77e77fabc5..cd1a08afb1 100644 --- a/tests/test_file_basename.py +++ b/tests/test_file_basename.py @@ -12,6 +12,7 @@ import os import tempfile import unittest +from pathlib import Path from monai.data.utils import create_file_basename @@ -69,6 +70,10 @@ def test_value(self): expected = os.path.join(output_tmp, "test", "test_post_8") self.assertEqual(result, expected) + result = create_file_basename("post", Path("test.tar.gz"), Path(output_tmp), Path("foo"), True, 8) + expected = os.path.join(output_tmp, "test", "test_post_8") + self.assertEqual(result, expected) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_mednistdataset.py b/tests/test_mednistdataset.py index 2e27f4ba95..9ba324e31d 100644 --- a/tests/test_mednistdataset.py +++ b/tests/test_mednistdataset.py @@ -13,6 +13,7 @@ import shutil import unittest from urllib.error import ContentTooShortError, HTTPError +from pathlib import Path from monai.apps import MedNISTDataset from monai.transforms import AddChanneld, Compose, LoadImaged, ScaleIntensityd, ToTensord @@ -53,7 +54,7 @@ def _test_dataset(dataset): _test_dataset(data) # testing from - data = MedNISTDataset(root_dir=testing_dir, transform=transform, section="test", download=False) + data = MedNISTDataset(root_dir=Path(testing_dir), transform=transform, section="test", download=False) data.get_num_classes() _test_dataset(data) data = MedNISTDataset(root_dir=testing_dir, section="test", download=False) diff --git a/tests/test_save_classificationd.py b/tests/test_save_classificationd.py index 26ce3176e8..1418bf22ee 100644 --- a/tests/test_save_classificationd.py +++ b/tests/test_save_classificationd.py @@ -13,6 +13,7 @@ import os import tempfile import unittest +from pathlib import Path import numpy as np import torch @@ -39,7 +40,7 @@ def test_saved_content(self): }, ] - saver = CSVSaver(output_dir=tempdir, filename="predictions2.csv", overwrite=False, flush=False) + saver = CSVSaver(output_dir=Path(tempdir), filename="predictions2.csv", overwrite=False, flush=False) # set up test transforms post_trans = Compose( [ @@ -49,7 +50,7 @@ def test_saved_content(self): keys="pred", saver=None, meta_keys=None, - output_dir=tempdir, + output_dir=Path(tempdir), filename="predictions1.csv", overwrite=True, ), diff --git a/tests/test_state_cacher.py b/tests/test_state_cacher.py index 5835bfdb5c..a1fbf171c8 100644 --- a/tests/test_state_cacher.py +++ b/tests/test_state_cacher.py @@ -11,6 +11,7 @@ import unittest from os.path import exists, join +from pathlib import Path from tempfile import gettempdir import torch @@ -23,8 +24,9 @@ TEST_CASE_0 = [torch.Tensor([1]).to(DEVICE), {"in_memory": True}] TEST_CASE_1 = [torch.Tensor([1]).to(DEVICE), {"in_memory": False, "cache_dir": gettempdir()}] TEST_CASE_2 = [torch.Tensor([1]).to(DEVICE), {"in_memory": False, "allow_overwrite": False}] +TEST_CASE_3 = [torch.Tensor([1]).to(DEVICE), {"in_memory": False, "cache_dir": Path(gettempdir())}] -TEST_CASES = [TEST_CASE_0, TEST_CASE_1, TEST_CASE_2] +TEST_CASES = [TEST_CASE_0, TEST_CASE_1, TEST_CASE_2, TEST_CASE_3] class TestStateCacher(unittest.TestCase): diff --git a/tests/test_write_metrics_reports.py b/tests/test_write_metrics_reports.py index f736db9961..3c93f20144 100644 --- a/tests/test_write_metrics_reports.py +++ b/tests/test_write_metrics_reports.py @@ -13,6 +13,7 @@ import os import tempfile import unittest +from pathlib import Path import torch @@ -23,7 +24,7 @@ class TestWriteMetricsReports(unittest.TestCase): def test_content(self): with tempfile.TemporaryDirectory() as tempdir: write_metrics_reports( - save_dir=tempdir, + save_dir=Path(tempdir), images=["filepath1", "filepath2"], metrics={"metric1": 1, "metric2": 2}, metric_details={"metric3": torch.tensor([[1, 2], [2, 3]]), "metric4": torch.tensor([[5, 6], [7, 8]])}, From f2658c38acaa9496056537d4bd6d4cc4a0ec2115 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Mon, 15 Nov 2021 12:04:43 +0000 Subject: [PATCH 3/8] review path obj Signed-off-by: Wenqi Li --- monai/apps/mmars/mmars.py | 2 +- monai/data/decathlon_datalist.py | 14 +++++++------- monai/data/image_reader.py | 2 +- tests/test_decathlondataset.py | 6 ++++-- tests/test_load_decathlon_datalist.py | 3 ++- tests/test_mednistdataset.py | 2 +- tests/test_mmar_download.py | 3 ++- 7 files changed, 18 insertions(+), 14 deletions(-) diff --git a/monai/apps/mmars/mmars.py b/monai/apps/mmars/mmars.py index 024581838a..6356e9a027 100644 --- a/monai/apps/mmars/mmars.py +++ b/monai/apps/mmars/mmars.py @@ -101,7 +101,7 @@ def _get_ngc_doc_url(model_name: str, model_prefix=""): def download_mmar( item, mmar_dir: Optional[PathLike] = None, progress: bool = True, api: bool = False, version: int = -1 -) -> Path: +): """ Download and extract Medical Model Archive (MMAR) from Nvidia Clara Train. diff --git a/monai/data/decathlon_datalist.py b/monai/data/decathlon_datalist.py index 39312805a7..14b5f4f0f6 100644 --- a/monai/data/decathlon_datalist.py +++ b/monai/data/decathlon_datalist.py @@ -21,12 +21,12 @@ @overload -def _compute_path(base_dir: str, element: str, check_path: bool = False) -> str: +def _compute_path(base_dir: PathLike, element: PathLike, check_path: bool = False) -> str: ... @overload -def _compute_path(base_dir: str, element: List[str], check_path: bool = False) -> List[str]: +def _compute_path(base_dir: PathLike, element: List[PathLike], check_path: bool = False) -> List[str]: ... @@ -43,24 +43,24 @@ def _compute_path(base_dir, element, check_path=False): """ - def _join_path(base_dir: str, item: str): + def _join_path(base_dir: PathLike, item: PathLike): result = os.path.normpath(os.path.join(base_dir, item)) if check_path and not os.path.exists(result): # if not an existing path, don't join with base dir return item return result - if isinstance(element, str): + if isinstance(element, (str, os.PathLike)): return _join_path(base_dir, element) if isinstance(element, list): for e in element: - if not isinstance(e, str): + if not isinstance(e, (str, os.PathLike)): return element return [_join_path(base_dir, e) for e in element] return element -def _append_paths(base_dir: str, is_segmentation: bool, items: List[Dict]) -> List[Dict]: +def _append_paths(base_dir: PathLike, is_segmentation: bool, items: List[Dict]) -> List[Dict]: """ Args: base_dir: the base directory of the dataset. @@ -87,7 +87,7 @@ def load_decathlon_datalist( data_list_file_path: PathLike, is_segmentation: bool = True, data_list_key: str = "training", - base_dir: Optional[str] = None, + base_dir: Optional[PathLike] = None, ) -> List[Dict]: """Load image/label paths of decathlon challenge from JSON file diff --git a/monai/data/image_reader.py b/monai/data/image_reader.py index 4108421ae9..9125cd0a94 100644 --- a/monai/data/image_reader.py +++ b/monai/data/image_reader.py @@ -535,7 +535,7 @@ def read(self, data: Union[Sequence[PathLike], PathLike], **kwargs): kwargs_.update(kwargs) for name in filenames: img = np.load(name, allow_pickle=True, **kwargs_) - if name.endswith(".npz"): + if Path(name).name.endswith(".npz"): # load expected items from NPZ file npz_keys = [f"arr_{i}" for i in range(len(img))] if self.npz_keys is None else self.npz_keys for k in npz_keys: diff --git a/tests/test_decathlondataset.py b/tests/test_decathlondataset.py index 0cb0840736..29ea3a3151 100644 --- a/tests/test_decathlondataset.py +++ b/tests/test_decathlondataset.py @@ -12,8 +12,8 @@ import os import shutil import unittest -from urllib.error import ContentTooShortError, HTTPError from pathlib import Path +from urllib.error import ContentTooShortError, HTTPError from monai.apps import DecathlonDataset from monai.transforms import AddChanneld, Compose, LoadImaged, ScaleIntensityd, ToTensord @@ -69,7 +69,9 @@ def _test_dataset(dataset): self.assertEqual(len(data), 208) # test dataset properties - data = DecathlonDataset(root_dir=Path(testing_dir), task="Task04_Hippocampus", section="validation", download=False) + data = DecathlonDataset( + root_dir=Path(testing_dir), task="Task04_Hippocampus", section="validation", download=False + ) properties = data.get_properties(keys="labels") self.assertDictEqual(properties["labels"], {"0": "background", "1": "Anterior", "2": "Posterior"}) diff --git a/tests/test_load_decathlon_datalist.py b/tests/test_load_decathlon_datalist.py index fe7ff6f8a2..d2113fccfa 100644 --- a/tests/test_load_decathlon_datalist.py +++ b/tests/test_load_decathlon_datalist.py @@ -13,6 +13,7 @@ import os import tempfile import unittest +from pathlib import Path from monai.data import load_decathlon_datalist @@ -115,7 +116,7 @@ def test_additional_items(self): file_path = os.path.join(tempdir, "test_data.json") with open(file_path, "w") as json_file: json_file.write(json_str) - result = load_decathlon_datalist(file_path, True, "training", tempdir) + result = load_decathlon_datalist(file_path, True, "training", Path(tempdir)) self.assertEqual(result[0]["image"], os.path.join(tempdir, "spleen_19.nii.gz")) self.assertEqual(result[0]["label"], os.path.join(tempdir, "spleen_19.nii.gz")) self.assertEqual(result[1]["mask"], os.path.join(tempdir, "mask31.txt")) diff --git a/tests/test_mednistdataset.py b/tests/test_mednistdataset.py index 9ba324e31d..f8d01902a5 100644 --- a/tests/test_mednistdataset.py +++ b/tests/test_mednistdataset.py @@ -12,8 +12,8 @@ import os import shutil import unittest -from urllib.error import ContentTooShortError, HTTPError from pathlib import Path +from urllib.error import ContentTooShortError, HTTPError from monai.apps import MedNISTDataset from monai.transforms import AddChanneld, Compose, LoadImaged, ScaleIntensityd, ToTensord diff --git a/tests/test_mmar_download.py b/tests/test_mmar_download.py index 49f0b77269..bc098d74c4 100644 --- a/tests/test_mmar_download.py +++ b/tests/test_mmar_download.py @@ -12,6 +12,7 @@ import os import tempfile import unittest +from pathlib import Path from urllib.error import ContentTooShortError, HTTPError import numpy as np @@ -114,7 +115,7 @@ def test_download(self, idx): download_mmar(idx, progress=False) # repeated to check caching with tempfile.TemporaryDirectory() as tmp_dir: download_mmar(idx, mmar_dir=tmp_dir, progress=False) - download_mmar(idx, mmar_dir=tmp_dir, progress=False, version=1) # repeated to check caching + download_mmar(idx, mmar_dir=Path(tmp_dir), progress=False, version=1) # repeated to check caching self.assertTrue(os.path.exists(os.path.join(tmp_dir, idx))) except (ContentTooShortError, HTTPError, RuntimeError) as e: print(str(e)) From c1f225e43eccf082fe6ea424b72a2874b17834b7 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Mon, 15 Nov 2021 12:47:01 +0000 Subject: [PATCH 4/8] update tests Signed-off-by: Wenqi Li --- monai/data/decathlon_datalist.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/monai/data/decathlon_datalist.py b/monai/data/decathlon_datalist.py index 14b5f4f0f6..32b2b12994 100644 --- a/monai/data/decathlon_datalist.py +++ b/monai/data/decathlon_datalist.py @@ -159,7 +159,7 @@ def load_decathlon_properties(data_property_file_path: PathLike, property_keys: def check_missing_files( datalist: List[Dict], keys: KeysCollection, - root_dir: Optional[Union[Path, str]] = None, + root_dir: Optional[PathLike] = None, allow_missing_keys: bool = False, ): """Checks whether some files in the Decathlon datalist are missing. @@ -167,7 +167,7 @@ def check_missing_files( Args: datalist: a list of data items, every item is a dictionary. - ususally generated by `load_decathlon_datalist` API. + usually generated by `load_decathlon_datalist` API. keys: expected keys to check in the datalist. root_dir: if not None, provides the root dir for the relative file paths in `datalist`. allow_missing_keys: whether allow missing keys in the datalist items. @@ -188,6 +188,7 @@ def check_missing_files( for f in ensure_tuple(item[k]): if not isinstance(f, (str, Path)): raise ValueError(f"filepath of key `{k}` must be a string or a list of strings, but got: {f}.") + f = Path(f) if isinstance(root_dir, (str, Path)): f = Path(root_dir).joinpath(f) if not f.exists(): @@ -230,7 +231,7 @@ def create_cross_validation_datalist( root_dir: if not None, provides the root dir for the relative file paths in `datalist`. allow_missing_keys: if check_missing_files is `True`, whether allow missing keys in the datalist items. if False, raise exception if missing. default to False. - raise_error: when found missing files, if `True`, raise exception and stop, if `False`, print warining. + raise_error: when found missing files, if `True`, raise exception and stop, if `False`, print warning. """ if check_missing and keys is not None: From 4f8a214de46ab943c84d8179225df14f075d3184 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Mon, 15 Nov 2021 13:23:41 +0000 Subject: [PATCH 5/8] autofix Signed-off-by: Wenqi Li --- monai/apps/utils.py | 1 - monai/data/decathlon_datalist.py | 5 +---- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/monai/apps/utils.py b/monai/apps/utils.py index 76d0c75625..c2873959bb 100644 --- a/monai/apps/utils.py +++ b/monai/apps/utils.py @@ -24,7 +24,6 @@ from urllib.request import urlretrieve from monai.config.type_definitions import PathLike -from monai.utils import min_version, optional_import from monai.utils import look_up_option, min_version, optional_import gdown, has_gdown = optional_import("gdown", "3.6") diff --git a/monai/data/decathlon_datalist.py b/monai/data/decathlon_datalist.py index 32b2b12994..3e6c830a4f 100644 --- a/monai/data/decathlon_datalist.py +++ b/monai/data/decathlon_datalist.py @@ -157,10 +157,7 @@ def load_decathlon_properties(data_property_file_path: PathLike, property_keys: def check_missing_files( - datalist: List[Dict], - keys: KeysCollection, - root_dir: Optional[PathLike] = None, - allow_missing_keys: bool = False, + datalist: List[Dict], keys: KeysCollection, root_dir: Optional[PathLike] = None, allow_missing_keys: bool = False ): """Checks whether some files in the Decathlon datalist are missing. It would be helpful to check missing files before a heavy training run. From 4ea6af3eed9c7ac85aacbcc41f1880677e89f437 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Mon, 15 Nov 2021 13:28:29 +0000 Subject: [PATCH 6/8] fixes unit test Signed-off-by: Wenqi Li --- monai/apps/datasets.py | 4 ++-- monai/data/decathlon_datalist.py | 8 ++++---- monai/data/image_reader.py | 2 +- monai/data/nifti_saver.py | 7 +++---- monai/data/png_saver.py | 6 +++--- monai/transforms/io/array.py | 12 ++++++------ tests/test_check_missing_files.py | 2 +- 7 files changed, 20 insertions(+), 21 deletions(-) diff --git a/monai/apps/datasets.py b/monai/apps/datasets.py index f1f75d251c..36f38a970f 100644 --- a/monai/apps/datasets.py +++ b/monai/apps/datasets.py @@ -114,10 +114,10 @@ def _generate_data_list(self, dataset_dir: PathLike) -> List[Dict]: """ dataset_dir = Path(dataset_dir) - class_names = sorted(x for x in dataset_dir.iterdir() if (dataset_dir / x).is_dir()) + class_names = sorted(f"{x}" for x in dataset_dir.iterdir() if (dataset_dir / x).is_dir()) self.num_class = len(class_names) image_files = [ - [dataset_dir.joinpath(class_names[i], x) for x in (dataset_dir / class_names[i]).iterdir()] + [f"{dataset_dir.joinpath(class_names[i], x)}" for x in (dataset_dir / class_names[i]).iterdir()] for i in range(self.num_class) ] num_each = [len(image_files[i]) for i in range(self.num_class)] diff --git a/monai/data/decathlon_datalist.py b/monai/data/decathlon_datalist.py index 3e6c830a4f..e9a9451103 100644 --- a/monai/data/decathlon_datalist.py +++ b/monai/data/decathlon_datalist.py @@ -47,8 +47,8 @@ def _join_path(base_dir: PathLike, item: PathLike): result = os.path.normpath(os.path.join(base_dir, item)) if check_path and not os.path.exists(result): # if not an existing path, don't join with base dir - return item - return result + return f"{item}" + return f"{result}" if isinstance(element, (str, os.PathLike)): return _join_path(base_dir, element) @@ -183,10 +183,10 @@ def check_missing_files( continue for f in ensure_tuple(item[k]): - if not isinstance(f, (str, Path)): + if not isinstance(f, (str, os.PathLike)): raise ValueError(f"filepath of key `{k}` must be a string or a list of strings, but got: {f}.") f = Path(f) - if isinstance(root_dir, (str, Path)): + if isinstance(root_dir, (str, os.PathLike)): f = Path(root_dir).joinpath(f) if not f.exists(): missing_files.append(f) diff --git a/monai/data/image_reader.py b/monai/data/image_reader.py index 9125cd0a94..6e9cbca809 100644 --- a/monai/data/image_reader.py +++ b/monai/data/image_reader.py @@ -197,7 +197,7 @@ def read(self, data: Union[Sequence[PathLike], PathLike], **kwargs): kwargs_ = self.kwargs.copy() kwargs_.update(kwargs) for name in filenames: - name = Path(name) + name = f"{name}" if Path(name).is_dir(): # read DICOM series # https://itk.org/ITKExamples/src/IO/GDCM/ReadDICOMSeriesAndWrite3DImage diff --git a/monai/data/nifti_saver.py b/monai/data/nifti_saver.py index 427b2d29d5..75805479d7 100644 --- a/monai/data/nifti_saver.py +++ b/monai/data/nifti_saver.py @@ -9,13 +9,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -from pathlib import Path from typing import Dict, Optional, Union import numpy as np import torch -from monai.config import DtypeLike +from monai.config import DtypeLike, PathLike from monai.data.nifti_writer import write_nifti from monai.data.utils import create_file_basename from monai.utils import GridSampleMode, GridSamplePadMode @@ -37,7 +36,7 @@ class NiftiSaver: def __init__( self, - output_dir: Union[Path, str] = "./", + output_dir: PathLike = "./", output_postfix: str = "seg", output_ext: str = ".nii.gz", resample: bool = True, @@ -47,7 +46,7 @@ def __init__( dtype: DtypeLike = np.float64, output_dtype: DtypeLike = np.float32, squeeze_end_dims: bool = True, - data_root_dir: str = "", + data_root_dir: PathLike = "", separate_folder: bool = True, print_log: bool = True, ) -> None: diff --git a/monai/data/png_saver.py b/monai/data/png_saver.py index 609cc8d7be..5154ac1ab4 100644 --- a/monai/data/png_saver.py +++ b/monai/data/png_saver.py @@ -9,12 +9,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -from pathlib import Path from typing import Dict, Optional, Union import numpy as np import torch +from monai.config.type_definitions import PathLike from monai.data.png_writer import write_png from monai.data.utils import create_file_basename from monai.utils import ImageMetaKey as Key @@ -34,13 +34,13 @@ class PNGSaver: def __init__( self, - output_dir: Union[Path, str] = "./", + output_dir: PathLike = "./", output_postfix: str = "seg", output_ext: str = ".png", resample: bool = True, mode: Union[InterpolateMode, str] = InterpolateMode.NEAREST, scale: Optional[int] = None, - data_root_dir: str = "", + data_root_dir: PathLike = "", separate_folder: bool = True, print_log: bool = True, ) -> None: diff --git a/monai/transforms/io/array.py b/monai/transforms/io/array.py index b0f4c02cf6..170482f504 100644 --- a/monai/transforms/io/array.py +++ b/monai/transforms/io/array.py @@ -23,7 +23,7 @@ import numpy as np import torch -from monai.config import DtypeLike +from monai.config import DtypeLike, PathLike from monai.data.image_reader import ImageReader, ITKReader, NibabelReader, NumpyReader, PILReader from monai.data.nifti_saver import NiftiSaver from monai.data.png_saver import PNGSaver @@ -167,7 +167,7 @@ def register(self, reader: ImageReader): warnings.warn(f"Preferably the reader should inherit ImageReader, but got {type(reader)}.") self.readers.append(reader) - def __call__(self, filename: Union[Sequence[str], str, Path, Sequence[Path]], reader: Optional[ImageReader] = None): + def __call__(self, filename: Union[Sequence[PathLike], PathLike], reader: Optional[ImageReader] = None): """ Load image file and meta data from the given filename(s). If `reader` is not specified, this class automatically chooses readers based on the @@ -183,7 +183,7 @@ def __call__(self, filename: Union[Sequence[str], str, Path, Sequence[Path]], re reader: runtime reader to load image file and meta data. """ - filename = tuple(str(s) for s in ensure_tuple(filename)) # allow Path objects + filename = tuple(f"{Path(s).expanduser()}" for s in ensure_tuple(filename)) # allow Path objects img = None if reader is not None: img = reader.read(filename) # runtime specified reader @@ -216,7 +216,7 @@ def __call__(self, filename: Union[Sequence[str], str, Path, Sequence[Path]], re if self.image_only: return img_array - meta_data[Key.FILENAME_OR_OBJ] = ensure_tuple(filename)[0] + meta_data[Key.FILENAME_OR_OBJ] = f"{ensure_tuple(filename)[0]}" # Path obj should be strings for data loader # make sure all elements in metadata are little endian meta_data = switch_endianness(meta_data, "<") @@ -292,7 +292,7 @@ class SaveImage(Transform): def __init__( self, - output_dir: Union[Path, str] = "./", + output_dir: PathLike = "./", output_postfix: str = "trans", output_ext: str = ".nii.gz", resample: bool = True, @@ -302,7 +302,7 @@ def __init__( dtype: DtypeLike = np.float64, output_dtype: DtypeLike = np.float32, squeeze_end_dims: bool = True, - data_root_dir: str = "", + data_root_dir: PathLike = "", separate_folder: bool = True, print_log: bool = True, ) -> None: diff --git a/tests/test_check_missing_files.py b/tests/test_check_missing_files.py index 759ec28bac..1134409a66 100644 --- a/tests/test_check_missing_files.py +++ b/tests/test_check_missing_files.py @@ -49,7 +49,7 @@ def test_content(self): missings = check_missing_files( datalist=datalist, keys=["image", "label", "test"], root_dir=tempdir, allow_missing_keys=True ) - self.assertEqual(missings[0], os.path.join(tempdir, "test_label_missing.nii.gz")) + self.assertEqual(f"{missings[0]}", os.path.join(tempdir, "test_label_missing.nii.gz")) if __name__ == "__main__": From 30258056ddaba76b5da7dd19567d9d6b641355ed Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Tue, 16 Nov 2021 10:50:16 +0000 Subject: [PATCH 7/8] update based on comments Signed-off-by: Wenqi Li --- monai/apps/datasets.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/monai/apps/datasets.py b/monai/apps/datasets.py index 36f38a970f..deb76c9699 100644 --- a/monai/apps/datasets.py +++ b/monai/apps/datasets.py @@ -146,6 +146,7 @@ def _generate_data_list(self, dataset_dir: PathLike) -> List[Dict]: f'Unsupported section: {self.section}, available options are ["training", "validation", "test"].' ) + # the types of label and class name should be compatible with the pytorch dataloader return [ {"image": image_files_list[i], "label": image_class[i], "class_name": class_name[i]} for i in section_indices @@ -306,6 +307,7 @@ def get_properties(self, keys: Optional[Union[Sequence[str], str]] = None): return {} def _generate_data_list(self, dataset_dir: PathLike) -> List[Dict]: + # the types of the item in data list should be compatible with the dataloader dataset_dir = Path(dataset_dir) section = "training" if self.section in ["training", "validation"] else "test" datalist = load_decathlon_datalist(dataset_dir / "dataset.json", True, section) From 22db84e4e86c1e10bcb1fb967c2b9604e00dd829 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Tue, 16 Nov 2021 11:51:23 +0000 Subject: [PATCH 8/8] fixes dep issue Signed-off-by: Wenqi Li --- requirements-dev.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements-dev.txt b/requirements-dev.txt index 1d9d52bca5..f47eb14bbd 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -40,5 +40,5 @@ requests einops transformers mlflow -matplotlib +matplotlib!=3.5.0 tensorboardX