diff --git a/monai/data/image_reader.py b/monai/data/image_reader.py index 047693ba55..d41b779a1e 100644 --- a/monai/data/image_reader.py +++ b/monai/data/image_reader.py @@ -20,7 +20,7 @@ from monai.config import DtypeLike, KeysCollection from monai.data.utils import correct_nifti_header_if_necessary from monai.transforms.utility.array import EnsureChannelFirst -from monai.utils import ensure_tuple, optional_import +from monai.utils import ensure_tuple, ensure_tuple_rep, optional_import from .utils import is_supported_format @@ -253,7 +253,7 @@ def _get_meta_dict(self, img) -> Dict: meta_dict["direction"] = itk.array_from_matrix(img.GetDirection()) return meta_dict - def _get_affine(self, img) -> np.ndarray: + def _get_affine(self, img): """ Get or construct the affine matrix of the image, it can be used to correct spacing, orientation or execute spatial transforms. @@ -274,7 +274,7 @@ def _get_affine(self, img) -> np.ndarray: affine[(slice(-1), -1)] = origin return affine - def _get_spatial_shape(self, img) -> np.ndarray: + def _get_spatial_shape(self, img): """ Get the spatial shape of image data, it doesn't contain the channel dim. @@ -406,7 +406,7 @@ def _get_meta_dict(self, img) -> Dict: """ return dict(img.header) - def _get_affine(self, img) -> np.ndarray: + def _get_affine(self, img): """ Get the affine matrix of the image, it can be used to correct spacing, orientation or execute spatial transforms. @@ -417,7 +417,7 @@ def _get_affine(self, img) -> np.ndarray: """ return np.array(img.affine, copy=True) - def _get_spatial_shape(self, img) -> np.ndarray: + def _get_spatial_shape(self, img): """ Get the spatial shape of image data, it doesn't contain the channel dim. @@ -430,7 +430,7 @@ def _get_spatial_shape(self, img) -> np.ndarray: # the img data should have no channel dim or the last dim is channel return np.asarray(img.header["dim"][1 : spatial_rank + 1]) - def _get_array_data(self, img) -> np.ndarray: + def _get_array_data(self, img): """ Get the raw array data of the image, converted to Numpy array. @@ -623,7 +623,7 @@ def _get_meta_dict(self, img) -> Dict: "height": img.height, } - def _get_spatial_shape(self, img) -> np.ndarray: + def _get_spatial_shape(self, img): """ Get the spatial shape of image data, it doesn't contain the channel dim. Args: @@ -697,7 +697,7 @@ def get_data( level: int = 0, dtype: DtypeLike = np.uint8, grid_shape: Tuple[int, int] = (1, 1), - patch_size: Optional[int] = None, + patch_size: Optional[Union[int, Tuple[int, int]]] = None, ): """ Extract regions as numpy array from WSI image and return them. @@ -711,15 +711,15 @@ def get_data( level: the level number, or list of level numbers (default=0) dtype: the data type of output image grid_shape: (row, columns) tuple define a grid to extract patches on that - patch_size: (heigsht, width) the size of extracted patches at the given level + patch_size: (height, width) the size of extracted patches at the given level """ - if size is None: - if location == (0, 0): - # the maximum size is set to WxH - size = (img.shape[0] // (2 ** level), img.shape[1] // (2 ** level)) - print(f"Reading the whole image at level={level} with shape={size}") - else: - raise ValueError("Size need to be provided to extract the region!") + + if self.reader_lib == "openslide" and size is None: + # the maximum size is set to WxH + size = ( + img.shape[0] // (2 ** level) - location[0], + img.shape[1] // (2 ** level) - location[1], + ) region = self._extract_region(img, location=location, size=size, level=level, dtype=dtype) @@ -731,8 +731,12 @@ def get_data( if patch_size is None: patches = region else: + tuple_patch_size = ensure_tuple_rep(patch_size, 2) patches = self._extract_patches( - region, patch_size=(patch_size, patch_size), grid_shape=grid_shape, dtype=dtype + region, + patch_size=tuple_patch_size, # type: ignore + grid_shape=grid_shape, + dtype=dtype, ) return patches, metadata @@ -740,22 +744,43 @@ def get_data( def _extract_region( self, img_obj, - size: Tuple[int, int], + size: Optional[Tuple[int, int]], location: Tuple[int, int] = (0, 0), level: int = 0, dtype: DtypeLike = np.uint8, ): # reverse the order of dimensions for size and location to be compatible with image shape - size = size[::-1] location = location[::-1] - region = img_obj.read_region(location=location, size=size, level=level) - if self.reader_lib == "openslide": - region = region.convert("RGB") - # convert to numpy - region = np.asarray(region, dtype=dtype) + if size is None: + region = img_obj.read_region(location=location, level=level) + else: + size = size[::-1] + region = img_obj.read_region(location=location, size=size, level=level) + region = self.convert_to_rgb_array(region, dtype) return region + def convert_to_rgb_array( + self, + raw_region, + dtype: DtypeLike = np.uint8, + ): + """Convert to RGB mode and numpy array""" + if self.reader_lib == "openslide": + # convert to RGB + raw_region = raw_region.convert("RGB") + # convert to numpy + raw_region = np.asarray(raw_region, dtype=dtype) + else: + num_channels = len(raw_region.channel_names) + # convert to numpy + raw_region = np.asarray(raw_region, dtype=dtype) + # remove alpha channel if exist (RGBA) + if num_channels > 3: + raw_region = raw_region[:, :, :3] + + return raw_region + def _extract_patches( self, region: np.ndarray, diff --git a/tests/test_cuimage_reader.py b/tests/test_cuimage_reader.py index 036d5ad1ae..1b0293f159 100644 --- a/tests/test_cuimage_reader.py +++ b/tests/test_cuimage_reader.py @@ -11,7 +11,7 @@ from monai.utils import optional_import _, has_cim = optional_import("cucim") - +PILImage, has_pil = optional_import("PIL.Image") FILE_URL = "http://openslide.cs.cmu.edu/download/openslide-testdata/Generic-TIFF/CMU-1.tiff" FILE_PATH = os.path.join(os.path.dirname(__file__), "testing_data", os.path.basename(FILE_URL)) @@ -62,6 +62,14 @@ np.array([[[[239]], [[239]], [[239]]], [[[243]], [[243]], [[243]]]]), ] +TEST_CASE_RGB_0 = [ + np.ones((3, 2, 2), dtype=np.uint8), # CHW +] + +TEST_CASE_RGB_1 = [ + np.ones((3, 100, 100), dtype=np.uint8), # CHW +] + class TestCuCIMReader(unittest.TestCase): @skipUnless(has_cim, "Requires CuCIM") @@ -91,6 +99,32 @@ def test_read_patches(self, file_path, patch_info, expected_img): self.assertTupleEqual(img.shape, expected_img.shape) self.assertIsNone(assert_array_equal(img, expected_img)) + @parameterized.expand([TEST_CASE_RGB_0, TEST_CASE_RGB_1]) + @skipUnless(has_pil, "Requires PIL") + def test_read_rgba(self, img_expected): + image = {} + reader = WSIReader("cuCIM") + for mode in ["RGB", "RGBA"]: + file_path = self.create_rgba_image(img_expected, "test_cu_tiff_image", mode=mode) + img_obj = reader.read(file_path) + image[mode], _ = reader.get_data(img_obj) + + self.assertIsNone(assert_array_equal(image["RGB"], img_expected)) + self.assertIsNone(assert_array_equal(image["RGBA"], img_expected)) + + def create_rgba_image(self, array: np.ndarray, filename_prefix: str, mode: str): + file_path = os.path.join(os.path.dirname(__file__), "testing_data", f"{filename_prefix}_{mode}.tiff") + + if mode == "RGBA": + array = np.concatenate([array, 255 * np.ones_like(array[0])[np.newaxis]]).astype(np.uint8) + + img_rgb = array.transpose(1, 2, 0) + + image = PILImage.fromarray(img_rgb, mode=mode) + image.save(file_path) + + return file_path + if __name__ == "__main__": unittest.main()