Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 23 additions & 19 deletions monai/apps/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import sys
from pathlib import Path
from typing import Callable, Dict, List, Optional, Sequence, Union

import numpy as np

from monai.apps.utils import download_and_extract
from monai.config.type_definitions import PathLike
from monai.data import (
CacheDataset,
load_decathlon_datalist,
Expand Down Expand Up @@ -64,7 +65,7 @@ class MedNISTDataset(Randomizable, CacheDataset):

def __init__(
self,
root_dir: str,
root_dir: PathLike,
section: str,
transform: Union[Sequence[Callable], Callable] = (),
download: bool = False,
Expand All @@ -75,19 +76,20 @@ def __init__(
cache_rate: float = 1.0,
num_workers: int = 0,
) -> None:
if not os.path.isdir(root_dir):
root_dir = Path(root_dir)
if not root_dir.is_dir():
raise ValueError("Root directory root_dir must be a directory.")
self.section = section
self.val_frac = val_frac
self.test_frac = test_frac
self.set_random_state(seed=seed)
tarfile_name = os.path.join(root_dir, self.compressed_file_name)
dataset_dir = os.path.join(root_dir, self.dataset_folder_name)
tarfile_name = root_dir / self.compressed_file_name
dataset_dir = root_dir / self.dataset_folder_name
self.num_class = 0
if download:
download_and_extract(self.resource, tarfile_name, root_dir, self.md5)

if not os.path.exists(dataset_dir):
if not dataset_dir.is_dir():
raise RuntimeError(
f"Cannot find dataset directory: {dataset_dir}, please use download=True to download it."
)
Expand All @@ -105,19 +107,17 @@ def get_num_classes(self) -> int:
"""Get number of classes."""
return self.num_class

def _generate_data_list(self, dataset_dir: str) -> List[Dict]:
def _generate_data_list(self, dataset_dir: PathLike) -> List[Dict]:
"""
Raises:
ValueError: When ``section`` is not one of ["training", "validation", "test"].

"""
class_names = sorted(x for x in os.listdir(dataset_dir) if os.path.isdir(os.path.join(dataset_dir, x)))
dataset_dir = Path(dataset_dir)
class_names = sorted(f"{x}" for x in dataset_dir.iterdir() if (dataset_dir / x).is_dir())
self.num_class = len(class_names)
image_files = [
[
os.path.join(dataset_dir, class_names[i], x)
for x in os.listdir(os.path.join(dataset_dir, class_names[i]))
]
[f"{dataset_dir.joinpath(class_names[i], x)}" for x in (dataset_dir / class_names[i]).iterdir()]
Comment thread
wyli marked this conversation as resolved.
for i in range(self.num_class)
]
num_each = [len(image_files[i]) for i in range(self.num_class)]
Expand Down Expand Up @@ -146,6 +146,7 @@ def _generate_data_list(self, dataset_dir: str) -> List[Dict]:
f'Unsupported section: {self.section}, available options are ["training", "validation", "test"].'
)

# the types of label and class name should be compatible with the pytorch dataloader
return [
{"image": image_files_list[i], "label": image_class[i], "class_name": class_name[i]}
for i in section_indices
Expand Down Expand Up @@ -234,7 +235,7 @@ class DecathlonDataset(Randomizable, CacheDataset):

def __init__(
self,
root_dir: str,
root_dir: PathLike,
task: str,
section: str,
transform: Union[Sequence[Callable], Callable] = (),
Expand All @@ -245,19 +246,20 @@ def __init__(
cache_rate: float = 1.0,
num_workers: int = 0,
) -> None:
if not os.path.isdir(root_dir):
root_dir = Path(root_dir)
if not root_dir.is_dir():
raise ValueError("Root directory root_dir must be a directory.")
self.section = section
self.val_frac = val_frac
self.set_random_state(seed=seed)
if task not in self.resource:
raise ValueError(f"Unsupported task: {task}, available options are: {list(self.resource.keys())}.")
dataset_dir = os.path.join(root_dir, task)
dataset_dir = root_dir / task
tarfile_name = f"{dataset_dir}.tar"
if download:
download_and_extract(self.resource[task], tarfile_name, root_dir, self.md5[task])

if not os.path.exists(dataset_dir):
if not dataset_dir.exists():
raise RuntimeError(
f"Cannot find dataset directory: {dataset_dir}, please use download=True to download it."
)
Expand All @@ -275,7 +277,7 @@ def __init__(
"numTraining",
"numTest",
]
self._properties = load_decathlon_properties(os.path.join(dataset_dir, "dataset.json"), property_keys)
self._properties = load_decathlon_properties(dataset_dir / "dataset.json", property_keys)
if transform == ():
transform = LoadImaged(["image", "label"])
CacheDataset.__init__(
Expand Down Expand Up @@ -304,9 +306,11 @@ def get_properties(self, keys: Optional[Union[Sequence[str], str]] = None):
return {key: self._properties[key] for key in ensure_tuple(keys)}
return {}

def _generate_data_list(self, dataset_dir: str) -> List[Dict]:
def _generate_data_list(self, dataset_dir: PathLike) -> List[Dict]:
# the types of the item in data list should be compatible with the dataloader
dataset_dir = Path(dataset_dir)
section = "training" if self.section in ["training", "validation"] else "test"
datalist = load_decathlon_datalist(os.path.join(dataset_dir, "dataset.json"), True, section)
datalist = load_decathlon_datalist(dataset_dir / "dataset.json", True, section)
return self._split_datalist(datalist)

def _split_datalist(self, datalist: List[Dict]) -> List[Dict]:
Expand Down
29 changes: 16 additions & 13 deletions monai/apps/mmars/mmars.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,15 @@
"""

import json
import os
import warnings
from typing import Mapping, Union
from pathlib import Path
from typing import Mapping, Optional, Union

import torch

import monai.networks.nets as monai_nets
from monai.apps.utils import download_and_extract, logger
from monai.config.type_definitions import PathLike
from monai.utils.module import optional_import

from .model_desc import MODEL_DESC
Expand Down Expand Up @@ -98,7 +99,9 @@ def _get_ngc_doc_url(model_name: str, model_prefix=""):
return f"https://ngc.nvidia.com/catalog/models/{model_prefix}{model_name}"


def download_mmar(item, mmar_dir=None, progress: bool = True, api: bool = False, version: int = -1):
def download_mmar(
item, mmar_dir: Optional[PathLike] = None, progress: bool = True, api: bool = False, version: int = -1
):
"""
Download and extract Medical Model Archive (MMAR) from Nvidia Clara Train.

Expand Down Expand Up @@ -128,10 +131,10 @@ def download_mmar(item, mmar_dir=None, progress: bool = True, api: bool = False,
if not mmar_dir:
get_dir, has_home = optional_import("torch.hub", name="get_dir")
if has_home:
mmar_dir = os.path.join(get_dir(), "mmars")
mmar_dir = Path(get_dir()) / "mmars"
else:
raise ValueError("mmar_dir=None, but no suitable default directory computed. Upgrade Pytorch to 1.6+ ?")

mmar_dir = Path(mmar_dir)
if api:
model_dict = _get_all_ngc_models(item)
if len(model_dict) == 0:
Expand All @@ -140,10 +143,10 @@ def download_mmar(item, mmar_dir=None, progress: bool = True, api: bool = False,
for k, v in model_dict.items():
ver = v["latest"] if version == -1 else str(version)
download_url = _get_ngc_url(k, ver)
model_dir = os.path.join(mmar_dir, v["name"])
model_dir = mmar_dir / v["name"]
download_and_extract(
url=download_url,
filepath=os.path.join(mmar_dir, f'{v["name"]}_{ver}.zip'),
filepath=mmar_dir / f'{v["name"]}_{ver}.zip',
output_dir=model_dir,
hash_val=None,
hash_type="md5",
Expand All @@ -161,11 +164,11 @@ def download_mmar(item, mmar_dir=None, progress: bool = True, api: bool = False,
if version > 0:
ver = str(version)
model_fullname = f"{item[Keys.NAME]}_{ver}"
model_dir = os.path.join(mmar_dir, model_fullname)
model_dir = mmar_dir / model_fullname
model_url = item.get(Keys.URL) or _get_ngc_url(item[Keys.NAME], version=ver, model_prefix="nvidia/med/")
download_and_extract(
url=model_url,
filepath=os.path.join(mmar_dir, f"{model_fullname}.{item[Keys.FILE_TYPE]}"),
filepath=mmar_dir / f"{model_fullname}.{item[Keys.FILE_TYPE]}",
output_dir=model_dir,
hash_val=item[Keys.HASH_VAL],
hash_type=item[Keys.HASH_TYPE],
Expand All @@ -178,7 +181,7 @@ def download_mmar(item, mmar_dir=None, progress: bool = True, api: bool = False,

def load_from_mmar(
item,
mmar_dir=None,
mmar_dir: Optional[PathLike] = None,
progress: bool = True,
version: int = -1,
map_location=None,
Expand Down Expand Up @@ -212,11 +215,11 @@ def load_from_mmar(
if not isinstance(item, Mapping):
item = get_model_spec(item)
model_dir = download_mmar(item=item, mmar_dir=mmar_dir, progress=progress, version=version)
model_file = os.path.join(model_dir, item[Keys.MODEL_FILE])
model_file = model_dir / item[Keys.MODEL_FILE]
logger.info(f'\n*** "{item[Keys.ID]}" available at {model_dir}.')

# loading with `torch.jit.load`
if f"{model_file}".endswith(".ts"):
if model_file.name.endswith(".ts"):
if not pretrained:
warnings.warn("Loading a ScriptModule, 'pretrained' option ignored.")
if weights_only:
Expand All @@ -232,7 +235,7 @@ def load_from_mmar(
model_config = _get_val(dict(model_dict).get("train_conf", {}), key=model_key, default={})
if not model_config:
# 2. search json CONFIG_FILE for model config spec.
json_path = os.path.join(model_dir, item.get(Keys.CONFIG_FILE, "config_train.json"))
json_path = model_dir / item.get(Keys.CONFIG_FILE, "config_train.json")
with open(json_path) as f:
conf_dict = json.load(f)
conf_dict = dict(conf_dict)
Expand Down
46 changes: 25 additions & 21 deletions monai/apps/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,12 @@
import tempfile
import warnings
import zipfile
from pathlib import Path
from typing import TYPE_CHECKING, Optional
from urllib.error import ContentTooShortError, HTTPError, URLError
from urllib.request import urlretrieve

from monai.config.type_definitions import PathLike
from monai.utils import look_up_option, min_version, optional_import

gdown, has_gdown = optional_import("gdown", "3.6")
Expand Down Expand Up @@ -70,10 +72,10 @@ def get_logger(
__all__.append("logger")


def _basename(p):
def _basename(p: PathLike) -> str:
"""get the last part of the path (removing the trailing slash if it exists)"""
sep = os.path.sep + (os.path.altsep or "") + "/ "
return os.path.basename(p.rstrip(sep))
return Path(f"{p}".rstrip(sep)).name


def _download_with_progress(url, filepath, progress: bool = True):
Expand Down Expand Up @@ -111,7 +113,7 @@ def update_to(self, b: int = 1, bsize: int = 1, tsize: Optional[int] = None):
raise e


def check_hash(filepath: str, val: Optional[str] = None, hash_type: str = "md5") -> bool:
def check_hash(filepath: PathLike, val: Optional[str] = None, hash_type: str = "md5") -> bool:
"""
Verify hash signature of specified file.

Expand Down Expand Up @@ -144,7 +146,7 @@ def check_hash(filepath: str, val: Optional[str] = None, hash_type: str = "md5")


def download_url(
url: str, filepath: str = "", hash_val: Optional[str] = None, hash_type: str = "md5", progress: bool = True
url: str, filepath: PathLike = "", hash_val: Optional[str] = None, hash_type: str = "md5", progress: bool = True
) -> None:
"""
Download file from specified URL link, support process bar and hash check.
Expand All @@ -170,9 +172,10 @@ def download_url(

"""
if not filepath:
filepath = os.path.abspath(os.path.join(".", _basename(url)))
filepath = Path(".", _basename(url)).resolve()
logger.info(f"Default downloading to '{filepath}'")
if os.path.exists(filepath):
filepath = Path(filepath)
if filepath.exists():
if not check_hash(filepath, hash_val, hash_type):
raise RuntimeError(
f"{hash_type} check of existing file failed: filepath={filepath}, expected {hash_type}={hash_val}."
Expand All @@ -181,21 +184,21 @@ def download_url(
return

with tempfile.TemporaryDirectory() as tmp_dir:
tmp_name = os.path.join(tmp_dir, f"{_basename(filepath)}")
tmp_name = Path(tmp_dir, _basename(filepath))
if url.startswith("https://drive.google.com"):
if not has_gdown:
raise RuntimeError("To download files from Google Drive, please install the gdown dependency.")
gdown.download(url, tmp_name, quiet=not progress)
gdown.download(url, f"{tmp_name}", quiet=not progress)
else:
_download_with_progress(url, tmp_name, progress=progress)
if not os.path.exists(tmp_name):
if not tmp_name.exists():
raise RuntimeError(
f"Download of file from {url} to {filepath} failed due to network issue or denied permission."
)
file_dir = os.path.dirname(filepath)
file_dir = filepath.parent
if file_dir:
os.makedirs(file_dir, exist_ok=True)
shutil.move(tmp_name, filepath) # copy the downloaded to a user-specified cache.
shutil.move(f"{tmp_name}", f"{filepath}") # copy the downloaded to a user-specified cache.
logger.info(f"Downloaded: {filepath}")
if not check_hash(filepath, hash_val, hash_type):
raise RuntimeError(
Expand All @@ -205,8 +208,8 @@ def download_url(


def extractall(
filepath: str,
output_dir: str = ".",
filepath: PathLike,
output_dir: PathLike = ".",
hash_val: Optional[str] = None,
hash_type: str = "md5",
file_type: str = "",
Expand Down Expand Up @@ -235,24 +238,25 @@ def extractall(
"""
if has_base:
# the extracted files will be in this folder
cache_dir = os.path.join(output_dir, _basename(filepath).split(".")[0])
cache_dir = Path(output_dir, _basename(filepath).split(".")[0])
else:
cache_dir = output_dir
if os.path.exists(cache_dir) and len(os.listdir(cache_dir)) > 0:
cache_dir = Path(output_dir)
if cache_dir.exists() and len(list(cache_dir.iterdir())) > 0:
logger.info(f"Non-empty folder exists in {cache_dir}, skipped extracting.")
return
filepath = Path(filepath)
if hash_val and not check_hash(filepath, hash_val, hash_type):
raise RuntimeError(
f"{hash_type} check of compressed file failed: " f"filepath={filepath}, expected {hash_type}={hash_val}."
)
logger.info(f"Writing into directory: {output_dir}.")
_file_type = file_type.lower().strip()
if filepath.endswith("zip") or _file_type == "zip":
if filepath.name.endswith("zip") or _file_type == "zip":
zip_file = zipfile.ZipFile(filepath)
zip_file.extractall(output_dir)
zip_file.close()
return
if filepath.endswith("tar") or filepath.endswith("tar.gz") or "tar" in _file_type:
if filepath.name.endswith("tar") or filepath.name.endswith("tar.gz") or "tar" in _file_type:
tar_file = tarfile.open(filepath)
tar_file.extractall(output_dir)
tar_file.close()
Expand All @@ -264,8 +268,8 @@ def extractall(

def download_and_extract(
url: str,
filepath: str = "",
output_dir: str = ".",
filepath: PathLike = "",
output_dir: PathLike = ".",
hash_val: Optional[str] = None,
hash_type: str = "md5",
file_type: str = "",
Expand All @@ -292,6 +296,6 @@ def download_and_extract(
progress: whether to display progress bar.
"""
with tempfile.TemporaryDirectory() as tmp_dir:
filename = filepath or os.path.join(tmp_dir, f"{_basename(url)}")
filename = filepath or Path(tmp_dir, _basename(url)).resolve()
download_url(url=url, filepath=filename, hash_val=hash_val, hash_type=hash_type, progress=progress)
extractall(filepath=filename, output_dir=output_dir, file_type=file_type, has_base=has_base)
10 changes: 9 additions & 1 deletion monai/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,12 @@
print_gpu_info,
print_system_info,
)
from .type_definitions import DtypeLike, IndexSelection, KeysCollection, NdarrayOrTensor, NdarrayTensor, TensorOrList
from .type_definitions import (
DtypeLike,
IndexSelection,
KeysCollection,
NdarrayOrTensor,
NdarrayTensor,
PathLike,
TensorOrList,
)
Loading