Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 49 additions & 24 deletions monai/data/image_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from monai.config import DtypeLike, KeysCollection
from monai.data.utils import correct_nifti_header_if_necessary
from monai.transforms.utility.array import EnsureChannelFirst
from monai.utils import ensure_tuple, optional_import
from monai.utils import ensure_tuple, ensure_tuple_rep, optional_import

from .utils import is_supported_format

Expand Down Expand Up @@ -253,7 +253,7 @@ def _get_meta_dict(self, img) -> Dict:
meta_dict["direction"] = itk.array_from_matrix(img.GetDirection())
return meta_dict

def _get_affine(self, img) -> np.ndarray:
def _get_affine(self, img):
"""
Get or construct the affine matrix of the image, it can be used to correct
spacing, orientation or execute spatial transforms.
Expand All @@ -274,7 +274,7 @@ def _get_affine(self, img) -> np.ndarray:
affine[(slice(-1), -1)] = origin
return affine

def _get_spatial_shape(self, img) -> np.ndarray:
def _get_spatial_shape(self, img):
"""
Get the spatial shape of image data, it doesn't contain the channel dim.

Expand Down Expand Up @@ -406,7 +406,7 @@ def _get_meta_dict(self, img) -> Dict:
"""
return dict(img.header)

def _get_affine(self, img) -> np.ndarray:
def _get_affine(self, img):
"""
Get the affine matrix of the image, it can be used to correct
spacing, orientation or execute spatial transforms.
Expand All @@ -417,7 +417,7 @@ def _get_affine(self, img) -> np.ndarray:
"""
return np.array(img.affine, copy=True)

def _get_spatial_shape(self, img) -> np.ndarray:
def _get_spatial_shape(self, img):
"""
Get the spatial shape of image data, it doesn't contain the channel dim.

Expand All @@ -430,7 +430,7 @@ def _get_spatial_shape(self, img) -> np.ndarray:
# the img data should have no channel dim or the last dim is channel
return np.asarray(img.header["dim"][1 : spatial_rank + 1])

