Skip to content
This repository was archived by the owner on Mar 21, 2024. It is now read-only.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ gets uploaded to AzureML, by skipping all test folders.

### Fixed

- ([#682](https://github.com/microsoft/InnerEye-DeepLearning/pull/682)) Ensure the shape of input patches is compatible with model constraints.
- ([#681](https://github.com/microsoft/InnerEye-DeepLearning/pull/681)) Pad model outputs if they are smaller than the inputs.
- ([#683](https://github.com/microsoft/InnerEye-DeepLearning/pull/683)) Fix missing separator error in docs Makefile.
- ([#659](https://github.com/microsoft/InnerEye-DeepLearning/pull/659)) Fix caching and checkpointing for TCGA CRCk dataset.
Expand Down
45 changes: 37 additions & 8 deletions InnerEye/ML/pipelines/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import logging
from enum import Enum
from pathlib import Path
from typing import Optional, Tuple
from typing import Optional, Tuple, Dict

import numpy as np
import torch
Expand Down Expand Up @@ -235,7 +235,7 @@ def post_process_posteriors(self, posteriors: np.ndarray, mask: np.ndarray = Non
@torch.no_grad()
def predict_whole_image(self, image_channels: np.ndarray,
voxel_spacing_mm: TupleFloat3,
mask: np.ndarray = None,
mask: Optional[np.ndarray] = None,
patient_id: int = 0) -> InferencePipeline.Result:
"""
Performs a single inference pass through the pipeline for the provided image
Expand All @@ -255,12 +255,26 @@ def predict_whole_image(self, image_channels: np.ndarray,
self.model.eval()

image = tio.ScalarImage(tensor=image_channels)
subject = tio.Subject(image=image)
INPUT = 'input_image'
MASK = 'mask'

subject_dict: Dict[str, tio.Image] = {INPUT: image}
if mask is not None:
subject_dict[MASK] = tio.LabelMap(tensor=mask[np.newaxis])
subject = tio.Subject(subject_dict)

constraints = self.model.model.crop_size_constraints

# Make sure the image size is compatible with the model
multiple_constraints = constraints.multiple_of # type: ignore
if multiple_constraints is not None:
ensure_shape_multiple = tio.EnsureShapeMultiple(constraints.multiple_of) # type: ignore
subject = ensure_shape_multiple(subject) # type: ignore

# There may be cases where the test image is smaller than the test_crop_size. Adjust crop_size
# to always fit into image. If test_crop_size is smaller than the image, crop will remain unchanged.
restrict_patch_size = self.model.model.crop_size_constraints.restrict_crop_size_to_image # type: ignore
effective_patch_size, effective_stride = restrict_patch_size(image.spatial_shape, # type: ignore
restrict_patch_size = constraints.restrict_crop_size_to_image # type: ignore
effective_patch_size, effective_stride = restrict_patch_size(subject.spatial_shape, # type: ignore
self.model_config.test_crop_size,
self.model_config.inference_stride_size)

Expand All @@ -276,10 +290,10 @@ def predict_whole_image(self, image_channels: np.ndarray,
aggregator = tio.inference.GridAggregator(grid_sampler)

logging.debug(
f"Inference on image size {image.spatial_shape} will run "
f"Inference on image size {subject.spatial_shape} will run "
f"with crop size {effective_patch_size} and stride {effective_stride}")
for patches_batch in patch_loader:
input_tensor = patches_batch['image'][tio.DATA].float()
input_tensor = patches_batch[INPUT][tio.DATA].float()
if self.model_config.use_gpu:
input_tensor = input_tensor.cuda()
locations = patches_batch[tio.LOCATION]
Expand All @@ -296,9 +310,24 @@ def predict_whole_image(self, image_channels: np.ndarray,
# collect the predictions over each of the batches
aggregator.add_batch(patches_posteriors, locations)
posteriors = aggregator.get_output_tensor().numpy()
posteriors, segmentation = self.post_process_posteriors(posteriors, mask=mask)
posteriors_mask = None if mask is None else subject[MASK].numpy()[0]
posteriors, segmentation = self.post_process_posteriors(posteriors, mask=posteriors_mask)

image_util.check_array_range(posteriors, error_prefix="Whole image posteriors")

# Make sure the final shape matches the input shape by undoing the padding in EnsureShapeMultiple (if any)
posteriors_image = tio.ScalarImage(tensor=posteriors, affine=image.affine)
segmentation_image = tio.LabelMap(tensor=segmentation[np.newaxis], affine=image.affine)
subject.add_image(posteriors_image, 'posteriors')
subject.add_image(segmentation_image, 'segmentation')
# Remove some images to avoid unnecessary computations
subject.remove_image(INPUT)
if mask is not None:
subject.remove_image(MASK)
subject_original_space = subject.apply_inverse_transform() if subject.applied_transforms else subject
posteriors = subject_original_space.posteriors.numpy() # type: ignore
segmentation = subject_original_space.segmentation.numpy()[0] # type: ignore

# prepare pipeline results from the processed batch
return InferencePipeline.Result(
patient_id=patient_id,
Expand Down
9 changes: 0 additions & 9 deletions Tests/ML/pipelines/test_inference_smallimages.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,15 +59,6 @@ def run_inference_on_unet(size: TupleInt3) -> None:
image_util.check_array_range(p)


def test_inference_on_too_small_image() -> None:
"""

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why delete a test?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because images will never be too small after merging

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The other PR does not have tests. Is it not possible to have a test to demonstrate that this is correct?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Running inference on a simplified Unet model when the input image is too small along an axis.
"""
with pytest.raises(ValueError) as ex:
run_inference_on_unet((5, 10, 64))
assert "input image must have at least a size of (16, 16, 16)" in str(ex)


@pytest.mark.parametrize("size", [(26, 20, 50), (16, 16, 16)])
def test_inference_on_small_image(size: TupleInt3) -> None:
"""
Expand Down