diff --git a/monai/apps/datasets.py b/monai/apps/datasets.py index f0416b8c4f..a193b7391e 100644 --- a/monai/apps/datasets.py +++ b/monai/apps/datasets.py @@ -11,7 +11,7 @@ import os import sys -from typing import Any, Callable, Dict, List, Optional, Sequence, Union +from typing import Callable, Dict, List, Optional, Sequence, Union import numpy as np @@ -98,8 +98,8 @@ def __init__( self, data, transform, cache_num=cache_num, cache_rate=cache_rate, num_workers=num_workers ) - def randomize(self, data: Optional[Any] = None) -> None: - self.rann = self.R.random() + def randomize(self, data: List[int]) -> None: + self.R.shuffle(data) def get_num_classes(self) -> int: """Get number of classes.""" @@ -132,22 +132,26 @@ def _generate_data_list(self, dataset_dir: str) -> List[Dict]: data = [] - for i in range(num_total): - self.randomize() - if self.section == "training": - if self.rann < self.val_frac + self.test_frac: - continue - elif self.section == "validation": - if self.rann >= self.val_frac: - continue - elif self.section == "test": - if self.rann < self.val_frac or self.rann >= self.val_frac + self.test_frac: - continue - else: - raise ValueError( - f'Unsupported section: {self.section}, available options are ["training", "validation", "test"].' - ) + length = len(image_files_list) + indices = np.arange(length) + self.randomize(indices) + + test_length = int(length * self.test_frac) + val_length = int(length * self.val_frac) + if self.section == "test": + section_indices = indices[:test_length] + elif self.section == "validation": + section_indices = indices[test_length : test_length + val_length] + elif self.section == "training": + section_indices = indices[test_length + val_length :] + else: + raise ValueError( + f'Unsupported section: {self.section}, available options are ["training", "validation", "test"].' + ) + + for i in section_indices: data.append({"image": image_files_list[i], "label": image_class[i], "class_name": class_name[i]}) + return data diff --git a/tests/test_mednistdataset.py b/tests/test_mednistdataset.py index 0887734a7c..2e27f4ba95 100644 --- a/tests/test_mednistdataset.py +++ b/tests/test_mednistdataset.py @@ -18,6 +18,8 @@ from monai.transforms import AddChanneld, Compose, LoadImaged, ScaleIntensityd, ToTensord from tests.utils import skip_if_quick +MEDNIST_FULL_DATASET_LENGTH = 58954 + class TestMedNISTDataset(unittest.TestCase): @skip_if_quick @@ -33,7 +35,7 @@ def test_values(self): ) def _test_dataset(dataset): - self.assertEqual(len(dataset), 5986) + self.assertEqual(len(dataset), int(MEDNIST_FULL_DATASET_LENGTH * dataset.test_frac)) self.assertTrue("image" in dataset[0]) self.assertTrue("label" in dataset[0]) self.assertTrue("image_meta_dict" in dataset[0]) @@ -56,6 +58,9 @@ def _test_dataset(dataset): _test_dataset(data) data = MedNISTDataset(root_dir=testing_dir, section="test", download=False) self.assertTupleEqual(data[0]["image"].shape, (64, 64)) + # test same dataset length with different random seed + data = MedNISTDataset(root_dir=testing_dir, transform=transform, section="test", download=False, seed=42) + _test_dataset(data) shutil.rmtree(os.path.join(testing_dir, "MedNIST")) try: data = MedNISTDataset(root_dir=testing_dir, transform=transform, section="test", download=False)