def _get_array_data(self, img) -> np.ndarray:
def _get_array_data(self, img):
"""
Get the raw array data of the image, converted to Numpy array.

Expand Down Expand Up @@ -623,7 +623,7 @@ def _get_meta_dict(self, img) -> Dict:
"height": img.height,
}

def _get_spatial_shape(self, img) -> np.ndarray:
def _get_spatial_shape(self, img):
"""
Get the spatial shape of image data, it doesn't contain the channel dim.
Args:
Expand Down Expand Up @@ -697,7 +697,7 @@ def get_data(
level: int = 0,
dtype: DtypeLike = np.uint8,
grid_shape: Tuple[int, int] = (1, 1),
patch_size: Optional[int] = None,
patch_size: Optional[Union[int, Tuple[int, int]]] = None,
):
"""
Extract regions as numpy array from WSI image and return them.
Expand All @@ -711,15 +711,15 @@ def get_data(
level: the level number, or list of level numbers (default=0)
dtype: the data type of output image
grid_shape: (row, columns) tuple define a grid to extract patches on that
patch_size: (heigsht, width) the size of extracted patches at the given level
patch_size: (height, width) the size of extracted patches at the given level
"""
if size is None:
if location == (0, 0):
# the maximum size is set to WxH
size = (img.shape[0] // (2 ** level), img.shape[1] // (2 ** level))
print(f"Reading the whole image at level={level} with shape={size}")
else:
raise ValueError("Size need to be provided to extract the region!")

if self.reader_lib == "openslide" and size is None:
# the maximum size is set to WxH
size = (
img.shape[0] // (2 ** level) - location[0],
img.shape[1] // (2 ** level) - location[1],
)

region = self._extract_region(img, location=location, size=size, level=level, dtype=dtype)

Expand All @@ -731,31 +731,56 @@ def get_data(
if patch_size is None:
patches = region
else:
tuple_patch_size = ensure_tuple_rep(patch_size, 2)
patches = self._extract_patches(
region, patch_size=(patch_size, patch_size), grid_shape=grid_shape, dtype=dtype
region,
patch_size=tuple_patch_size, # type: ignore
grid_shape=grid_shape,
dtype=dtype,
)

return patches, metadata

def _extract_region(
self,
img_obj,
size: Tuple[int, int],
size: Optional[Tuple[int, int]],
location: Tuple[int, int] = (0, 0),
level: int = 0,
dtype: DtypeLike = np.uint8,
):
# reverse the order of dimensions for size and location to be compatible with image shape
size = size[::-1]
location = location[::-1]
region = img_obj.read_region(location=location, size=size, level=level)
if self.reader_lib == "openslide":
region = region.convert("RGB")
# convert to numpy
region = np.asarray(region, dtype=dtype)
if size is None:
region = img_obj.read_region(location=location, level=level)
Comment thread
wyli marked this conversation as resolved.
else:
size = size[::-1]
region = img_obj.read_region(location=location, size=size, level=level)

region = self.convert_to_rgb_array(region, dtype)
return region

def convert_to_rgb_array(
self,
raw_region,
dtype: DtypeLike = np.uint8,
):
"""Convert to RGB mode and numpy array"""
if self.reader_lib == "openslide":
# convert to RGB
raw_region = raw_region.convert("RGB")
# convert to numpy
raw_region = np.asarray(raw_region, dtype=dtype)
else:
num_channels = len(raw_region.channel_names)
# convert to numpy
raw_region = np.asarray(raw_region, dtype=dtype)
# remove alpha channel if exist (RGBA)
if num_channels > 3:
raw_region = raw_region[:, :, :3]

return raw_region

def _extract_patches(
self,
region: np.ndarray,
Expand Down
36 changes: 35 additions & 1 deletion tests/test_cuimage_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from monai.utils import optional_import

_, has_cim = optional_import("cucim")

PILImage, has_pil = optional_import("PIL.Image")

FILE_URL = "http://openslide.cs.cmu.edu/download/openslide-testdata/Generic-TIFF/CMU-1.tiff"
FILE_PATH = os.path.join(os.path.dirname(__file__), "testing_data", os.path.basename(FILE_URL))
Expand Down Expand Up @@ -62,6 +62,14 @@
np.array([[[[239]], [[239]], [[239]]], [[[243]], [[243]], [[243]]]]),
]

TEST_CASE_RGB_0 = [
np.ones((3, 2, 2), dtype=np.uint8), # CHW
]

TEST_CASE_RGB_1 = [
np.ones((3, 100, 100), dtype=np.uint8), # CHW
]


class TestCuCIMReader(unittest.TestCase):
@skipUnless(has_cim, "Requires CuCIM")
Expand Down Expand Up @@ -91,6 +99,32 @@ def test_read_patches(self, file_path, patch_info, expected_img):
self.assertTupleEqual(img.shape, expected_img.shape)
self.assertIsNone(assert_array_equal(img, expected_img))

@parameterized.expand([TEST_CASE_RGB_0, TEST_CASE_RGB_1])
@skipUnless(has_pil, "Requires PIL")
def test_read_rgba(self, img_expected):
image = {}
reader = WSIReader("cuCIM")
for mode in ["RGB", "RGBA"]:
file_path = self.create_rgba_image(img_expected, "test_cu_tiff_image", mode=mode)
img_obj = reader.read(file_path)
image[mode], _ = reader.get_data(img_obj)

self.assertIsNone(assert_array_equal(image["RGB"], img_expected))
self.assertIsNone(assert_array_equal(image["RGBA"], img_expected))

def create_rgba_image(self, array: np.ndarray, filename_prefix: str, mode: str):
file_path = os.path.join(os.path.dirname(__file__), "testing_data", f"{filename_prefix}_{mode}.tiff")

if mode == "RGBA":
array = np.concatenate([array, 255 * np.ones_like(array[0])[np.newaxis]]).astype(np.uint8)

img_rgb = array.transpose(1, 2, 0)

image = PILImage.fromarray(img_rgb, mode=mode)
image.save(file_path)

return file_path


if __name__ == "__main__":
unittest.main()