Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions monai/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@
GaussianSmooth,
GibbsNoise,
HistogramNormalize,
IntensityRemap,
KSpaceSpikeNoise,
MaskIntensity,
NormalizeIntensity,
Expand All @@ -98,6 +99,7 @@
RandGaussianSmooth,
RandGibbsNoise,
RandHistogramShift,
RandIntensityRemap,
RandKSpaceSpikeNoise,
RandRicianNoise,
RandScaleIntensity,
Expand Down
112 changes: 112 additions & 0 deletions monai/transforms/intensity/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@
"RandCoarseDropout",
"RandCoarseShuffle",
"HistogramNormalize",
"IntensityRemap",
"RandIntensityRemap",
]


Expand Down Expand Up @@ -2053,3 +2055,113 @@ def __call__(self, img: NdarrayOrTensor, mask: Optional[NdarrayOrTensor] = None)
out, *_ = convert_to_dst_type(src=ret, dst=img, dtype=self.dtype or img.dtype)

return out


class IntensityRemap(RandomizableTransform):
"""
Transform for intensity remapping of images. The intensity at each
pixel is replaced by a new values coming from an intensity remappping
curve.

The remapping curve is created by uniformly sampling values from the
possible intensities for the input image and then adding a linear
component. The curve is the rescaled to the input image intensity range.

Intended to be used as a means to data augmentation via:
:py:class:`monai.transforms.RandIntensityRemap`.

Implementation is described in the work:
`Intensity augmentation for domain transfer of whole breast segmentation
in MRI <https://ieeexplore.ieee.org/abstract/document/9166708>`_.

Args:
kernel_size: window size for averaging operation for the remapping
curve.
slope: slope of the linear component. Easiest to leave default value
and tune the kernel_size parameter instead.
return_map: set to True for the transform to return a dictionary version
of the lookup table used in the intensity remapping. The keys
correspond to the old intensities, and the values are the new
values.
"""

def __init__(self, kernel_size: int = 30, slope: float = 0.7):

super().__init__()

self.kernel_size = kernel_size
self.slope = slope

def __call__(self, img: torch.Tensor) -> torch.Tensor:
"""
Args:
img: image to remap.
"""

img = img.clone()
# sample noise
vals_to_sample = torch.unique(img).tolist()
noise = torch.from_numpy(self.R.choice(vals_to_sample, len(vals_to_sample) - 1 + self.kernel_size))
# smooth
noise = torch.nn.AvgPool1d(self.kernel_size, stride=1)(noise.unsqueeze(0)).squeeze()
# add linear component
grid = torch.arange(len(noise)) / len(noise)
noise += self.slope * grid
# rescale
noise = (noise - noise.min()) / (noise.max() - noise.min()) * img.max() + img.min()

# intensity remapping function
index_img = torch.bucketize(img, torch.tensor(vals_to_sample))
img = noise[index_img]

return img


class RandIntensityRemap(RandomizableTransform):
"""
Transform for intensity remapping of images. The intensity at each
pixel is replaced by a new values coming from an intensity remappping
curve.

The remapping curve is created by uniformly sampling values from the
possible intensities for the input image and then adding a linear
component. The curve is the rescaled to the input image intensity range.

Implementation is described in the work:
`Intensity augmentation for domain transfer of whole breast segmentation
in MRI <https://ieeexplore.ieee.org/abstract/document/9166708>`_.

Args:
prob: probability of applying the transform.
kernel_size: window size for averaging operation for the remapping
curve.
slope: slope of the linear component. Easiest to leave default value
and tune the kernel_size parameter instead.
channel_wise: set to True to treat each channel independently.
"""

def __init__(self, prob: float = 0.1, kernel_size: int = 30, slope: float = 0.7, channel_wise: bool = True):

RandomizableTransform.__init__(self, prob=prob)
self.kernel_size = kernel_size
self.slope = slope
self.channel_wise = True

def __call__(self, img: torch.Tensor) -> torch.Tensor:
"""
Args:
img: image to remap.
"""
super().randomize(None)
if self._do_transform:
if self.channel_wise:
img = torch.stack(
[
IntensityRemap(self.kernel_size, self.R.choice([-self.slope, self.slope]))(img[i])
for i in range(len(img))
]
)
else:
img = IntensityRemap(self.kernel_size, self.R.choice([-self.slope, self.slope]))(img)

return img