diff --git a/tests/test_handler_transform_inverter.py b/tests/test_handler_transform_inverter.py index 7be63e7638..524f13a760 100644 --- a/tests/test_handler_transform_inverter.py +++ b/tests/test_handler_transform_inverter.py @@ -89,6 +89,18 @@ def _train_func(engine, batch): output_keys=["image", "label"], batch_keys="label", nearest_interp=True, + postfix="inverted1", + num_workers=0 if sys.platform == "darwin" or torch.cuda.is_available() else 2, + ).attach(engine) + + # test different nearest interpolation values + TransformInverter( + transform=transform, + loader=loader, + output_keys=["image", "label"], + batch_keys="image", + nearest_interp=[True, False], + postfix="inverted2", num_workers=0 if sys.platform == "darwin" or torch.cuda.is_available() else 2, ).attach(engine) @@ -96,11 +108,12 @@ def _train_func(engine, batch): set_determinism(seed=None) self.assertTupleEqual(engine.state.output["image"].shape, (2, 1, 100, 100, 100)) self.assertTupleEqual(engine.state.output["label"].shape, (2, 1, 100, 100, 100)) - for i in engine.state.output["image_inverted"] + engine.state.output["label_inverted"]: + # check the nearest inerpolation mode + for i in engine.state.output["image_inverted1"] + engine.state.output["label_inverted1"]: torch.testing.assert_allclose(i.to(torch.uint8).to(torch.float), i.to(torch.float)) self.assertTupleEqual(i.shape, (1, 100, 101, 107)) # check labels match - reverted = engine.state.output["label_inverted"][-1].detach().cpu().numpy()[0].astype(np.int32) + reverted = engine.state.output["label_inverted1"][-1].detach().cpu().numpy()[0].astype(np.int32) original = LoadImaged(KEYS)(data[-1])["label"] n_good = np.sum(np.isclose(reverted, original, atol=1e-3)) reverted_name = engine.state.output["label_meta_dict"]["filename_or_obj"][-1] @@ -112,6 +125,17 @@ def _train_func(engine, batch): # 1824: torch 1.5.1 self.assertTrue((reverted.size - n_good) in (25300, 1812, 1824), "diff. in 3 possible values") + # check the case that different items use different interpolation mode to invert transforms + for i in engine.state.output["image_inverted2"]: + # if the interpolation mode is nearest, accumulated diff should be smaller than 1 + self.assertLess(torch.sum(i.to(torch.float) - i.to(torch.uint8).to(torch.float)).item(), 1.0) + self.assertTupleEqual(i.shape, (1, 100, 101, 107)) + + for i in engine.state.output["label_inverted2"]: + # if the interpolation mode is not nearest, accumulated diff should be greater than 10000 + self.assertGreater(torch.sum(i.to(torch.float) - i.to(torch.uint8).to(torch.float)).item(), 10000.0) + self.assertTupleEqual(i.shape, (1, 100, 101, 107)) + if __name__ == "__main__": unittest.main()