From 403b3ade537bd4ba6e39338885d7b1c43eaadcdb Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Sun, 20 Feb 2022 19:24:19 +0000 Subject: [PATCH] fixes #3788 according to the suggestions Signed-off-by: Wenqi Li --- monai/metrics/surface_distance.py | 20 +++++--------------- tests/test_surface_distance.py | 14 +++++++------- 2 files changed, 12 insertions(+), 22 deletions(-) diff --git a/monai/metrics/surface_distance.py b/monai/metrics/surface_distance.py index fce4b735e5..efa0d17da4 100644 --- a/monai/metrics/surface_distance.py +++ b/monai/metrics/surface_distance.py @@ -16,7 +16,7 @@ import torch from monai.metrics.utils import do_metric_reduction, get_mask_edges, get_surface_distance, ignore_background -from monai.utils import MetricReduction +from monai.utils import MetricReduction, convert_data_type from .metric import CumulativeIterationMetric @@ -153,20 +153,10 @@ def compute_average_surface_distance( warnings.warn(f"the ground truth of class {c} is all 0, this may result in nan/inf distance.") if not np.any(edges_pred): warnings.warn(f"the prediction of class {c} is all 0, this may result in nan/inf distance.") - surface_distance = get_surface_distance(edges_pred, edges_gt, distance_metric=distance_metric) - if surface_distance.shape == (0,): - avg_surface_distance = np.nan - else: - avg_surface_distance = surface_distance.mean() - if not symmetric: - asd[b, c] = avg_surface_distance - else: + if symmetric: surface_distance_2 = get_surface_distance(edges_gt, edges_pred, distance_metric=distance_metric) - if surface_distance_2.shape == (0,): - avg_surface_distance_2 = np.nan - else: - avg_surface_distance_2 = surface_distance_2.mean() - asd[b, c] = np.mean((avg_surface_distance, avg_surface_distance_2)) + surface_distance = np.concatenate([surface_distance, surface_distance_2]) + asd[b, c] = np.nan if surface_distance.shape == (0,) else surface_distance.mean() - return torch.from_numpy(asd) + return convert_data_type(asd, torch.Tensor)[0] diff --git a/tests/test_surface_distance.py b/tests/test_surface_distance.py index 781c71c23a..edfe9e8663 100644 --- a/tests/test_surface_distance.py +++ b/tests/test_surface_distance.py @@ -61,14 +61,14 @@ def create_spherical_seg_3d( create_spherical_seg_3d(radius=33, centre=(19, 33, 22)), create_spherical_seg_3d(radius=33, centre=(20, 33, 22)), ], - [0.35021200688332677, 0.3483278807706289], + [0.350217, 0.3483278807706289], ], [ [ create_spherical_seg_3d(radius=20, centre=(20, 33, 22)), create_spherical_seg_3d(radius=40, centre=(20, 33, 22)), ], - [13.975673696300824, 12.040033513150455], + [15.117741, 12.040033513150455], ], [ [ @@ -76,7 +76,7 @@ def create_spherical_seg_3d( create_spherical_seg_3d(radius=40, centre=(20, 33, 22)), "chessboard", ], - [10.792254295459173, 9.605067064083457], + [11.492719, 9.605067064083457], ], [ [ @@ -84,7 +84,7 @@ def create_spherical_seg_3d( create_spherical_seg_3d(radius=40, centre=(20, 33, 22)), "taxicab", ], - [17.32691760951026, 12.432687531048186], + [20.214613, 12.432687531048186], ], [[np.zeros([99, 99, 99]), create_spherical_seg_3d(radius=40, centre=(20, 33, 22))], [np.inf, np.inf]], [[create_spherical_seg_3d(), np.zeros([99, 99, 99]), "taxicab"], [np.inf, np.inf]], @@ -121,7 +121,7 @@ def test_value(self, input_data, expected_value): sur_metric(batch_seg_1, batch_seg_2) result = sur_metric.aggregate() expected_value_curr = expected_value[ct] - np.testing.assert_allclose(expected_value_curr, result, rtol=1e-7) + np.testing.assert_allclose(expected_value_curr, result, rtol=1e-5) ct += 1 @parameterized.expand(TEST_CASES_NANS) @@ -135,8 +135,8 @@ def test_nans(self, input_data): batch_seg_2 = [seg_2.unsqueeze(0)] sur_metric(batch_seg_1, batch_seg_2) result, not_nans = sur_metric.aggregate() - np.testing.assert_allclose(0, result, rtol=1e-7) - np.testing.assert_allclose(0, not_nans, rtol=1e-7) + np.testing.assert_allclose(0, result, rtol=1e-5) + np.testing.assert_allclose(0, not_nans, rtol=1e-5) if __name__ == "__main__":