diff --git a/docs/source/apps.rst b/docs/source/apps.rst index 1c4f4c3dfb..4c45a5fb39 100644 --- a/docs/source/apps.rst +++ b/docs/source/apps.rst @@ -62,3 +62,12 @@ Applications :members: .. autoclass:: Fetch2DSliced :members: + +`Pathology` +----------- + +.. automodule:: monai.apps.pathology.datasets +.. autoclass:: PatchWSIDataset + :members: +.. autoclass:: SmartCachePatchWSIDataset + :members: diff --git a/monai/apps/pathology/__init__.py b/monai/apps/pathology/__init__.py new file mode 100644 index 0000000000..bbdb812c03 --- /dev/null +++ b/monai/apps/pathology/__init__.py @@ -0,0 +1,12 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .datasets import PatchWSIDataset, SmartCacheDataset diff --git a/monai/apps/pathology/datasets.py b/monai/apps/pathology/datasets.py new file mode 100644 index 0000000000..f9ce0bc62b --- /dev/null +++ b/monai/apps/pathology/datasets.py @@ -0,0 +1,158 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +from typing import Callable, List, Optional, Sequence, Tuple, Union + +import numpy as np + +from monai.data import Dataset, SmartCacheDataset +from monai.data.image_reader import WSIReader + +__all__ = ["PatchWSIDataset", "SmartCachePatchWSIDataset"] + + +class PatchWSIDataset(Dataset): + """ + This dataset reads whole slide images, extracts regions, and creates patches. + It also reads labels for each patch and provides each patch with its associated class labels. + + Args: + data: the list of input samples including image, location, and label (see below for more details). + region_size: the region to be extracted from the whole slide image. + grid_shape: the grid shape on which the patches should be extracted. + patch_size: the patches extracted from the region on the grid. + image_reader_name: the name of library to be used for loading whole slide imaging, either CuCIM or OpenSlide. + Defaults to CuCIM. + transform: transforms to be executed on input data. + + Note: + The input data has the following form as an example: + `[{"image": "path/to/image1.tiff", "location": [200, 500], "label": [0,0,0,1]}]`. + + This means from "image1.tiff" extract a region centered at the given location `location` + with the size of `region_size`, and then extract patches with the size of `patch_size` + from a square grid with the shape of `grid_shape`. + Be aware the the `grid_shape` should construct a grid with the same number of element as `labels`, + so for this example the `grid_shape` should be (2, 2). + + """ + + def __init__( + self, + data: List, + region_size: Union[int, Tuple[int, int]], + grid_shape: Union[int, Tuple[int, int]], + patch_size: int, + image_reader_name: str = "cuCIM", + transform: Optional[Callable] = None, + ): + super().__init__(data, transform) + + if isinstance(region_size, int): + self.region_size = (region_size, region_size) + else: + self.region_size = region_size + + if isinstance(grid_shape, int): + self.grid_shape = (grid_shape, grid_shape) + else: + self.grid_shape = grid_shape + + self.patch_size = patch_size + self.sub_region_size = (self.region_size[0] / self.grid_shape[0], self.region_size[1] / self.grid_shape[1]) + + self.image_path_list = list({x["image"] for x in self.data}) + + self.image_reader_name = image_reader_name + self.image_reader = WSIReader(image_reader_name) + self.wsi_object_dict = None + if self.image_reader_name != "openslide": + # OpenSlide causes memory issue if we prefetch image objects + self._fetch_wsi_objects() + + def _fetch_wsi_objects(self): + """Load all the image objects and reuse them when asked for an item.""" + self.wsi_object_dict = {} + for image_path in self.image_path_list: + self.wsi_object_dict[image_path] = self.image_reader.read(image_path) + + def __getitem__(self, index): + sample = self.data[index] + if self.image_reader_name == "openslide": + img_obj = self.image_reader.read(sample["image"]) + else: + img_obj = self.wsi_object_dict[sample["image"]] + location = [sample["location"][i] - self.region_size[i] // 2 for i in range(len(self.region_size))] + images, _ = self.image_reader.get_data( + img=img_obj, + location=location, + size=self.region_size, + grid_shape=self.grid_shape, + patch_size=self.patch_size, + ) + labels = np.array(sample["label"], dtype=np.float32)[:, np.newaxis, np.newaxis] + patches = [{"image": images[i], "label": labels[i]} for i in range(len(sample["label"]))] + if self.transform: + patches = self.transform(patches) + return patches + + +class SmartCachePatchWSIDataset(SmartCacheDataset): + """Add SmartCache functionality to `PatchWSIDataset`. + + Args: + data: the list of input samples including image, location, and label (see `PatchWSIDataset` for more details) + region_size: the region to be extracted from the whole slide image. + grid_shape: the grid shape on which the patches should be extracted. + patch_size: the patches extracted from the region on the grid. + image_reader_name: the name of library to be used for loading whole slide imaging, either CuCIM or OpenSlide. + Defaults to CuCIM. + transform: transforms to be executed on input data. + replace_rate: percentage of the cached items to be replaced in every epoch. + cache_num: number of items to be cached. Default is `sys.maxsize`. + will take the minimum of (cache_num, data_length x cache_rate, data_length). + cache_rate: percentage of cached data in total, default is 1.0 (cache all). + will take the minimum of (cache_num, data_length x cache_rate, data_length). + num_init_workers: the number of worker threads to initialize the cache for first epoch. + If num_init_workers is None then the number returned by os.cpu_count() is used. + num_replace_workers: the number of worker threads to prepare the replacement cache for every epoch. + If num_replace_workers is None then the number returned by os.cpu_count() is used. + progress: whether to display a progress bar when caching for the first epoch. + + """ + + def __init__( + self, + data: List, + region_size: Union[int, Tuple[int, int]], + grid_shape: Union[int, Tuple[int, int]], + patch_size: int, + transform: Union[Sequence[Callable], Callable], + image_reader_name: str = "cuCIM", + replace_rate: float = 0.5, + cache_num: int = sys.maxsize, + cache_rate: float = 1.0, + num_init_workers: Optional[int] = None, + num_replace_workers: Optional[int] = None, + progress: bool = True, + ): + patch_wsi_dataset = PatchWSIDataset(data, region_size, grid_shape, patch_size, image_reader_name) + super().__init__( + data=patch_wsi_dataset, # type: ignore + transform=transform, + replace_rate=replace_rate, + cache_num=cache_num, + cache_rate=cache_rate, + num_init_workers=num_init_workers, + num_replace_workers=num_replace_workers, + progress=progress, + ) diff --git a/monai/data/dataset.py b/monai/data/dataset.py index 813008e3a8..9a4e932160 100644 --- a/monai/data/dataset.py +++ b/monai/data/dataset.py @@ -582,6 +582,21 @@ class SmartCacheDataset(Randomizable, CacheDataset): This replacement will not work if setting the `multiprocessing_context` of DataLoader to `spawn` or on windows(the default multiprocessing method is `spawn`) and setting `num_workers` greater than 0. + Args: + data: input data to load and transform to generate dataset for model. + transform: transforms to execute operations on input data. + replace_rate: percentage of the cached items to be replaced in every epoch. + cache_num: number of items to be cached. Default is `sys.maxsize`. + will take the minimum of (cache_num, data_length x cache_rate, data_length). + cache_rate: percentage of cached data in total, default is 1.0 (cache all). + will take the minimum of (cache_num, data_length x cache_rate, data_length). + num_init_workers: the number of worker threads to initialize the cache for first epoch. + If num_init_workers is None then the number returned by os.cpu_count() is used. + num_replace_workers: the number of worker threads to prepare the replacement cache for every epoch. + If num_replace_workers is None then the number returned by os.cpu_count() is used. + progress: whether to display a progress bar when caching for the first epoch. + shuffle: whether to shuffle the whole data list before preparing the cache content for first epoch. + seed: random seed if shuffle is `True`, default to `0`. """ def __init__( @@ -597,24 +612,6 @@ def __init__( shuffle: bool = True, seed: int = 0, ) -> None: - """ - Args: - data: input data to load and transform to generate dataset for model. - transform: transforms to execute operations on input data. - replace_rate: percentage of the cached items to be replaced in every epoch. - cache_num: number of items to be cached. Default is `sys.maxsize`. - will take the minimum of (cache_num, data_length x cache_rate, data_length). - cache_rate: percentage of cached data in total, default is 1.0 (cache all). - will take the minimum of (cache_num, data_length x cache_rate, data_length). - num_init_workers: the number of worker threads to initialize the cache for first epoch. - If num_init_workers is None then the number returned by os.cpu_count() is used. - num_replace_workers: the number of worker threads to prepare the replacement cache for every epoch. - If num_replace_workers is None then the number returned by os.cpu_count() is used. - progress: whether to display a progress bar when caching for the first epoch. - shuffle: whether to shuffle the whole data list before preparing the cache content for first epoch. - seed: random seed if shuffle is `True`, default to `0`. - - """ if shuffle: self.set_random_state(seed=seed) self.randomize(data) diff --git a/tests/test_patch_wsi_dataset.py b/tests/test_patch_wsi_dataset.py new file mode 100644 index 0000000000..730519ed52 --- /dev/null +++ b/tests/test_patch_wsi_dataset.py @@ -0,0 +1,136 @@ +import os +import unittest +from unittest import skipUnless +from urllib import request + +import numpy as np +from numpy.testing import assert_array_equal +from parameterized import parameterized + +from monai.apps.pathology.datasets import PatchWSIDataset +from monai.utils import optional_import + +_, has_cim = optional_import("cucim") +_, has_osl = optional_import("openslide") + +FILE_URL = "http://openslide.cs.cmu.edu/download/openslide-testdata/Generic-TIFF/CMU-1.tiff" + +TEST_CASE_0 = [ + FILE_URL, + { + "data": [ + {"image": "./CMU-1.tiff", "location": [0, 0], "label": [1]}, + ], + "region_size": (1, 1), + "grid_shape": (1, 1), + "patch_size": 1, + "image_reader_name": "cuCIM", + }, + [ + {"image": np.array([[[239]], [[239]], [[239]]], dtype=np.uint8), "label": np.array([[1]])}, + ], +] + +TEST_CASE_1 = [ + FILE_URL, + { + "data": [{"image": "./CMU-1.tiff", "location": [10004, 20004], "label": [0, 0, 0, 1]}], + "region_size": (8, 8), + "grid_shape": (2, 2), + "patch_size": 1, + "image_reader_name": "cuCIM", + }, + [ + {"image": np.array([[[247]], [[245]], [[248]]], dtype=np.uint8), "label": np.array([[0]])}, + {"image": np.array([[[245]], [[247]], [[244]]], dtype=np.uint8), "label": np.array([[0]])}, + {"image": np.array([[[246]], [[246]], [[246]]], dtype=np.uint8), "label": np.array([[0]])}, + {"image": np.array([[[246]], [[246]], [[246]]], dtype=np.uint8), "label": np.array([[1]])}, + ], +] + +TEST_CASE_2 = [ + FILE_URL, + { + "data": [ + {"image": "./CMU-1.tiff", "location": [0, 0], "label": [1]}, + ], + "region_size": 1, + "grid_shape": 1, + "patch_size": 1, + "image_reader_name": "cuCIM", + }, + [ + {"image": np.array([[[239]], [[239]], [[239]]], dtype=np.uint8), "label": np.array([[1]])}, + ], +] + + +TEST_CASE_OPENSLIDE_0 = [ + FILE_URL, + { + "data": [ + {"image": "./CMU-1.tiff", "location": [0, 0], "label": [1]}, + ], + "region_size": (1, 1), + "grid_shape": (1, 1), + "patch_size": 1, + "image_reader_name": "OpenSlide", + }, + [ + {"image": np.array([[[239]], [[239]], [[239]]], dtype=np.uint8), "label": np.array([[1]])}, + ], +] + +TEST_CASE_OPENSLIDE_1 = [ + FILE_URL, + { + "data": [{"image": "./CMU-1.tiff", "location": [10004, 20004], "label": [0, 0, 0, 1]}], + "region_size": (8, 8), + "grid_shape": (2, 2), + "patch_size": 1, + "image_reader_name": "OpenSlide", + }, + [ + {"image": np.array([[[247]], [[245]], [[248]]], dtype=np.uint8), "label": np.array([[0]])}, + {"image": np.array([[[245]], [[247]], [[244]]], dtype=np.uint8), "label": np.array([[0]])}, + {"image": np.array([[[246]], [[246]], [[246]]], dtype=np.uint8), "label": np.array([[0]])}, + {"image": np.array([[[246]], [[246]], [[246]]], dtype=np.uint8), "label": np.array([[1]])}, + ], +] + + +class TestPatchWSIDataset(unittest.TestCase): + @parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_2]) + @skipUnless(has_cim, "Requires CuCIM") + def test_read_patches_cucim(self, file_url, input_parameters, expected): + self.camelyon_data_download(file_url) + dataset = PatchWSIDataset(**input_parameters) + samples = dataset[0] + for i in range(len(samples)): + self.assertTupleEqual(samples[i]["label"].shape, expected[i]["label"].shape) + self.assertTupleEqual(samples[i]["image"].shape, expected[i]["image"].shape) + self.assertIsNone(assert_array_equal(samples[i]["label"], expected[i]["label"])) + self.assertIsNone(assert_array_equal(samples[i]["image"], expected[i]["image"])) + + @parameterized.expand([TEST_CASE_OPENSLIDE_0, TEST_CASE_OPENSLIDE_1]) + @skipUnless(has_osl, "Requires OpenSlide") + def test_read_patches_openslide(self, file_url, input_parameters, expected): + self.camelyon_data_download(file_url) + dataset = PatchWSIDataset(**input_parameters) + samples = dataset[0] + for i in range(len(samples)): + self.assertTupleEqual(samples[i]["label"].shape, expected[i]["label"].shape) + self.assertTupleEqual(samples[i]["image"].shape, expected[i]["image"].shape) + self.assertIsNone(assert_array_equal(samples[i]["label"], expected[i]["label"])) + self.assertIsNone(assert_array_equal(samples[i]["image"], expected[i]["image"])) + + def camelyon_data_download(self, file_url): + filename = os.path.basename(file_url) + if not os.path.exists(filename): + print(f"Test image [{filename}] does not exist. Downloading...") + request.urlretrieve(file_url, filename) + return filename + + +if __name__ == "__main__": + unittest.main()