Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions monai/data/nifti_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,9 @@ def write_nifti(

if allclose(affine, target_affine, atol=1e-3): # type: ignore
# no affine changes, save (data, affine)
results_img = nib.Nifti1Image(data.astype(output_dtype, copy=False), to_affine_nd(3, target_affine)) # type: ignore
results_img = nib.Nifti1Image(
data.astype(output_dtype, copy=False), to_affine_nd(3, target_affine) # type: ignore
)
nib.save(results_img, file_name)
return

Expand Down Expand Up @@ -163,6 +165,8 @@ def write_nifti(
)
data_np = data_torch.squeeze(0).squeeze(0).detach().cpu().numpy()

results_img = nib.Nifti1Image(data_np.astype(output_dtype, copy=False), to_affine_nd(3, target_affine)) # type: ignore
results_img = nib.Nifti1Image(
data_np.astype(output_dtype, copy=False), to_affine_nd(3, target_affine) # type: ignore
)
nib.save(results_img, file_name)
return
7 changes: 4 additions & 3 deletions monai/data/test_time_augmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from monai.data.dataset import Dataset
from monai.data.utils import decollate_batch, pad_list_data_collate
from monai.transforms.compose import Compose
from monai.transforms.croppad.batch import PadListDataCollate
from monai.transforms.inverse import InvertibleTransform
from monai.transforms.post.dictionary import Invertd
from monai.transforms.transform import Randomizable
Expand Down Expand Up @@ -189,10 +190,10 @@ def __call__(

outs: List = []

for batch_data in tqdm(dl) if has_tqdm and self.progress else dl:
for b in tqdm(dl) if has_tqdm and self.progress else dl:
# do model forward pass
batch_data[self._pred_key] = self.inferrer_fn(batch_data[self.image_key].to(self.device))
outs.extend([self.inverter(i)[self._pred_key] for i in decollate_batch(batch_data)])
b[self._pred_key] = self.inferrer_fn(b[self.image_key].to(self.device))
outs.extend([self.inverter(PadListDataCollate.inverse(i))[self._pred_key] for i in decollate_batch(b)])

output: NdarrayOrTensor = stack(outs, 0)

Expand Down
8 changes: 4 additions & 4 deletions monai/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,10 +492,10 @@ def pad_list_data_collate(
tensor in each dimension. This transform is useful if some of the applied transforms generate batch data of
different sizes.

This can be used on both list and dictionary data. In the case of the dictionary data, this transform will be added
to the list of invertible transforms.

The inverse can be called using the static method: `monai.transforms.croppad.batch.PadListDataCollate.inverse`.
This can be used on both list and dictionary data.
Note that in the case of the dictionary data, this decollate function may add the transform information of
`PadListDataCollate` to the list of invertible transforms if input batch have different spatial shape, so need to
call static method: `monai.transforms.croppad.batch.PadListDataCollate.inverse` before inverting other transforms.

Args:
batch: batch of data to pad-collate
Expand Down
5 changes: 3 additions & 2 deletions monai/transforms/croppad/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,9 @@ class PadListDataCollate(InvertibleTransform):
tensor in each dimension. This transform is useful if some of the applied transforms generate batch data of
different sizes.

This can be used on both list and dictionary data. In the case of the dictionary data, this transform will be added
to the list of invertible transforms.
This can be used on both list and dictionary data.
Note that in the case of the dictionary data, it may add the transform information to the list of invertible transforms
if input batch have different spatial shape, so need to call static method: `inverse` before inverting other transforms.

Note that normally, a user won't explicitly use the `__call__` method. Rather this would be passed to the `DataLoader`.
This means that `__call__` handles data as it comes out of a `DataLoader`, containing batch dimension. However, the
Expand Down
4 changes: 3 additions & 1 deletion monai/transforms/spatial/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,9 @@ def __call__(
_t_r[idx, -1] = (max(d_dst, 2) - 1.0) / 2.0
xform = xform @ _t_r
if not USE_COMPILED:
_t_l = normalize_transform(in_spatial_size, xform.device, xform.dtype, align_corners=True) # type: ignore
_t_l = normalize_transform(
in_spatial_size, xform.device, xform.dtype, align_corners=True # type: ignore
)
xform = _t_l @ xform # type: ignore
affine_xform = Affine(
affine=xform, spatial_size=spatial_size, norm_coords=False, image_only=True, dtype=_dtype
Expand Down
2 changes: 1 addition & 1 deletion tests/test_testtimeaugmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def tearDown(self) -> None:
set_determinism(None)

def test_test_time_augmentation(self):
input_size = (20, 20)
input_size = (20, 40) # test different input data shape to pad list collate
keys = ["image", "label"]
num_training_ims = 10

Expand Down