diff --git a/monai/apps/deepgrow/interaction.py b/monai/apps/deepgrow/interaction.py index 473ad5bac2..11b56086df 100644 --- a/monai/apps/deepgrow/interaction.py +++ b/monai/apps/deepgrow/interaction.py @@ -12,6 +12,7 @@ import torch +from monai.data import decollate_batch, list_data_collate from monai.engines import SupervisedEvaluator, SupervisedTrainer from monai.engines.utils import IterationEvents from monai.transforms import Compose @@ -74,6 +75,9 @@ def __call__(self, engine: Union[SupervisedTrainer, SupervisedEvaluator], batchd batchdata[self.key_probability] = torch.as_tensor( ([1.0 - ((1.0 / self.max_interactions) * j)] if self.train else [1.0]) * len(inputs) ) - batchdata = self.transforms(batchdata) + # decollate batch data to execute click transforms + batchdata_list = [self.transforms(i) for i in decollate_batch(batchdata, detach=True)] + # collate list into a batch for next round interaction + batchdata = list_data_collate(batchdata_list) return engine._iteration(engine, batchdata) diff --git a/monai/apps/deepgrow/transforms.py b/monai/apps/deepgrow/transforms.py index cfdeb5c87f..db450792b0 100644 --- a/monai/apps/deepgrow/transforms.py +++ b/monai/apps/deepgrow/transforms.py @@ -8,6 +8,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import json from typing import Callable, Dict, Optional, Sequence, Union import numpy as np @@ -144,7 +145,7 @@ def _apply(self, label, sid): def __call__(self, data): d = dict(data) self.randomize(data) - d[self.guidance] = self._apply(d[self.label], self.sid) + d[self.guidance] = json.dumps(self._apply(d[self.label], self.sid).astype(int).tolist()) return d @@ -159,7 +160,7 @@ class AddGuidanceSignald(Transform): guidance: key to store guidance. sigma: standard deviation for Gaussian kernel. number_intensity_ch: channel index. - batched: whether input is batched or not. + """ def __init__( @@ -168,17 +169,16 @@ def __init__( guidance: str = "guidance", sigma: int = 2, number_intensity_ch: int = 1, - batched: bool = False, ): self.image = image self.guidance = guidance self.sigma = sigma self.number_intensity_ch = number_intensity_ch - self.batched = batched def _get_signal(self, image, guidance): dimensions = 3 if len(image.shape) > 3 else 2 guidance = guidance.tolist() if isinstance(guidance, np.ndarray) else guidance + guidance = json.loads(guidance) if isinstance(guidance, str) else guidance if dimensions == 3: signal = np.zeros((len(guidance), image.shape[-3], image.shape[-2], image.shape[-1]), dtype=np.float32) else: @@ -210,16 +210,9 @@ def _get_signal(self, image, guidance): return signal def _apply(self, image, guidance): - if not self.batched: - signal = self._get_signal(image, guidance) - return np.concatenate([image, signal], axis=0) - - images = [] - for i, g in zip(image, guidance): - i = i[0 : 0 + self.number_intensity_ch, ...] - signal = self._get_signal(i, g) - images.append(np.concatenate([i, signal], axis=0)) - return images + signal = self._get_signal(image, guidance) + image = image[0 : 0 + self.number_intensity_ch, ...] + return np.concatenate([image, signal], axis=0) def __call__(self, data): d = dict(data) @@ -234,26 +227,17 @@ class FindDiscrepancyRegionsd(Transform): """ Find discrepancy between prediction and actual during click interactions during training. - If batched is true: - - label is in shape (B, C, D, H, W) or (B, C, H, W) - pred has same shape as label - discrepancy will have shape (B, 2, C, D, H, W) or (B, 2, C, H, W) - Args: label: key to label source. pred: key to prediction source. discrepancy: key to store discrepancies found between label and prediction. - batched: whether input is batched or not. + """ - def __init__( - self, label: str = "label", pred: str = "pred", discrepancy: str = "discrepancy", batched: bool = True - ): + def __init__(self, label: str = "label", pred: str = "pred", discrepancy: str = "discrepancy"): self.label = label self.pred = pred self.discrepancy = discrepancy - self.batched = batched @staticmethod def disparity(label, pred): @@ -266,13 +250,7 @@ def disparity(label, pred): return [pos_disparity, neg_disparity] def _apply(self, label, pred): - if not self.batched: - return self.disparity(label, pred) - - disparity = [] - for la, pr in zip(label, pred): - disparity.append(self.disparity(la, pr)) - return disparity + return self.disparity(label, pred) def __call__(self, data): d = dict(data) @@ -286,30 +264,16 @@ def __call__(self, data): class AddRandomGuidanced(Randomizable, Transform): """ Add random guidance based on discrepancies that were found between label and prediction. - - If batched is True, input shape is as below: - - Guidance is of shape (B, 2, N, # of dim) where B is batch size, 2 means positive and negative, - N means how many guidance points, # of dim is the total number of dimensions of the image - (for example if the image is CDHW, then # of dim would be 4). - - Discrepancy is of shape (B, 2, C, D, H, W) or (B, 2, C, H, W) - - Probability is of shape (B, 1) - - else: - - Guidance is of shape (2, N, # of dim) - - Discrepancy is of shape (2, C, D, H, W) or (2, C, H, W) - - Probability is of shape (1) + input shape is as below: + Guidance is of shape (2, N, # of dim) + Discrepancy is of shape (2, C, D, H, W) or (2, C, H, W) + Probability is of shape (1) Args: guidance: key to guidance source. discrepancy: key that represents discrepancies found between label and prediction. probability: key that represents click/interaction probability. - batched: whether input is batched or not. + """ def __init__( @@ -317,22 +281,15 @@ def __init__( guidance: str = "guidance", discrepancy: str = "discrepancy", probability: str = "probability", - batched: bool = True, ): self.guidance = guidance self.discrepancy = discrepancy self.probability = probability - self.batched = batched self._will_interact = None def randomize(self, data=None): probability = data[self.probability] - if not self.batched: - self._will_interact = self.R.choice([True, False], p=[probability, 1.0 - probability]) - else: - self._will_interact = [] - for p in probability: - self._will_interact.append(self.R.choice([True, False], p=[p, 1.0 - p])) + self._will_interact = self.R.choice([True, False], p=[probability, 1.0 - probability]) def find_guidance(self, discrepancy): distance = distance_transform_cdt(discrepancy).flatten() @@ -368,24 +325,16 @@ def add_guidance(self, discrepancy, will_interact): def _apply(self, guidance, discrepancy): guidance = guidance.tolist() if isinstance(guidance, np.ndarray) else guidance - if not self.batched: - pos, neg = self.add_guidance(discrepancy, self._will_interact) - if pos: - guidance[0].append(pos) - guidance[1].append([-1] * len(pos)) - if neg: - guidance[0].append([-1] * len(neg)) - guidance[1].append(neg) - else: - for g, d, w in zip(guidance, discrepancy, self._will_interact): - pos, neg = self.add_guidance(d, w) - if pos: - g[0].append(pos) - g[1].append([-1] * len(pos)) - if neg: - g[0].append([-1] * len(neg)) - g[1].append(neg) - return np.asarray(guidance) + guidance = json.loads(guidance) if isinstance(guidance, str) else guidance + pos, neg = self.add_guidance(discrepancy, self._will_interact) + if pos: + guidance[0].append(pos) + guidance[1].append([-1] * len(pos)) + if neg: + guidance[0].append([-1] * len(neg)) + guidance[1].append(neg) + + return json.dumps(np.asarray(guidance).astype(int).tolist()) def __call__(self, data): d = dict(data) diff --git a/tests/test_deepgrow_interaction.py b/tests/test_deepgrow_interaction.py index 77c37bf5f3..e5c7bf9051 100644 --- a/tests/test_deepgrow_interaction.py +++ b/tests/test_deepgrow_interaction.py @@ -11,17 +11,25 @@ import unittest +import numpy as np import torch from monai.apps.deepgrow.interaction import Interaction +from monai.apps.deepgrow.transforms import ( + AddGuidanceSignald, + AddInitialSeedPointd, + AddRandomGuidanced, + FindAllValidSlicesd, + FindDiscrepancyRegionsd, +) from monai.data import Dataset from monai.engines import SupervisedTrainer from monai.engines.utils import IterationEvents -from monai.transforms import Activationsd, Compose, ToNumpyd +from monai.transforms import Activationsd, Compose, ToNumpyd, ToTensord def add_one(engine): - if engine.state.best_metric is -1: + if engine.state.best_metric == -1: engine.state.best_metric = 0 else: engine.state.best_metric = engine.state.best_metric + 1 @@ -29,21 +37,34 @@ def add_one(engine): class TestInteractions(unittest.TestCase): def run_interaction(self, train, compose): - data = [] - for i in range(5): - data.append({"image": torch.tensor([float(i)]), "label": torch.tensor([float(i)])}) - network = torch.nn.Linear(1, 1) + data = [{"image": np.ones((1, 2, 2, 2)).astype(np.float32), "label": np.ones((1, 2, 2, 2))} for _ in range(5)] + network = torch.nn.Linear(2, 2) lr = 1e-3 opt = torch.optim.SGD(network.parameters(), lr) loss = torch.nn.L1Loss() - dataset = Dataset(data, transform=None) + train_transforms = Compose( + [ + FindAllValidSlicesd(label="label", sids="sids"), + AddInitialSeedPointd(label="label", guidance="guidance", sids="sids"), + AddGuidanceSignald(image="image", guidance="guidance"), + ToTensord(keys=("image", "label")), + ] + ) + dataset = Dataset(data, transform=train_transforms) data_loader = torch.utils.data.DataLoader(dataset, batch_size=5) - iteration_transforms = [Activationsd(keys="pred", sigmoid=True), ToNumpyd(keys="pred")] + iteration_transforms = [ + Activationsd(keys="pred", sigmoid=True), + ToNumpyd(keys=["image", "label", "pred"]), + FindDiscrepancyRegionsd(label="label", pred="pred", discrepancy="discrepancy"), + AddRandomGuidanced(guidance="guidance", discrepancy="discrepancy", probability="probability"), + AddGuidanceSignald(image="image", guidance="guidance"), + ToTensord(keys=("image", "label")), + ] iteration_transforms = Compose(iteration_transforms) if compose else iteration_transforms i = Interaction(transforms=iteration_transforms, train=train, max_interactions=5) - self.assertEqual(len(i.transforms.transforms), 2, "Mismatch in expected transforms") + self.assertEqual(len(i.transforms.transforms), 6, "Mismatch in expected transforms") # set up engine engine = SupervisedTrainer( diff --git a/tests/test_deepgrow_transforms.py b/tests/test_deepgrow_transforms.py index 83bb5ebaa4..f50e92d146 100644 --- a/tests/test_deepgrow_transforms.py +++ b/tests/test_deepgrow_transforms.py @@ -30,8 +30,6 @@ IMAGE = np.array([[[[1, 0, 2, 0, 1], [0, 1, 2, 1, 0], [2, 2, 3, 2, 2], [0, 1, 2, 1, 0], [1, 0, 2, 0, 1]]]]) LABEL = np.array([[[[0, 0, 0, 0, 0], [0, 1, 0, 1, 0], [0, 0, 1, 0, 0], [0, 1, 0, 1, 0], [0, 0, 0, 0, 0]]]]) -BATCH_IMAGE = np.array([[[[[1, 0, 2, 0, 1], [0, 1, 2, 1, 0], [2, 2, 3, 2, 2], [0, 1, 2, 1, 0], [1, 0, 2, 0, 1]]]]]) -BATCH_LABEL = np.array([[[[[0, 0, 0, 0, 0], [0, 1, 0, 1, 0], [0, 0, 1, 0, 0], [0, 1, 0, 1, 0], [0, 0, 0, 0, 0]]]]]) DATA_1 = { "image": IMAGE, @@ -61,24 +59,22 @@ } DATA_3 = { - "image": BATCH_IMAGE, - "label": BATCH_LABEL, - "pred": np.array([[[[[0, 0, 0, 0, 0], [0, 1, 0, 0, 0], [0, 1, 1, 0, 0], [0, 1, 0, 1, 0], [0, 0, 0, 0, 0]]]]]), + "image": IMAGE, + "label": LABEL, + "pred": np.array([[[[0, 0, 0, 0, 0], [0, 1, 0, 0, 0], [0, 1, 1, 0, 0], [0, 1, 0, 1, 0], [0, 0, 0, 0, 0]]]]), } DATA_4 = { - "image": BATCH_IMAGE, - "label": BATCH_LABEL, - "guidance": np.array([[[[1, 0, 2, 2]], [[-1, -1, -1, -1]]]]), + "image": IMAGE, + "label": LABEL, + "guidance": np.array([[[1, 0, 2, 2]], [[-1, -1, -1, -1]]]), "discrepancy": np.array( [ - [ - [[[[0, 0, 0, 0, 0], [0, 0, 0, 1, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]], - [[[[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 1, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]], - ] + [[[[0, 0, 0, 0, 0], [0, 0, 0, 1, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]], + [[[[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 1, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]], ] ), - "probability": [1.0], + "probability": 1.0, } DATA_5 = { @@ -192,11 +188,11 @@ ADD_INITIAL_POINT_TEST_CASE_1 = [ {"label": "label", "guidance": "guidance", "sids": "sids"}, DATA_1, - np.array([[[1, 0, 2, 2]], [[-1, -1, -1, -1]]]), + "[[[1, 0, 2, 2]], [[-1, -1, -1, -1]]]", ] ADD_GUIDANCE_TEST_CASE_1 = [ - {"image": "image", "guidance": "guidance", "batched": False}, + {"image": "image", "guidance": "guidance"}, DATA_2, np.array( [ @@ -233,18 +229,16 @@ DATA_3, np.array( [ - [ - [[[[0, 0, 0, 0, 0], [0, 0, 0, 1, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]], - [[[[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 1, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]], - ] + [[[[0, 0, 0, 0, 0], [0, 0, 0, 1, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]], + [[[[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 1, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]], ] ), ] ADD_RANDOM_GUIDANCE_TEST_CASE_1 = [ - {"guidance": "guidance", "discrepancy": "discrepancy", "probability": "probability", "batched": True}, + {"guidance": "guidance", "discrepancy": "discrepancy", "probability": "probability"}, DATA_4, - np.array([[[[1, 0, 2, 2], [1, 0, 1, 3]], [[-1, -1, -1, -1], [-1, -1, -1, -1]]]]), + "[[[1, 0, 2, 2], [1, 0, 1, 3]], [[-1, -1, -1, -1], [-1, -1, -1, -1]]]", ] ADD_GUIDANCE_FROM_POINTS_TEST_CASE_1 = [ @@ -398,7 +392,7 @@ def test_correct_results(self, arguments, input_data, expected_result): add_fn = AddInitialSeedPointd(**arguments) add_fn.set_random_state(seed) result = add_fn(input_data) - np.testing.assert_allclose(result[arguments["guidance"]], expected_result) + self.assertEqual(result[arguments["guidance"]], expected_result) class TestAddGuidanceSignald(unittest.TestCase): @@ -422,7 +416,7 @@ def test_correct_results(self, arguments, input_data, expected_result): add_fn = AddRandomGuidanced(**arguments) add_fn.set_random_state(seed) result = add_fn(input_data) - np.testing.assert_allclose(result[arguments["guidance"]], expected_result, rtol=1e-5) + self.assertEqual(result[arguments["guidance"]], expected_result) class TestAddGuidanceFromPointsd(unittest.TestCase):