Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 18 additions & 3 deletions monai/data/test_time_augmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union

import numpy as np
import torch
Expand All @@ -23,6 +23,14 @@
from monai.transforms.transform import Randomizable
from monai.transforms.utils import allow_missing_keys_mode
from monai.utils.enums import CommonKeys, InverseKeys
from monai.utils.module import optional_import

if TYPE_CHECKING:
from tqdm import tqdm

has_tqdm = True
else:
tqdm, has_tqdm = optional_import("tqdm", name="tqdm")

__all__ = ["TestTimeAugmentation"]

Expand Down Expand Up @@ -57,6 +65,7 @@ class TestTimeAugmentation:
return_full_data: normally, metrics are returned (mode, mean, std, vvc). Setting this flag to `True` will return the
full data. Dimensions will be same size as when passing a single image through `inferrer_fn`, with a dimension appended
equal in size to `num_examples` (N), i.e., `[N,C,H,W,[D]]`.
progress: whether to display a progress bar.

Example:
.. code-block:: python
Expand All @@ -80,6 +89,7 @@ def __init__(
image_key=CommonKeys.IMAGE,
label_key=CommonKeys.LABEL,
return_full_data: bool = False,
progress: bool = True,
) -> None:
self.transform = transform
self.batch_size = batch_size
Expand All @@ -89,6 +99,7 @@ def __init__(
self.image_key = image_key
self.label_key = label_key
self.return_full_data = return_full_data
self.progress = progress

# check that the transform has at least one random component, and that all random transforms are invertible
self._check_transforms()
Expand Down Expand Up @@ -143,7 +154,7 @@ def __call__(

outputs: List[np.ndarray] = []

for batch_data in dl:
for batch_data in tqdm(dl) if has_tqdm and self.progress else dl:

batch_images = batch_data[self.image_key].to(self.device)

Expand All @@ -156,6 +167,10 @@ def __call__(

# create a dictionary containing the inferred batch and their transforms
inferred_dict = {self.label_key: batch_output, label_transform_key: batch_data[label_transform_key]}
# if meta dict is present, add that too (required for some inverse transforms)
label_meta_dict_key = self.label_key + "_meta_dict"
if label_meta_dict_key in batch_data:
inferred_dict[label_meta_dict_key] = batch_data[label_meta_dict_key]

# do inverse transformation (allow missing keys as only inverting label)
with allow_missing_keys_mode(self.transform): # type: ignore
Expand All @@ -171,7 +186,7 @@ def __call__(
return output

# calculate metrics
mode: np.ndarray = np.apply_along_axis(lambda x: np.bincount(x).argmax(), axis=0, arr=output.astype(np.int64))
mode = np.array(torch.mode(torch.Tensor(output.astype(np.int64)), dim=0).values)
mean: np.ndarray = np.mean(output, axis=0) # type: ignore
std: np.ndarray = np.std(output, axis=0) # type: ignore
vvc: float = (np.std(output) / np.mean(output)).item()
Expand Down
1 change: 1 addition & 0 deletions tests/min_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ def run_testsuit():
"test_ensure_channel_firstd",
"test_handler_early_stop",
"test_handler_transform_inverter",
"test_testtimeaugmentation",
]
assert sorted(exclude_cases) == sorted(set(exclude_cases)), f"Duplicated items in {exclude_cases}"

Expand Down
25 changes: 22 additions & 3 deletions tests/test_testtimeaugmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,29 +23,37 @@
from monai.networks.nets import UNet
from monai.transforms import Activations, AddChanneld, AsDiscrete, Compose, CropForegroundd, DivisiblePadd, RandAffined
from monai.transforms.croppad.dictionary import SpatialPadd
from monai.transforms.spatial.dictionary import Rand2DElasticd, RandFlipd
from monai.transforms.spatial.dictionary import Rand2DElasticd, RandFlipd, Spacingd
from monai.utils import optional_import, set_determinism

if TYPE_CHECKING:
import tqdm

has_tqdm = True
has_nib = True
else:
tqdm, has_tqdm = optional_import("tqdm")
_, has_nib = optional_import("nibabel")

trange = partial(tqdm.trange, desc="training") if has_tqdm else range


class TestTestTimeAugmentation(unittest.TestCase):
@staticmethod
def get_data(num_examples, input_size):
def get_data(num_examples, input_size, include_label=True):
custom_create_test_image_2d = partial(
create_test_image_2d, *input_size, rad_max=7, num_seg_classes=1, num_objs=1
)
data = []
for _ in range(num_examples):
im, label = custom_create_test_image_2d()
data.append({"image": im, "label": label})
d = {}
d["image"] = im
d["image_meta_dict"] = {"affine": np.eye(4)}
if include_label:
d["label"] = label
d["label_meta_dict"] = {"affine": np.eye(4)}
data.append(d)
return data[0] if num_examples == 1 else data

def setUp(self) -> None:
Expand Down Expand Up @@ -138,6 +146,17 @@ def test_single_transform(self):
tta = TestTimeAugmentation(transforms, batch_size=5, num_workers=0, inferrer_fn=lambda x: x)
tta(self.get_data(1, (20, 20)))

def test_image_no_label(self):
transforms = RandFlipd(["image"])
tta = TestTimeAugmentation(transforms, batch_size=5, num_workers=0, inferrer_fn=lambda x: x, label_key="image")
tta(self.get_data(1, (20, 20), include_label=False))

@unittest.skipUnless(has_nib, "Requires nibabel")
def test_requires_meta_dict(self):
transforms = Compose([RandFlipd("image"), Spacingd("image", (1, 1))])
tta = TestTimeAugmentation(transforms, batch_size=5, num_workers=0, inferrer_fn=lambda x: x, label_key="image")
tta(self.get_data(1, (20, 20), include_label=False))


if __name__ == "__main__":
unittest.main()