From e552896397c66d084cac906d298d0321737f619b Mon Sep 17 00:00:00 2001 From: charliebudd Date: Mon, 25 Jan 2021 14:30:47 +0000 Subject: [PATCH 01/16] conditional random field implementation Signed-off-by: charliebudd --- monai/networks/blocks/__init__.py | 1 + monai/networks/blocks/crf.py | 141 ++++++++++++++++++++++++++++++ 2 files changed, 142 insertions(+) create mode 100644 monai/networks/blocks/crf.py diff --git a/monai/networks/blocks/__init__.py b/monai/networks/blocks/__init__.py index c33feb4e2b..69eb31f8c7 100644 --- a/monai/networks/blocks/__init__.py +++ b/monai/networks/blocks/__init__.py @@ -13,6 +13,7 @@ from .activation import Mish, Swish from .aspp import SimpleASPP from .convolutions import Convolution, ResidualUnit +from .crf import CRF from .downsample import MaxAvgPool from .dynunet_block import UnetBasicBlock, UnetOutBlock, UnetResBlock, UnetUpBlock, get_output_padding, get_padding from .fcn import FCN, GCN, MCFCN, Refine diff --git a/monai/networks/blocks/crf.py b/monai/networks/blocks/crf.py new file mode 100644 index 0000000000..df0faf848b --- /dev/null +++ b/monai/networks/blocks/crf.py @@ -0,0 +1,141 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from torch.nn.functional import conv1d, conv2d, conv3d, pad, softmax + +from monai.networks.layers.filtering import PHLFilter + +__all__ = ["CRF"] + + +class CRF(torch.nn.Module): + """ + Conditional Random Field: Combines message passing with a class + compatability convolution into an iterative process designed + to successively minimise the energy of the class labeling. + + In this implementation, the message passing step is a weighted + combination of a gaussian filter and a bilateral filter. + The bilateral term is included to respect existing structure + within the reference tensor. + + See: + https://arxiv.org/abs/1502.03240 + + Args: + input_tensor: tensor containing initial class logits. + + referenece_tensor: the reference tensor used to guide the message passing. + + bilateral_weight: the weighting of the bilateral term in the message passing step + + gaussian_weight: the weighting of the gaussian term in the message passing step + + bilateral_spatial_sigma: standard deviation in spatial coordinates for the bilateral term + + bilateral_color_sigma: standard deviation in color space for the bilateral term + + gaussian_spatial_sigma: standard deviation in spatial coordinates for the gaussian term + + compatability_kernel_range: the range of the kernel used in the compatability convolution + + iterations: the number of iterations. + + Returns: + output (torch.Tensor): output tensor. + """ + + def __init__( + self, + bilateral_weight: [float] = 0.8, + gaussian_weight: [float] = 0.2, + bilateral_spatial_sigma: [float] = 64, + bilateral_color_sigma: [float] = 0.2, + gaussian_spatial_sigma: [float] = 64, + compatability_kernel_range: [int] = 1, + iterations: [int] = 5, + ): + super(CRF, self).__init__() + self.bilateral_weight = bilateral_weight + self.gaussian_weight = gaussian_weight + self.bilateral_spatial_sigma = bilateral_spatial_sigma + self.bilateral_color_sigma = bilateral_color_sigma + self.gaussian_spatial_sigma = gaussian_spatial_sigma + self.compatability_kernel_range = compatability_kernel_range + self.iterations = iterations + + def forward(self, input_tensor: torch.Tensor, reference_tensor: torch.Tensor): + + # useful values + spatial_dim = input_tensor.dim() - 2 + class_count = input_tensor.size(1) + padding = self.compatability_kernel_range + + # constructing spatial feature tensor + spatial_features = _create_coordinate_tensor(reference_tensor) + + # constructing final feature tensors for bilateral and gaussian kernel + bilateral_features = torch.cat( + [spatial_features / self.bilateral_spatial_sigma, reference_tensor / self.bilateral_color_sigma], dim=1 + ) + gaussian_features = spatial_features / self.gaussian_spatial_sigma + + # compatability matrix (potts model (1 - diag) for now) + compatability_matrix = _potts_model_weights(class_count).to(device=input_tensor.device) + + # expanding matrix to kernel + compatability_kernel = _expand_matrix_to_kernel( + compatability_matrix, spatial_dim, self.compatability_kernel_range + ) + + # choosing convolution function + conv = [conv1d, conv2d, conv3d][spatial_dim - 1] + + # seting up output tensor + output_tensor = softmax(input_tensor, dim=1) + + # mean field loop + for _ in range(self.iterations): + + # message passing step for both kernels + bliateral_output = PHLFilter.apply(output_tensor, bilateral_features) + gaussian_output = PHLFilter.apply(output_tensor, gaussian_features) + + # combining filter outputs + combined_output = self.bilateral_weight * bliateral_output + self.gaussian_weight * gaussian_output + combined_output /= self.bilateral_weight + self.gaussian_weight + + # compatibility convolution + combined_output = pad(combined_output, 2 * spatial_dim * [padding], mode="replicate") + compatibility_update = conv(combined_output, compatability_kernel) + + # update and normalize + output_tensor = softmax(input_tensor - compatibility_update, dim=1) + + return output_tensor + + +# helper methods +def _create_coordinate_tensor(tensor): + axes = [torch.arange(tensor.size(i)) for i in range(2, tensor.dim())] + grids = torch.meshgrid(axes) + return torch.stack(grids).unsqueeze(0).to(device=tensor.device, dtype=tensor.dtype) + + +def _potts_model_weights(class_count): + return (1 - torch.diag(torch.ones(class_count))).unsqueeze(-1) + + +def _expand_matrix_to_kernel(matrix, spatial_dim, kernel_range): + reshape_arg = (matrix.size(0), matrix.size(1)) + spatial_dim * (1,) + expand_arg = (-1, -1) + spatial_dim * (1 + 2 * kernel_size,) + return matrix.reshape(reshape_arg).expand(expand_arg) From a73028e21e01eacedf932763bb117e41d2b0b2d2 Mon Sep 17 00:00:00 2001 From: charliebudd Date: Tue, 26 Jan 2021 11:19:11 +0000 Subject: [PATCH 02/16] fixing variable rename typo Signed-off-by: charliebudd --- monai/networks/blocks/crf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/networks/blocks/crf.py b/monai/networks/blocks/crf.py index df0faf848b..f4b064f021 100644 --- a/monai/networks/blocks/crf.py +++ b/monai/networks/blocks/crf.py @@ -137,5 +137,5 @@ def _potts_model_weights(class_count): def _expand_matrix_to_kernel(matrix, spatial_dim, kernel_range): reshape_arg = (matrix.size(0), matrix.size(1)) + spatial_dim * (1,) - expand_arg = (-1, -1) + spatial_dim * (1 + 2 * kernel_size,) + expand_arg = (-1, -1) + spatial_dim * (1 + 2 * kernel_range,) return matrix.reshape(reshape_arg).expand(expand_arg) From 01171c034db5a021c2c65ecacc9fd7f6b6fc5c79 Mon Sep 17 00:00:00 2001 From: charliebudd Date: Mon, 1 Mar 2021 10:41:28 +0000 Subject: [PATCH 03/16] changing default parameters Signed-off-by: charliebudd --- monai/networks/blocks/crf.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/monai/networks/blocks/crf.py b/monai/networks/blocks/crf.py index f4b064f021..d31e0f1936 100644 --- a/monai/networks/blocks/crf.py +++ b/monai/networks/blocks/crf.py @@ -56,11 +56,11 @@ class CRF(torch.nn.Module): def __init__( self, - bilateral_weight: [float] = 0.8, - gaussian_weight: [float] = 0.2, - bilateral_spatial_sigma: [float] = 64, - bilateral_color_sigma: [float] = 0.2, - gaussian_spatial_sigma: [float] = 64, + bilateral_weight: [float] = 1.0, + gaussian_weight: [float] = 1.0, + bilateral_spatial_sigma: [float] = 5.0, + bilateral_color_sigma: [float] = 0.5, + gaussian_spatial_sigma: [float] = 5.0, compatability_kernel_range: [int] = 1, iterations: [int] = 5, ): From 9d05485ef2315bf2437afe855e8029aa92f4a800 Mon Sep 17 00:00:00 2001 From: charliebudd Date: Mon, 1 Mar 2021 15:05:46 +0000 Subject: [PATCH 04/16] fixing error for non-singular batch size Signed-off-by: charliebudd --- monai/networks/blocks/crf.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/monai/networks/blocks/crf.py b/monai/networks/blocks/crf.py index d31e0f1936..9cfce9045b 100644 --- a/monai/networks/blocks/crf.py +++ b/monai/networks/blocks/crf.py @@ -128,7 +128,8 @@ def forward(self, input_tensor: torch.Tensor, reference_tensor: torch.Tensor): def _create_coordinate_tensor(tensor): axes = [torch.arange(tensor.size(i)) for i in range(2, tensor.dim())] grids = torch.meshgrid(axes) - return torch.stack(grids).unsqueeze(0).to(device=tensor.device, dtype=tensor.dtype) + coords = torch.stack(grids).to(device=tensor.device, dtype=tensor.dtype) + return torch.stack(tensor.size(0) * [coords], dim=0) def _potts_model_weights(class_count): From ed1a47a5be4335931445df03a426d2bd4ee06d9b Mon Sep 17 00:00:00 2001 From: charliebudd Date: Tue, 2 Mar 2021 10:38:18 +0000 Subject: [PATCH 05/16] unit tests Signed-off-by: charliebudd --- tests/test_crf_cpu.py | 460 +++++++++++++++++++++++++++++++++++++++++ tests/test_crf_cuda.py | 459 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 919 insertions(+) create mode 100644 tests/test_crf_cpu.py create mode 100644 tests/test_crf_cuda.py diff --git a/tests/test_crf_cpu.py b/tests/test_crf_cpu.py new file mode 100644 index 0000000000..762c060a80 --- /dev/null +++ b/tests/test_crf_cpu.py @@ -0,0 +1,460 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import torch +from parameterized import parameterized + +from monai.networks.blocks import CRF +from tests.utils import skip_if_no_cpp_extention + +TEST_CASES = [ + [ + # Case Description + "2 batche(s), 1 dimension(s), 2 classe(s), 1 channel(s)", + # Parameters + [ + 3.0, # bilateral_weight + 1.0, # gaussian_weight + 5.0, # bilateral_spatial_sigma + 0.5, # bilateral_color_sigma + 5.0, # gaussian_spatial_sigma + 1, # compatability_kernel_range + 5, # iterations + ], + # Input + [ + # Batch 0 + [ + # Class 0 + [0.8, 0.9, 0.6, 0.2, 0.3], + + # Class 1 + [0.1, 0.3, 0.5, 0.8, 0.7] + ], + # Batch 1 + [ + # Class 0 + [0.8, 0.9, 0.6, 0.2, 0.3], + + # Class 1 + [0.1, 0.3, 0.5, 0.8, 0.7] + ], + ], + # Features + [ + # Batch 0 + [ + # Channel 0 + [1, 1, 1, 0.5, 0], + ], + # Batch 1 + [ + # Channel 0 + [1, 1, 0.5, 0, 0], + ], + ], + # Expected + [ + # Batch 0 + [ + # Class 0 + [0.911299, 0.901745, 0.828803, 0.644988, 0.627997], + + # Class 1 + [0.088701, 0.098255, 0.171197, 0.355012, 0.372003], + ], + # Batch 1 + [ + # Class 0 + [0.902577, 0.870105, 0.725821, 0.465292, 0.466282], + + # Class 1 + [0.097423, 0.129895, 0.274179, 0.534708, 0.533718] + ], + ], + ], + [ + # Case Description + "1 batche(s), 2 dimension(s), 3 classe(s), 2 channel(s)", + # Parameters + [ + 3.0, # bilateral_weight + 1.0, # gaussian_weight + 5.0, # bilateral_spatial_sigma + 0.5, # bilateral_color_sigma + 5.0, # gaussian_spatial_sigma + 1, # compatability_kernel_range + 5, # iterations + ], + # Input + [ + # Batch 0 + [ + # Class 0 + [ + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 1.0, 1.0], + [0.0, 0.0, 0.0, 1.0, 1.0], + ], + + # Class 1 + [ + [1.0, 1.0, 0.0, 0.0, 0.0], + [1.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + ], + + # Class 2 + [ + [0.0, 0.0, 0.0, 1.0, 1.0], + [0.0, 0.0, 1.0, 1.0, 1.0], + [0.0, 1.0, 1.0, 1.0, 0.0], + [1.0, 1.0, 1.0, 0.0, 0.0], + [1.0, 1.0, 0.0, 0.0, 0.0], + ], + ], + ], + # Features + [ + # Batch 0 + [ + # Channel 0 + [ + [1.0, 1.0, 0.0, 0.0, 0.0], + [1.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + ], + + # Channel 1 + [ + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 1.0, 1.0], + [0.0, 0.0, 0.0, 1.0, 1.0], + ], + ], + ], + # Expected + [ + # Batch 0 + [ + # Class 0 + [ + [0.000124, 0.000124, 0.000124, 0.000045, 0.000045], + [0.000124, 0.000124, 0.000046, 0.000045, 0.000045], + [0.000124, 0.000046, 0.000046, 0.000046, 0.000124], + [0.000046, 0.000046, 0.000046, 0.000336, 0.000337], + [0.000046, 0.000046, 0.000124, 0.000337, 0.000337] + ], + + # Class 1 + [ + [0.000337, 0.000337, 0.000124, 0.000046, 0.000046], + [0.000337, 0.000337, 0.000046, 0.000046, 0.000046], + [0.000124, 0.000046, 0.000046, 0.000046, 0.000124], + [0.000046, 0.000046, 0.000046, 0.000124, 0.000124], + [0.000046, 0.000046, 0.000124, 0.000124, 0.000124] + ], + + # Class 2 + [ + [0.999539, 0.999539, 0.999752, 0.999909, 0.999909], + [0.999539, 0.999540, 0.999909, 0.999909, 0.999909], + [0.999752, 0.999909, 0.999909, 0.999909, 0.999753], + [0.999909, 0.999909, 0.999909, 0.999540, 0.999540], + [0.999909, 0.999909, 0.999753, 0.999540, 0.999539] + ], + ], + ], + ], + [ + # Case Description + "1 batche(s), 3 dimension(s), 2 classe(s), 1 channel(s)", + # Parameters + [ + 8.0, # bilateral_weight + 1.0, # gaussian_weight + 5.0, # bilateral_spatial_sigma + 0.1, # bilateral_color_sigma + 5.0, # gaussian_spatial_sigma + 1, # compatability_kernel_range + 2, # iterations + ], + # Input + [ + # Batch 0 + [ + # Class 0 + [ + # Slice 0 + [ + [1.0, 1.0, 0.0, 0.0, 0.0], + [1.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + ], + # Slice 1 + [ + [1.0, 1.0, 0.0, 0.0, 0.0], + [1.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + ], + # Slice 2 + [ + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + ], + # Slice 3 + [ + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + ], + # Slice 4 + [ + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + ], + ], + + # Class 1 + [ + # Slice 0 + [ + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + ], + # Slice 1 + [ + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + ], + # Slice 2 + [ + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + ], + # Slice 3 + [ + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 1.0, 1.0], + [0.0, 0.0, 0.0, 1.0, 1.0], + ], + # Slice 4 + [ + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 1.0, 1.0], + [0.0, 0.0, 0.0, 1.0, 1.0], + ], + ], + ], + ], + # Features + [ + # Batch 0 + [ + # Channel 0 + [ + # Slice 0 + [ + [0.5, 0.5, 0.5, 0.0, 0.0], + [0.5, 0.5, 0.5, 0.0, 0.0], + [0.5, 0.5, 0.5, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + ], + # Slice 1 + [ + [0.5, 0.5, 0.5, 0.0, 0.0], + [0.5, 0.5, 0.5, 0.0, 0.0], + [0.5, 0.5, 0.5, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + ], + # Slice 2 + [ + [0.5, 0.5, 0.5, 0.0, 0.0], + [0.5, 0.5, 0.5, 0.0, 0.0], + [0.5, 0.5, 0.8, 1.0, 1.0], + [0.0, 0.0, 1.0, 1.0, 1.0], + [0.0, 0.0, 1.0, 1.0, 1.0], + ], + # Slice 3 + [ + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 1.0, 1.0], + [0.0, 0.0, 1.0, 1.0, 1.0], + [0.0, 0.0, 1.0, 1.0, 1.0], + ], + # Slice 4 + [ + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 1.0, 1.0], + [0.0, 0.0, 1.0, 1.0, 1.0], + [0.0, 0.0, 1.0, 1.0, 1.0], + ], + ], + ], + ], + # Expected + [ + # Batch 0 + [ + # Class 0 + [ + # Slice 0 + [ + [1.000000, 1.000000, 0.999999, 0.999049, 0.623107], + [1.000000, 1.000000, 0.999998, 0.998911, 0.586533], + [0.999999, 0.999998, 0.999866, 0.988961, 0.531648], + [0.999001, 0.998872, 0.988757, 0.897140, 0.478895], + [0.607341, 0.574104, 0.523379, 0.475603, 0.444802] + ], + # Slice 1 + [ + [1.000000, 1.000000, 0.999998, 0.998867, 0.575748], + [1.000000, 1.000000, 0.999987, 0.982081, 0.125020], + [0.999998, 0.999987, 0.995514, 0.395557, 0.014057], + [0.998802, 0.981356, 0.390734, 0.007831, 0.001451], + [0.557932, 0.118905, 0.013556, 0.001430, 0.001305] + ], + # Slice 2 + [ + [0.999998, 0.999998, 0.999852, 0.987882, 0.509607], + [0.999998, 0.999986, 0.995232, 0.382256, 0.013357], + [0.999845, 0.995113, 0.379344, 0.001853, 0.000170], + [0.987122, 0.372275, 0.001819, 0.000006, 0.000002], + [0.490950, 0.012637, 0.000164, 0.000002, 0.000002] + ], + # Slice 3 + [ + [0.998755, 0.998626, 0.986658, 0.884026, 0.452782], + [0.998602, 0.978651, 0.362107, 0.007108, 0.001340], + [0.986097, 0.356834, 0.001720, 0.000005, 0.000002], + [0.878055, 0.006855, 0.000005, 0.000000, 0.000000], + [0.436683, 0.001282, 0.000002, 0.000000, 0.000000] + ], + # Slice 4 + [ + [0.532362, 0.506320, 0.466688, 0.436673, 0.419582], + [0.501712, 0.099573, 0.011501, 0.001265, 0.001190], + [0.456888, 0.011278, 0.000150, 0.000002, 0.000002], + [0.423857, 0.001228, 0.000002, 0.000000, 0.000000], + [0.405660, 0.001153, 0.000002, 0.000000, 0.000000] + ], + ], + + # Class 1 + [ + # Slice 0 + [ + [0.000000, 0.000000, 0.000001, 0.000951, 0.376893], + [0.000000, 0.000000, 0.000002, 0.001089, 0.413467], + [0.000001, 0.000002, 0.000134, 0.011039, 0.468352], + [0.000999, 0.001128, 0.011243, 0.102860, 0.521105], + [0.392659, 0.425896, 0.476621, 0.524397, 0.555198] + ], + # Slice 1 + [ + [0.000000, 0.000000, 0.000002, 0.001133, 0.424252], + [0.000000, 0.000000, 0.000013, 0.017919, 0.874980], + [0.000002, 0.000013, 0.004486, 0.604443, 0.985943], + [0.001198, 0.018644, 0.609266, 0.992169, 0.998549], + [0.442068, 0.881095, 0.986444, 0.998570, 0.998695] + ], + # Slice 2 + [ + [0.000002, 0.000002, 0.000148, 0.012118, 0.490393], + [0.000002, 0.000014, 0.004769, 0.617744, 0.986643], + [0.000155, 0.004887, 0.620656, 0.998147, 0.999830], + [0.012878, 0.627725, 0.998181, 0.999994, 0.999998], + [0.509050, 0.987363, 0.999836, 0.999998, 0.999998] + ], + # Slice 3 + [ + [0.001245, 0.001374, 0.013342, 0.115974, 0.547218], + [0.001398, 0.021349, 0.637893, 0.992892, 0.998660], + [0.013903, 0.643166, 0.998280, 0.999995, 0.999998], + [0.121945, 0.993145, 0.999995, 1.000000, 1.000000], + [0.563317, 0.998718, 0.999998, 1.000000, 1.000000] + ], + # Slice 4 + [ + [0.467638, 0.493680, 0.533312, 0.563327, 0.580418], + [0.498288, 0.900427, 0.988499, 0.998735, 0.998810], + [0.543112, 0.988722, 0.999850, 0.999998, 0.999998], + [0.576143, 0.998772, 0.999998, 1.000000, 1.000000], + [0.594340, 0.998847, 0.999998, 1.000000, 1.000000] + ], + ], + ], + ], + ], +] + + +@skip_if_no_cpp_extention +class CRFTestCaseCpu(unittest.TestCase): + + @parameterized.expand(TEST_CASES) + def test(self, test_case_description, params, input, features, expected): + + # Create input tensors + input_tensor = torch.from_numpy(np.array(input)).to(dtype=torch.float, device=torch.device("cpu")) + feature_tensor = torch.from_numpy(np.array(features)).to(dtype=torch.float, device=torch.device("cpu")) + + # apply filter + crf = CRF(*params) + output = crf(input_tensor, feature_tensor).cpu().numpy() + + # Ensure result are as expected + np.testing.assert_allclose(output, expected, atol=1e-4) + + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_crf_cuda.py b/tests/test_crf_cuda.py new file mode 100644 index 0000000000..1e663b6670 --- /dev/null +++ b/tests/test_crf_cuda.py @@ -0,0 +1,459 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import torch +from parameterized import parameterized + +from monai.networks.blocks import CRF +from tests.utils import skip_if_no_cpp_extention, skip_if_no_cuda + +TEST_CASES = [ + [ + # Case Description + "2 batche(s), 1 dimension(s), 2 classe(s), 1 channel(s)", + # Parameters + [ + 3.0, # bilateral_weight + 1.0, # gaussian_weight + 5.0, # bilateral_spatial_sigma + 0.5, # bilateral_color_sigma + 5.0, # gaussian_spatial_sigma + 1, # compatability_kernel_range + 5, # iterations + ], + # Input + [ + # Batch 0 + [ + # Class 0 + [0.8, 0.9, 0.6, 0.2, 0.3], + + # Class 1 + [0.1, 0.3, 0.5, 0.8, 0.7] + ], + # Batch 1 + [ + # Class 0 + [0.8, 0.9, 0.6, 0.2, 0.3], + + # Class 1 + [0.1, 0.3, 0.5, 0.8, 0.7] + ], + ], + # Features + [ + # Batch 0 + [ + # Channel 0 + [1, 1, 1, 0.5, 0], + ], + # Batch 1 + [ + # Channel 0 + [1, 1, 0.5, 0, 0], + ], + ], + # Expected + [ + # Batch 0 + [ + # Class 0 + [0.89333 , 0.881837, 0.787194, 0.55363 , 0.507627], + + # Class 1 + [0.10667 , 0.118163, 0.212806, 0.44637 , 0.492373] + ], + # Batch 1 + [ + # Class 0 + [0.846356, 0.777572, 0.536503, 0.241165, 0.232537], + + # Class 1 + [0.153644, 0.222428, 0.463497, 0.758835, 0.767463] + ], + ], + ], + [ + # Case Description + "1 batche(s), 2 dimension(s), 3 classe(s), 2 channel(s)", + # Parameters + [ + 3.0, # bilateral_weight + 1.0, # gaussian_weight + 5.0, # bilateral_spatial_sigma + 0.5, # bilateral_color_sigma + 5.0, # gaussian_spatial_sigma + 1, # compatability_kernel_range + 5, # iterations + ], + # Input + [ + # Batch 0 + [ + # Class 0 + [ + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 1.0, 1.0], + [0.0, 0.0, 0.0, 1.0, 1.0], + ], + + # Class 1 + [ + [1.0, 1.0, 0.0, 0.0, 0.0], + [1.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + ], + + # Class 2 + [ + [0.0, 0.0, 0.0, 1.0, 1.0], + [0.0, 0.0, 1.0, 1.0, 1.0], + [0.0, 1.0, 1.0, 1.0, 0.0], + [1.0, 1.0, 1.0, 0.0, 0.0], + [1.0, 1.0, 0.0, 0.0, 0.0], + ], + ], + ], + # Features + [ + # Batch 0 + [ + # Channel 0 + [ + [1.0, 1.0, 0.0, 0.0, 0.0], + [1.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + ], + + # Channel 1 + [ + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 1.0, 1.0], + [0.0, 0.0, 0.0, 1.0, 1.0], + ], + ], + ], + # Expected + [ + # Batch 0 + [ + # Class 0 + [ + [0.000257, 0.000283, 0.000311, 0.000126, 0.000126], + [0.000376, 0.000458, 0.000216, 0.000257, 0.000264], + [0.000693, 0.000347, 0.001017, 0.002675, 0.015330], + [0.000604, 0.000824, 0.004701, 0.148476, 0.425911], + [0.000987, 0.001379, 0.038858, 0.516412, 0.896450] + ], + + # Class 1 + [ + [0.000702, 0.000664, 0.000230, 0.000081, 0.000080], + [0.000771, 0.000793, 0.000113, 0.000115, 0.000116], + [0.000348, 0.000144, 0.000237, 0.000370, 0.001437], + [0.000177, 0.000207, 0.000491, 0.002529, 0.003492], + [0.000224, 0.000265, 0.002267, 0.003538, 0.002137] + ], + + # Class 2 + [ + [0.999041, 0.999054, 0.999459, 0.999793, 0.999793], + [0.998852, 0.998749, 0.999672, 0.999628, 0.999621], + [0.998959, 0.999509, 0.998746, 0.996955, 0.983234], + [0.999219, 0.998969, 0.994808, 0.848995, 0.570597], + [0.998789, 0.998356, 0.958874, 0.480050, 0.101413] + ], + ], + ], + ], + [ + # Case Description + "1 batche(s), 3 dimension(s), 2 classe(s), 1 channel(s)", + # Parameters + [ + 8.0, # bilateral_weight + 1.0, # gaussian_weight + 5.0, # bilateral_spatial_sigma + 0.1, # bilateral_color_sigma + 5.0, # gaussian_spatial_sigma + 1, # compatability_kernel_range + 2, # iterations + ], + # Input + [ + # Batch 0 + [ + # Class 0 + [ + # Slice 0 + [ + [1.0, 1.0, 0.0, 0.0, 0.0], + [1.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + ], + # Slice 1 + [ + [1.0, 1.0, 0.0, 0.0, 0.0], + [1.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + ], + # Slice 2 + [ + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + ], + # Slice 3 + [ + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + ], + # Slice 4 + [ + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + ], + ], + + # Class 1 + [ + # Slice 0 + [ + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + ], + # Slice 1 + [ + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + ], + # Slice 2 + [ + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + ], + # Slice 3 + [ + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 1.0, 1.0], + [0.0, 0.0, 0.0, 1.0, 1.0], + ], + # Slice 4 + [ + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 1.0, 1.0], + [0.0, 0.0, 0.0, 1.0, 1.0], + ], + ], + ], + ], + # Features + [ + # Batch 0 + [ + # Channel 0 + [ + # Slice 0 + [ + [0.5, 0.5, 0.5, 0.0, 0.0], + [0.5, 0.5, 0.5, 0.0, 0.0], + [0.5, 0.5, 0.5, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + ], + # Slice 1 + [ + [0.5, 0.5, 0.5, 0.0, 0.0], + [0.5, 0.5, 0.5, 0.0, 0.0], + [0.5, 0.5, 0.5, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + ], + # Slice 2 + [ + [0.5, 0.5, 0.5, 0.0, 0.0], + [0.5, 0.5, 0.5, 0.0, 0.0], + [0.5, 0.5, 0.8, 1.0, 1.0], + [0.0, 0.0, 1.0, 1.0, 1.0], + [0.0, 0.0, 1.0, 1.0, 1.0], + ], + # Slice 3 + [ + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 1.0, 1.0], + [0.0, 0.0, 1.0, 1.0, 1.0], + [0.0, 0.0, 1.0, 1.0, 1.0], + ], + # Slice 4 + [ + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 1.0, 1.0], + [0.0, 0.0, 1.0, 1.0, 1.0], + [0.0, 0.0, 1.0, 1.0, 1.0], + ], + ], + ], + ], + # Expected + [ + # Batch 0 + [ + # Class 0 + [ + # Slice 0 + [ + [1.000000, 1.000000, 0.999999, 0.999400, 0.672704], + [1.000000, 1.000000, 0.999999, 0.999311, 0.634444], + [0.999999, 0.999999, 0.999923, 0.992283, 0.570609], + [0.999347, 0.999270, 0.992103, 0.915680, 0.497763], + [0.641561, 0.611654, 0.561265, 0.496694, 0.447266] + ], + # Slice 1 + [ + [1.000000, 1.000000, 0.999999, 0.999297, 0.628315], + [1.000000, 1.000000, 0.999996, 0.991844, 0.143119], + [0.999999, 0.999995, 0.998187, 0.561012, 0.015132], + [0.999211, 0.991375, 0.560541, 0.013006, 0.001419], + [0.584354, 0.132132, 0.015294, 0.001480, 0.001224] + ], + # Slice 2 + [ + [0.999999, 0.999999, 0.999920, 0.991880, 0.556050], + [0.999999, 0.999995, 0.998105, 0.545511, 0.014186], + [0.999915, 0.998068, 0.566432, 0.003078, 0.000158], + [0.990786, 0.537158, 0.003170, 0.000008, 0.000002], + [0.510227, 0.013575, 0.000172, 0.000002, 0.000002] + ], + # Slice 3 + [ + [0.999278, 0.999206, 0.991367, 0.909220, 0.479173], + [0.999193, 0.990973, 0.534392, 0.011364, 0.001236], + [0.990854, 0.535103, 0.002982, 0.000007, 0.000002], + [0.902482, 0.011722, 0.000008, 0.000000, 0.000000], + [0.456381, 0.001320, 0.000002, 0.000000, 0.000000] + ], + # Slice 4 + [ + [0.606149, 0.580932, 0.528721, 0.472177, 0.429050], + [0.576264, 0.124628, 0.013063, 0.001210, 0.001002], + [0.517166, 0.013384, 0.000150, 0.000002, 0.000001], + [0.467132, 0.001316, 0.000002, 0.000000, 0.000000], + [0.432737, 0.001164, 0.000002, 0.000000, 0.000000] + ], + ], + + # Class 1 + [ + # Slice 0 + [ + [0.000000, 0.000000, 0.000001, 0.000600, 0.327296], + [0.000000, 0.000000, 0.000001, 0.000689, 0.365556], + [0.000001, 0.000001, 0.000077, 0.007717, 0.429391], + [0.000653, 0.000729, 0.007897, 0.084320, 0.502237], + [0.358439, 0.388346, 0.438735, 0.503306, 0.552734] + ], + # Slice 1 + [ + [0.000000, 0.000000, 0.000001, 0.000703, 0.371685], + [0.000000, 0.000000, 0.000004, 0.008156, 0.856880], + [0.000001, 0.000005, 0.001814, 0.438988, 0.984868], + [0.000789, 0.008625, 0.439459, 0.986994, 0.998581], + [0.415646, 0.867868, 0.984706, 0.998520, 0.998776] + ], + # Slice 2 + [ + [0.000001, 0.000001, 0.000080, 0.008120, 0.443950], + [0.000001, 0.000005, 0.001895, 0.454489, 0.985814], + [0.000085, 0.001932, 0.433568, 0.996922, 0.999842], + [0.009214, 0.462842, 0.996830, 0.999992, 0.999998], + [0.489773, 0.986425, 0.999828, 0.999998, 0.999998] + ], + # Slice 3 + [ + [0.000722, 0.000794, 0.008633, 0.090780, 0.520827], + [0.000807, 0.009027, 0.465608, 0.988636, 0.998764], + [0.009146, 0.464897, 0.997018, 0.999993, 0.999998], + [0.097518, 0.988278, 0.999992, 1.000000, 1.000000], + [0.543619, 0.998680, 0.999998, 1.000000, 1.000000] + ], + # Slice 4 + [ + [0.393851, 0.419068, 0.471279, 0.527823, 0.570950], + [0.423736, 0.875372, 0.986937, 0.998790, 0.998998], + [0.482834, 0.986616, 0.999850, 0.999998, 0.999999], + [0.532868, 0.998684, 0.999998, 1.000000, 1.000000], + [0.567263, 0.998836, 0.999998, 1.000000, 1.000000] + ], + ], + ], + ], + ], +] + + +@skip_if_no_cpp_extention +class CRFTestCaseCuda(unittest.TestCase): + + @parameterized.expand(TEST_CASES) + def test(self, test_case_description, params, input, features, expected): + + # Create input tensors + input_tensor = torch.from_numpy(np.array(input)).to(dtype=torch.float, device=torch.device("cuda")) + feature_tensor = torch.from_numpy(np.array(features)).to(dtype=torch.float, device=torch.device("cuda")) + + # apply filter + crf = CRF(*params) + output = crf(input_tensor, feature_tensor).cpu().numpy() + + # Ensure result are as expected + np.testing.assert_allclose(output, expected, atol=1e-4) + + +if __name__ == "__main__": + unittest.main() From 3bb1db4771748df3b1628cdcc5455daa4327bd32 Mon Sep 17 00:00:00 2001 From: chaliebudd Date: Thu, 18 Mar 2021 14:01:49 +0000 Subject: [PATCH 06/16] Removing filter weight normalisation and adding update power factor Signed-off-by: chaliebudd --- monai/networks/blocks/crf.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/monai/networks/blocks/crf.py b/monai/networks/blocks/crf.py index 9cfce9045b..119b1b7622 100644 --- a/monai/networks/blocks/crf.py +++ b/monai/networks/blocks/crf.py @@ -36,17 +36,19 @@ class CRF(torch.nn.Module): referenece_tensor: the reference tensor used to guide the message passing. - bilateral_weight: the weighting of the bilateral term in the message passing step + bilateral_weight: the weighting of the bilateral term in the message passing step. - gaussian_weight: the weighting of the gaussian term in the message passing step + gaussian_weight: the weighting of the gaussian term in the message passing step. - bilateral_spatial_sigma: standard deviation in spatial coordinates for the bilateral term + bilateral_spatial_sigma: standard deviation in spatial coordinates for the bilateral term. - bilateral_color_sigma: standard deviation in color space for the bilateral term + bilateral_color_sigma: standard deviation in color space for the bilateral term. - gaussian_spatial_sigma: standard deviation in spatial coordinates for the gaussian term + gaussian_spatial_sigma: standard deviation in spatial coordinates for the gaussian term. - compatability_kernel_range: the range of the kernel used in the compatability convolution + update_factor: determines the magnitude of each update. + + compatability_kernel_range: the range of the kernel used in the compatability convolution. iterations: the number of iterations. @@ -61,6 +63,7 @@ def __init__( bilateral_spatial_sigma: [float] = 5.0, bilateral_color_sigma: [float] = 0.5, gaussian_spatial_sigma: [float] = 5.0, + update_factor : [float] = 3.0, compatability_kernel_range: [int] = 1, iterations: [int] = 5, ): @@ -70,6 +73,7 @@ def __init__( self.bilateral_spatial_sigma = bilateral_spatial_sigma self.bilateral_color_sigma = bilateral_color_sigma self.gaussian_spatial_sigma = gaussian_spatial_sigma + self.update_factor = update_factor self.compatability_kernel_range = compatability_kernel_range self.iterations = iterations @@ -112,14 +116,13 @@ def forward(self, input_tensor: torch.Tensor, reference_tensor: torch.Tensor): # combining filter outputs combined_output = self.bilateral_weight * bliateral_output + self.gaussian_weight * gaussian_output - combined_output /= self.bilateral_weight + self.gaussian_weight # compatibility convolution combined_output = pad(combined_output, 2 * spatial_dim * [padding], mode="replicate") compatibility_update = conv(combined_output, compatability_kernel) # update and normalize - output_tensor = softmax(input_tensor - compatibility_update, dim=1) + output_tensor = softmax(input_tensor - self.update_factor * compatibility_update, dim=1) return output_tensor From b03647a7cb51a64c0afbd6e8a9ffdbc0723cc855 Mon Sep 17 00:00:00 2001 From: chaliebudd Date: Thu, 18 Mar 2021 14:40:48 +0000 Subject: [PATCH 07/16] updating unit tests Signed-off-by: chaliebudd --- monai/networks/blocks/crf.py | 2 +- tests/test_crf_cpu.py | 200 ++++++++++++++++----------------- tests/test_crf_cuda.py | 209 +++++++++++++++++------------------ 3 files changed, 196 insertions(+), 215 deletions(-) diff --git a/monai/networks/blocks/crf.py b/monai/networks/blocks/crf.py index 119b1b7622..48ed265495 100644 --- a/monai/networks/blocks/crf.py +++ b/monai/networks/blocks/crf.py @@ -63,7 +63,7 @@ def __init__( bilateral_spatial_sigma: [float] = 5.0, bilateral_color_sigma: [float] = 0.5, gaussian_spatial_sigma: [float] = 5.0, - update_factor : [float] = 3.0, + update_factor: [float] = 3.0, compatability_kernel_range: [int] = 1, iterations: [int] = 5, ): diff --git a/tests/test_crf_cpu.py b/tests/test_crf_cpu.py index 762c060a80..660278167c 100644 --- a/tests/test_crf_cpu.py +++ b/tests/test_crf_cpu.py @@ -24,13 +24,14 @@ "2 batche(s), 1 dimension(s), 2 classe(s), 1 channel(s)", # Parameters [ - 3.0, # bilateral_weight - 1.0, # gaussian_weight - 5.0, # bilateral_spatial_sigma - 0.5, # bilateral_color_sigma - 5.0, # gaussian_spatial_sigma - 1, # compatability_kernel_range - 5, # iterations + 1.0, # bilateral_weight + 0.3, # gaussian_weight + 5.0, # bilateral_spatial_sigma + 0.5, # bilateral_color_sigma + 5.0, # gaussian_spatial_sigma + 1.0, # update_factor + 1, # compatability_kernel_range + 5, # iterations ], # Input [ @@ -38,17 +39,15 @@ [ # Class 0 [0.8, 0.9, 0.6, 0.2, 0.3], - # Class 1 - [0.1, 0.3, 0.5, 0.8, 0.7] + [0.1, 0.3, 0.5, 0.8, 0.7], ], # Batch 1 [ # Class 0 [0.8, 0.9, 0.6, 0.2, 0.3], - # Class 1 - [0.1, 0.3, 0.5, 0.8, 0.7] + [0.1, 0.3, 0.5, 0.8, 0.7], ], ], # Features @@ -69,18 +68,16 @@ # Batch 0 [ # Class 0 - [0.911299, 0.901745, 0.828803, 0.644988, 0.627997], - + [0.976472, 0.973789, 0.951958, 0.882982, 0.876651], # Class 1 - [0.088701, 0.098255, 0.171197, 0.355012, 0.372003], + [0.023528, 0.026211, 0.048042, 0.117018, 0.123349], ], # Batch 1 [ # Class 0 - [0.902577, 0.870105, 0.725821, 0.465292, 0.466282], - + [0.963642, 0.946892, 0.858650, 0.633639, 0.617334], # Class 1 - [0.097423, 0.129895, 0.274179, 0.534708, 0.533718] + [0.036358, 0.053108, 0.141350, 0.366361, 0.382666], ], ], ], @@ -89,13 +86,14 @@ "1 batche(s), 2 dimension(s), 3 classe(s), 2 channel(s)", # Parameters [ - 3.0, # bilateral_weight - 1.0, # gaussian_weight - 5.0, # bilateral_spatial_sigma - 0.5, # bilateral_color_sigma - 5.0, # gaussian_spatial_sigma - 1, # compatability_kernel_range - 5, # iterations + 1.0, # bilateral_weight + 0.3, # gaussian_weight + 5.0, # bilateral_spatial_sigma + 0.5, # bilateral_color_sigma + 5.0, # gaussian_spatial_sigma + 1.0, # update_factor + 1, # compatability_kernel_range + 5, # iterations ], # Input [ @@ -109,7 +107,6 @@ [0.0, 0.0, 0.0, 1.0, 1.0], [0.0, 0.0, 0.0, 1.0, 1.0], ], - # Class 1 [ [1.0, 1.0, 0.0, 0.0, 0.0], @@ -118,7 +115,6 @@ [0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0], ], - # Class 2 [ [0.0, 0.0, 0.0, 1.0, 1.0], @@ -141,7 +137,6 @@ [0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0], ], - # Channel 1 [ [0.0, 0.0, 0.0, 0.0, 0.0], @@ -158,29 +153,27 @@ [ # Class 0 [ - [0.000124, 0.000124, 0.000124, 0.000045, 0.000045], - [0.000124, 0.000124, 0.000046, 0.000045, 0.000045], - [0.000124, 0.000046, 0.000046, 0.000046, 0.000124], - [0.000046, 0.000046, 0.000046, 0.000336, 0.000337], - [0.000046, 0.000046, 0.000124, 0.000337, 0.000337] + [0.000008, 0.000008, 0.000008, 0.000003, 0.000003], + [0.000008, 0.000008, 0.000003, 0.000003, 0.000003], + [0.000008, 0.000003, 0.000003, 0.000003, 0.000008], + [0.000003, 0.000003, 0.000003, 0.000023, 0.000023], + [0.000003, 0.000003, 0.000008, 0.000023, 0.000023], ], - # Class 1 [ - [0.000337, 0.000337, 0.000124, 0.000046, 0.000046], - [0.000337, 0.000337, 0.000046, 0.000046, 0.000046], - [0.000124, 0.000046, 0.000046, 0.000046, 0.000124], - [0.000046, 0.000046, 0.000046, 0.000124, 0.000124], - [0.000046, 0.000046, 0.000124, 0.000124, 0.000124] + [0.000023, 0.000023, 0.000008, 0.000003, 0.000003], + [0.000023, 0.000023, 0.000003, 0.000003, 0.000003], + [0.000008, 0.000003, 0.000003, 0.000003, 0.000008], + [0.000003, 0.000003, 0.000003, 0.000008, 0.000008], + [0.000003, 0.000003, 0.000008, 0.000008, 0.000008], ], - # Class 2 [ - [0.999539, 0.999539, 0.999752, 0.999909, 0.999909], - [0.999539, 0.999540, 0.999909, 0.999909, 0.999909], - [0.999752, 0.999909, 0.999909, 0.999909, 0.999753], - [0.999909, 0.999909, 0.999909, 0.999540, 0.999540], - [0.999909, 0.999909, 0.999753, 0.999540, 0.999539] + [0.999969, 0.999969, 0.999983, 0.999994, 0.999994], + [0.999969, 0.999969, 0.999994, 0.999994, 0.999994], + [0.999983, 0.999994, 0.999994, 0.999994, 0.999983], + [0.999994, 0.999994, 0.999994, 0.999969, 0.999969], + [0.999994, 0.999994, 0.999983, 0.999969, 0.999969], ], ], ], @@ -190,13 +183,14 @@ "1 batche(s), 3 dimension(s), 2 classe(s), 1 channel(s)", # Parameters [ - 8.0, # bilateral_weight - 1.0, # gaussian_weight - 5.0, # bilateral_spatial_sigma - 0.1, # bilateral_color_sigma - 5.0, # gaussian_spatial_sigma - 1, # compatability_kernel_range - 2, # iterations + 1.0, # bilateral_weight + 0.3, # gaussian_weight + 5.0, # bilateral_spatial_sigma + 0.1, # bilateral_color_sigma + 5.0, # gaussian_spatial_sigma + 1.0, # update_factor + 1, # compatability_kernel_range + 2, # iterations ], # Input [ @@ -245,7 +239,6 @@ [0.0, 0.0, 0.0, 0.0, 0.0], ], ], - # Class 1 [ # Slice 0 @@ -348,87 +341,86 @@ [ # Slice 0 [ - [1.000000, 1.000000, 0.999999, 0.999049, 0.623107], - [1.000000, 1.000000, 0.999998, 0.998911, 0.586533], - [0.999999, 0.999998, 0.999866, 0.988961, 0.531648], - [0.999001, 0.998872, 0.988757, 0.897140, 0.478895], - [0.607341, 0.574104, 0.523379, 0.475603, 0.444802] + [1.000000, 1.000000, 1.000000, 0.999808, 0.721122], + [1.000000, 1.000000, 1.000000, 0.999758, 0.666025], + [1.000000, 1.000000, 0.999979, 0.995894, 0.577459], + [0.999787, 0.999739, 0.995725, 0.934170, 0.488704], + [0.691645, 0.641028, 0.560127, 0.481718, 0.431180], ], # Slice 1 [ - [1.000000, 1.000000, 0.999998, 0.998867, 0.575748], - [1.000000, 1.000000, 0.999987, 0.982081, 0.125020], - [0.999998, 0.999987, 0.995514, 0.395557, 0.014057], - [0.998802, 0.981356, 0.390734, 0.007831, 0.001451], - [0.557932, 0.118905, 0.013556, 0.001430, 0.001305] + [1.000000, 1.000000, 1.000000, 0.999743, 0.650416], + [1.000000, 1.000000, 0.999999, 0.992747, 0.108034], + [1.000000, 0.999999, 0.998541, 0.402370, 0.007122], + [0.999711, 0.992109, 0.391941, 0.003358, 0.000440], + [0.615523, 0.097120, 0.006599, 0.000427, 0.000365], ], # Slice 2 [ - [0.999998, 0.999998, 0.999852, 0.987882, 0.509607], - [0.999998, 0.999986, 0.995232, 0.382256, 0.013357], - [0.999845, 0.995113, 0.379344, 0.001853, 0.000170], - [0.987122, 0.372275, 0.001819, 0.000006, 0.000002], - [0.490950, 0.012637, 0.000164, 0.000002, 0.000002] + [1.000000, 1.000000, 0.999975, 0.995241, 0.543122], + [1.000000, 0.999998, 0.998394, 0.381981, 0.006586], + [0.999973, 0.998313, 0.370238, 0.000596, 0.000034], + [0.994611, 0.361317, 0.000573, 0.000001, 0.000000], + [0.505392, 0.005862, 0.000032, 0.000000, 0.000000], ], # Slice 3 [ - [0.998755, 0.998626, 0.986658, 0.884026, 0.452782], - [0.998602, 0.978651, 0.362107, 0.007108, 0.001340], - [0.986097, 0.356834, 0.001720, 0.000005, 0.000002], - [0.878055, 0.006855, 0.000005, 0.000000, 0.000000], - [0.436683, 0.001282, 0.000002, 0.000000, 0.000000] + [0.999692, 0.999639, 0.994364, 0.919713, 0.446683], + [0.999626, 0.990123, 0.347190, 0.002895, 0.000390], + [0.993872, 0.336665, 0.000525, 0.000001, 0.000000], + [0.910704, 0.002676, 0.000001, 0.000000, 0.000000], + [0.413964, 0.000354, 0.000000, 0.000000, 0.000000], ], # Slice 4 [ - [0.532362, 0.506320, 0.466688, 0.436673, 0.419582], - [0.501712, 0.099573, 0.011501, 0.001265, 0.001190], - [0.456888, 0.011278, 0.000150, 0.000002, 0.000002], - [0.423857, 0.001228, 0.000002, 0.000000, 0.000000], - [0.405660, 0.001153, 0.000002, 0.000000, 0.000000] + [0.574496, 0.533306, 0.469335, 0.419446, 0.390403], + [0.524093, 0.072180, 0.005087, 0.000354, 0.000318], + [0.449362, 0.004875, 0.000028, 0.000000, 0.000000], + [0.393431, 0.000331, 0.000000, 0.000000, 0.000000], + [0.362417, 0.000295, 0.000000, 0.000000, 0.000000], ], ], - # Class 1 [ # Slice 0 [ - [0.000000, 0.000000, 0.000001, 0.000951, 0.376893], - [0.000000, 0.000000, 0.000002, 0.001089, 0.413467], - [0.000001, 0.000002, 0.000134, 0.011039, 0.468352], - [0.000999, 0.001128, 0.011243, 0.102860, 0.521105], - [0.392659, 0.425896, 0.476621, 0.524397, 0.555198] + [0.000000, 0.000000, 0.000000, 0.000192, 0.278878], + [0.000000, 0.000000, 0.000000, 0.000242, 0.333975], + [0.000000, 0.000000, 0.000021, 0.004106, 0.422541], + [0.000213, 0.000261, 0.004275, 0.065830, 0.511296], + [0.308355, 0.358972, 0.439873, 0.518282, 0.568820], ], # Slice 1 [ - [0.000000, 0.000000, 0.000002, 0.001133, 0.424252], - [0.000000, 0.000000, 0.000013, 0.017919, 0.874980], - [0.000002, 0.000013, 0.004486, 0.604443, 0.985943], - [0.001198, 0.018644, 0.609266, 0.992169, 0.998549], - [0.442068, 0.881095, 0.986444, 0.998570, 0.998695] + [0.000000, 0.000000, 0.000000, 0.000257, 0.349584], + [0.000000, 0.000000, 0.000001, 0.007253, 0.891966], + [0.000000, 0.000001, 0.001459, 0.597630, 0.992878], + [0.000289, 0.007891, 0.608059, 0.996642, 0.999560], + [0.384477, 0.902880, 0.993401, 0.999573, 0.999635], ], # Slice 2 [ - [0.000002, 0.000002, 0.000148, 0.012118, 0.490393], - [0.000002, 0.000014, 0.004769, 0.617744, 0.986643], - [0.000155, 0.004887, 0.620656, 0.998147, 0.999830], - [0.012878, 0.627725, 0.998181, 0.999994, 0.999998], - [0.509050, 0.987363, 0.999836, 0.999998, 0.999998] + [0.000000, 0.000000, 0.000025, 0.004759, 0.456878], + [0.000000, 0.000002, 0.001606, 0.618019, 0.993414], + [0.000027, 0.001687, 0.629762, 0.999404, 0.999966], + [0.005389, 0.638683, 0.999427, 0.999999, 1.000000], + [0.494608, 0.994138, 0.999968, 1.000000, 1.000000], ], # Slice 3 [ - [0.001245, 0.001374, 0.013342, 0.115974, 0.547218], - [0.001398, 0.021349, 0.637893, 0.992892, 0.998660], - [0.013903, 0.643166, 0.998280, 0.999995, 0.999998], - [0.121945, 0.993145, 0.999995, 1.000000, 1.000000], - [0.563317, 0.998718, 0.999998, 1.000000, 1.000000] + [0.000308, 0.000361, 0.005636, 0.080287, 0.553317], + [0.000374, 0.009877, 0.652810, 0.997105, 0.999610], + [0.006128, 0.663335, 0.999475, 0.999999, 1.000000], + [0.089296, 0.997324, 0.999999, 1.000000, 1.000000], + [0.586036, 0.999646, 1.000000, 1.000000, 1.000000], ], # Slice 4 [ - [0.467638, 0.493680, 0.533312, 0.563327, 0.580418], - [0.498288, 0.900427, 0.988499, 0.998735, 0.998810], - [0.543112, 0.988722, 0.999850, 0.999998, 0.999998], - [0.576143, 0.998772, 0.999998, 1.000000, 1.000000], - [0.594340, 0.998847, 0.999998, 1.000000, 1.000000] + [0.425504, 0.466694, 0.530665, 0.580554, 0.609597], + [0.475907, 0.927820, 0.994913, 0.999646, 0.999682], + [0.550638, 0.995125, 0.999972, 1.000000, 1.000000], + [0.606569, 0.999669, 1.000000, 1.000000, 1.000000], + [0.637583, 0.999705, 1.000000, 1.000000, 1.000000], ], ], ], @@ -439,7 +431,6 @@ @skip_if_no_cpp_extention class CRFTestCaseCpu(unittest.TestCase): - @parameterized.expand(TEST_CASES) def test(self, test_case_description, params, input, features, expected): @@ -453,7 +444,6 @@ def test(self, test_case_description, params, input, features, expected): # Ensure result are as expected np.testing.assert_allclose(output, expected, atol=1e-4) - if __name__ == "__main__": diff --git a/tests/test_crf_cuda.py b/tests/test_crf_cuda.py index 1e663b6670..deca250aea 100644 --- a/tests/test_crf_cuda.py +++ b/tests/test_crf_cuda.py @@ -24,13 +24,14 @@ "2 batche(s), 1 dimension(s), 2 classe(s), 1 channel(s)", # Parameters [ - 3.0, # bilateral_weight - 1.0, # gaussian_weight - 5.0, # bilateral_spatial_sigma - 0.5, # bilateral_color_sigma - 5.0, # gaussian_spatial_sigma - 1, # compatability_kernel_range - 5, # iterations + 1.0, # bilateral_weight + 0.3, # gaussian_weight + 5.0, # bilateral_spatial_sigma + 0.5, # bilateral_color_sigma + 5.0, # gaussian_spatial_sigma + 1.0, # update_factor + 1, # compatability_kernel_range + 5, # iterations ], # Input [ @@ -38,17 +39,15 @@ [ # Class 0 [0.8, 0.9, 0.6, 0.2, 0.3], - # Class 1 - [0.1, 0.3, 0.5, 0.8, 0.7] + [0.1, 0.3, 0.5, 0.8, 0.7], ], # Batch 1 [ # Class 0 [0.8, 0.9, 0.6, 0.2, 0.3], - # Class 1 - [0.1, 0.3, 0.5, 0.8, 0.7] + [0.1, 0.3, 0.5, 0.8, 0.7], ], ], # Features @@ -69,18 +68,16 @@ # Batch 0 [ # Class 0 - [0.89333 , 0.881837, 0.787194, 0.55363 , 0.507627], - + [0.965345, 0.961201, 0.920527, 0.772525, 0.711900], # Class 1 - [0.10667 , 0.118163, 0.212806, 0.44637 , 0.492373] + [0.034655, 0.038799, 0.079473, 0.227475, 0.288100], ], # Batch 1 [ # Class 0 - [0.846356, 0.777572, 0.536503, 0.241165, 0.232537], - + [0.897615, 0.816166, 0.500186, 0.158644, 0.133245], # Class 1 - [0.153644, 0.222428, 0.463497, 0.758835, 0.767463] + [0.102385, 0.183834, 0.499814, 0.841356, 0.866755], ], ], ], @@ -89,13 +86,14 @@ "1 batche(s), 2 dimension(s), 3 classe(s), 2 channel(s)", # Parameters [ - 3.0, # bilateral_weight - 1.0, # gaussian_weight - 5.0, # bilateral_spatial_sigma - 0.5, # bilateral_color_sigma - 5.0, # gaussian_spatial_sigma - 1, # compatability_kernel_range - 5, # iterations + 1.0, # bilateral_weight + 0.3, # gaussian_weight + 5.0, # bilateral_spatial_sigma + 0.5, # bilateral_color_sigma + 5.0, # gaussian_spatial_sigma + 1.0, # update_factor + 1, # compatability_kernel_range + 5, # iterations ], # Input [ @@ -109,7 +107,6 @@ [0.0, 0.0, 0.0, 1.0, 1.0], [0.0, 0.0, 0.0, 1.0, 1.0], ], - # Class 1 [ [1.0, 1.0, 0.0, 0.0, 0.0], @@ -118,14 +115,13 @@ [0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0], ], - # Class 2 [ - [0.0, 0.0, 0.0, 1.0, 1.0], - [0.0, 0.0, 1.0, 1.0, 1.0], - [0.0, 1.0, 1.0, 1.0, 0.0], - [1.0, 1.0, 1.0, 0.0, 0.0], - [1.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.5, 1.0], + [0.0, 0.0, 0.5, 1.0, 0.5], + [0.0, 0.5, 1.0, 0.5, 0.0], + [0.5, 1.0, 0.5, 0.0, 0.0], + [1.0, 0.5, 0.0, 0.0, 0.0], ], ], ], @@ -141,7 +137,6 @@ [0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0], ], - # Channel 1 [ [0.0, 0.0, 0.0, 0.0, 0.0], @@ -158,29 +153,27 @@ [ # Class 0 [ - [0.000257, 0.000283, 0.000311, 0.000126, 0.000126], - [0.000376, 0.000458, 0.000216, 0.000257, 0.000264], - [0.000693, 0.000347, 0.001017, 0.002675, 0.015330], - [0.000604, 0.000824, 0.004701, 0.148476, 0.425911], - [0.000987, 0.001379, 0.038858, 0.516412, 0.896450] + [0.001529, 0.000798, 0.000323, 0.000093, 0.000053], + [0.001365, 0.000966, 0.000422, 0.000178, 0.000281], + [0.001405, 0.001007, 0.002425, 0.013078, 0.064707], + [0.001239, 0.001263, 0.033857, 0.665830, 0.951172], + [0.001534, 0.004486, 0.263298, 0.973852, 0.999018], ], - # Class 1 [ - [0.000702, 0.000664, 0.000230, 0.000081, 0.000080], - [0.000771, 0.000793, 0.000113, 0.000115, 0.000116], - [0.000348, 0.000144, 0.000237, 0.000370, 0.001437], - [0.000177, 0.000207, 0.000491, 0.002529, 0.003492], - [0.000224, 0.000265, 0.002267, 0.003538, 0.002137] + [0.230989, 0.025518, 0.000764, 0.000057, 0.000029], + [0.037540, 0.008348, 0.000381, 0.000055, 0.000075], + [0.001987, 0.000665, 0.000363, 0.000499, 0.001170], + [0.000187, 0.000143, 0.000805, 0.001361, 0.000533], + [0.000131, 0.000286, 0.002139, 0.000410, 0.000069], ], - # Class 2 [ - [0.999041, 0.999054, 0.999459, 0.999793, 0.999793], - [0.998852, 0.998749, 0.999672, 0.999628, 0.999621], - [0.998959, 0.999509, 0.998746, 0.996955, 0.983234], - [0.999219, 0.998969, 0.994808, 0.848995, 0.570597], - [0.998789, 0.998356, 0.958874, 0.480050, 0.101413] + [0.767482, 0.973685, 0.998913, 0.999850, 0.999919], + [0.961095, 0.990687, 0.999197, 0.999768, 0.999644], + [0.996608, 0.998328, 0.997212, 0.986423, 0.934124], + [0.998574, 0.998594, 0.965337, 0.332809, 0.048295], + [0.998334, 0.995228, 0.734563, 0.025738, 0.000912], ], ], ], @@ -190,13 +183,14 @@ "1 batche(s), 3 dimension(s), 2 classe(s), 1 channel(s)", # Parameters [ - 8.0, # bilateral_weight - 1.0, # gaussian_weight - 5.0, # bilateral_spatial_sigma - 0.1, # bilateral_color_sigma - 5.0, # gaussian_spatial_sigma - 1, # compatability_kernel_range - 2, # iterations + 1.0, # bilateral_weight + 0.3, # gaussian_weight + 5.0, # bilateral_spatial_sigma + 0.1, # bilateral_color_sigma + 5.0, # gaussian_spatial_sigma + 1.0, # update_factor + 1, # compatability_kernel_range + 2, # iterations ], # Input [ @@ -245,7 +239,6 @@ [0.0, 0.0, 0.0, 0.0, 0.0], ], ], - # Class 1 [ # Slice 0 @@ -348,87 +341,86 @@ [ # Slice 0 [ - [1.000000, 1.000000, 0.999999, 0.999400, 0.672704], - [1.000000, 1.000000, 0.999999, 0.999311, 0.634444], - [0.999999, 0.999999, 0.999923, 0.992283, 0.570609], - [0.999347, 0.999270, 0.992103, 0.915680, 0.497763], - [0.641561, 0.611654, 0.561265, 0.496694, 0.447266] + [1.000000, 1.000000, 1.000000, 0.999884, 0.769625], + [1.000000, 1.000000, 1.000000, 0.999851, 0.714004], + [1.000000, 1.000000, 0.999988, 0.997150, 0.614165], + [0.999862, 0.999832, 0.996976, 0.945058, 0.497088], + [0.720345, 0.672450, 0.590360, 0.490120, 0.416671], ], # Slice 1 [ - [1.000000, 1.000000, 0.999999, 0.999297, 0.628315], - [1.000000, 1.000000, 0.999996, 0.991844, 0.143119], - [0.999999, 0.999995, 0.998187, 0.561012, 0.015132], - [0.999211, 0.991375, 0.560541, 0.013006, 0.001419], - [0.584354, 0.132132, 0.015294, 0.001480, 0.001224] + [1.000000, 1.000000, 1.000000, 0.999848, 0.707997], + [1.000000, 1.000000, 1.000000, 0.997064, 0.127893], + [1.000000, 1.000000, 0.999469, 0.591574, 0.007791], + [0.999812, 0.996663, 0.582521, 0.006041, 0.000427], + [0.637809, 0.107586, 0.007432, 0.000437, 0.000333], ], # Slice 2 [ - [0.999999, 0.999999, 0.999920, 0.991880, 0.556050], - [0.999999, 0.999995, 0.998105, 0.545511, 0.014186], - [0.999915, 0.998068, 0.566432, 0.003078, 0.000158], - [0.990786, 0.537158, 0.003170, 0.000008, 0.000002], - [0.510227, 0.013575, 0.000172, 0.000002, 0.000002] + [1.000000, 1.000000, 0.999987, 0.996994, 0.600095], + [1.000000, 1.000000, 0.999441, 0.575839, 0.007303], + [0.999986, 0.999411, 0.587268, 0.001117, 0.000033], + [0.996210, 0.550023, 0.001114, 0.000001, 0.000000], + [0.520757, 0.006334, 0.000034, 0.000000, 0.000000], ], # Slice 3 [ - [0.999278, 0.999206, 0.991367, 0.909220, 0.479173], - [0.999193, 0.990973, 0.534392, 0.011364, 0.001236], - [0.990854, 0.535103, 0.002982, 0.000007, 0.000002], - [0.902482, 0.011722, 0.000008, 0.000000, 0.000000], - [0.456381, 0.001320, 0.000002, 0.000000, 0.000000] + [0.999834, 0.999807, 0.996617, 0.940887, 0.482334], + [0.999799, 0.996410, 0.553696, 0.005287, 0.000376], + [0.996193, 0.546801, 0.001047, 0.000001, 0.000000], + [0.930515, 0.005142, 0.000001, 0.000000, 0.000000], + [0.430705, 0.000371, 0.000000, 0.000000, 0.000000], ], # Slice 4 [ - [0.606149, 0.580932, 0.528721, 0.472177, 0.429050], - [0.576264, 0.124628, 0.013063, 0.001210, 0.001002], - [0.517166, 0.013384, 0.000150, 0.000002, 0.000001], - [0.467132, 0.001316, 0.000002, 0.000000, 0.000000], - [0.432737, 0.001164, 0.000002, 0.000000, 0.000000] + [0.665227, 0.627316, 0.550517, 0.467839, 0.406319], + [0.617408, 0.098325, 0.006247, 0.000359, 0.000278], + [0.524800, 0.006229, 0.000030, 0.000000, 0.000000], + [0.443054, 0.000372, 0.000000, 0.000000, 0.000000], + [0.388126, 0.000305, 0.000000, 0.000000, 0.000000], ], ], - # Class 1 [ # Slice 0 [ - [0.000000, 0.000000, 0.000001, 0.000600, 0.327296], - [0.000000, 0.000000, 0.000001, 0.000689, 0.365556], - [0.000001, 0.000001, 0.000077, 0.007717, 0.429391], - [0.000653, 0.000729, 0.007897, 0.084320, 0.502237], - [0.358439, 0.388346, 0.438735, 0.503306, 0.552734] + [0.000000, 0.000000, 0.000000, 0.000116, 0.230375], + [0.000000, 0.000000, 0.000000, 0.000149, 0.285996], + [0.000000, 0.000000, 0.000012, 0.002850, 0.385835], + [0.000138, 0.000168, 0.003024, 0.054942, 0.502912], + [0.279655, 0.327550, 0.409640, 0.509880, 0.583329], ], # Slice 1 [ - [0.000000, 0.000000, 0.000001, 0.000703, 0.371685], - [0.000000, 0.000000, 0.000004, 0.008156, 0.856880], - [0.000001, 0.000005, 0.001814, 0.438988, 0.984868], - [0.000789, 0.008625, 0.439459, 0.986994, 0.998581], - [0.415646, 0.867868, 0.984706, 0.998520, 0.998776] + [0.000000, 0.000000, 0.000000, 0.000152, 0.292003], + [0.000000, 0.000000, 0.000000, 0.002936, 0.872107], + [0.000000, 0.000000, 0.000531, 0.408426, 0.992209], + [0.000188, 0.003337, 0.417479, 0.993959, 0.999574], + [0.362191, 0.892414, 0.992568, 0.999564, 0.999667], ], # Slice 2 [ - [0.000001, 0.000001, 0.000080, 0.008120, 0.443950], - [0.000001, 0.000005, 0.001895, 0.454489, 0.985814], - [0.000085, 0.001932, 0.433568, 0.996922, 0.999842], - [0.009214, 0.462842, 0.996830, 0.999992, 0.999998], - [0.489773, 0.986425, 0.999828, 0.999998, 0.999998] + [0.000000, 0.000000, 0.000013, 0.003006, 0.399905], + [0.000000, 0.000000, 0.000559, 0.424161, 0.992697], + [0.000014, 0.000589, 0.412732, 0.998884, 0.999967], + [0.003790, 0.449977, 0.998886, 0.999999, 1.000000], + [0.479243, 0.993666, 0.999966, 1.000000, 1.000000], ], # Slice 3 [ - [0.000722, 0.000794, 0.008633, 0.090780, 0.520827], - [0.000807, 0.009027, 0.465608, 0.988636, 0.998764], - [0.009146, 0.464897, 0.997018, 0.999993, 0.999998], - [0.097518, 0.988278, 0.999992, 1.000000, 1.000000], - [0.543619, 0.998680, 0.999998, 1.000000, 1.000000] + [0.000166, 0.000193, 0.003383, 0.059113, 0.517666], + [0.000201, 0.003590, 0.446304, 0.994713, 0.999624], + [0.003807, 0.453199, 0.998953, 0.999999, 1.000000], + [0.069485, 0.994858, 0.999999, 1.000000, 1.000000], + [0.569295, 0.999629, 1.000000, 1.000000, 1.000000], ], # Slice 4 [ - [0.393851, 0.419068, 0.471279, 0.527823, 0.570950], - [0.423736, 0.875372, 0.986937, 0.998790, 0.998998], - [0.482834, 0.986616, 0.999850, 0.999998, 0.999999], - [0.532868, 0.998684, 0.999998, 1.000000, 1.000000], - [0.567263, 0.998836, 0.999998, 1.000000, 1.000000] + [0.334773, 0.372684, 0.449483, 0.532161, 0.593681], + [0.382592, 0.901675, 0.993753, 0.999641, 0.999722], + [0.475200, 0.993771, 0.999970, 1.000000, 1.000000], + [0.556946, 0.999628, 1.000000, 1.000000, 1.000000], + [0.611874, 0.999695, 1.000000, 1.000000, 1.000000], ], ], ], @@ -439,7 +431,6 @@ @skip_if_no_cpp_extention class CRFTestCaseCuda(unittest.TestCase): - @parameterized.expand(TEST_CASES) def test(self, test_case_description, params, input, features, expected): From 0ecc530ff15b20dc49270a7268e8b9925d1555d2 Mon Sep 17 00:00:00 2001 From: chaliebudd Date: Thu, 18 Mar 2021 14:53:44 +0000 Subject: [PATCH 08/16] skipping cuda test if no cuda Signed-off-by: chaliebudd --- tests/test_crf_cuda.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_crf_cuda.py b/tests/test_crf_cuda.py index deca250aea..1b41bcc0e5 100644 --- a/tests/test_crf_cuda.py +++ b/tests/test_crf_cuda.py @@ -430,6 +430,7 @@ @skip_if_no_cpp_extention +@skip_if_no_cuda class CRFTestCaseCuda(unittest.TestCase): @parameterized.expand(TEST_CASES) def test(self, test_case_description, params, input, features, expected): From bfde89d7b0182f799b06615e7a229e3d8b5ad418 Mon Sep 17 00:00:00 2001 From: chaliebudd Date: Thu, 18 Mar 2021 15:07:39 +0000 Subject: [PATCH 09/16] adding crf to docs Signed-off-by: chaliebudd --- docs/source/networks.rst | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/docs/source/networks.rst b/docs/source/networks.rst index cf383d2908..fc644fa8ff 100644 --- a/docs/source/networks.rst +++ b/docs/source/networks.rst @@ -20,6 +20,11 @@ Blocks .. autoclass:: Convolution :members: +`CRF` +~~~~~~~~~~~~~ +.. autoclass:: CRF + :members: + `ResidualUnit` ~~~~~~~~~~~~~~ .. autoclass:: ResidualUnit From 20de01d1c274994491b4b708cb0a951a82d1d955 Mon Sep 17 00:00:00 2001 From: chaliebudd Date: Wed, 24 Mar 2021 09:21:57 +0000 Subject: [PATCH 10/16] fixing typo Signed-off-by: chaliebudd --- tests/test_crf_cpu.py | 4 ++-- tests/test_crf_cuda.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/test_crf_cpu.py b/tests/test_crf_cpu.py index 660278167c..f6e82d16a5 100644 --- a/tests/test_crf_cpu.py +++ b/tests/test_crf_cpu.py @@ -16,7 +16,7 @@ from parameterized import parameterized from monai.networks.blocks import CRF -from tests.utils import skip_if_no_cpp_extention +from tests.utils import skip_if_no_cpp_extension TEST_CASES = [ [ @@ -429,7 +429,7 @@ ] -@skip_if_no_cpp_extention +@skip_if_no_cpp_extension class CRFTestCaseCpu(unittest.TestCase): @parameterized.expand(TEST_CASES) def test(self, test_case_description, params, input, features, expected): diff --git a/tests/test_crf_cuda.py b/tests/test_crf_cuda.py index 1b41bcc0e5..4decd433fa 100644 --- a/tests/test_crf_cuda.py +++ b/tests/test_crf_cuda.py @@ -16,7 +16,7 @@ from parameterized import parameterized from monai.networks.blocks import CRF -from tests.utils import skip_if_no_cpp_extention, skip_if_no_cuda +from tests.utils import skip_if_no_cpp_extension, skip_if_no_cuda TEST_CASES = [ [ @@ -429,7 +429,7 @@ ] -@skip_if_no_cpp_extention +@skip_if_no_cpp_extension @skip_if_no_cuda class CRFTestCaseCuda(unittest.TestCase): @parameterized.expand(TEST_CASES) From 6c1d21c79c0d8dac9b96d23b79caca943b0ad079 Mon Sep 17 00:00:00 2001 From: chaliebudd Date: Wed, 24 Mar 2021 09:38:42 +0000 Subject: [PATCH 11/16] fixing type declaration syntax Signed-off-by: chaliebudd --- monai/networks/blocks/crf.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/monai/networks/blocks/crf.py b/monai/networks/blocks/crf.py index 48ed265495..8bd056f00e 100644 --- a/monai/networks/blocks/crf.py +++ b/monai/networks/blocks/crf.py @@ -58,14 +58,14 @@ class CRF(torch.nn.Module): def __init__( self, - bilateral_weight: [float] = 1.0, - gaussian_weight: [float] = 1.0, - bilateral_spatial_sigma: [float] = 5.0, - bilateral_color_sigma: [float] = 0.5, - gaussian_spatial_sigma: [float] = 5.0, - update_factor: [float] = 3.0, - compatability_kernel_range: [int] = 1, - iterations: [int] = 5, + bilateral_weight: float = 1.0, + gaussian_weight: float = 1.0, + bilateral_spatial_sigma: float = 5.0, + bilateral_color_sigma: float = 0.5, + gaussian_spatial_sigma: float = 5.0, + update_factor: float = 3.0, + compatability_kernel_range: int = 1, + iterations: int = 5, ): super(CRF, self).__init__() self.bilateral_weight = bilateral_weight From 749b1cf3057975904244b5093900b5d85fc07330 Mon Sep 17 00:00:00 2001 From: chaliebudd Date: Wed, 24 Mar 2021 09:41:29 +0000 Subject: [PATCH 12/16] correcting docstrings Signed-off-by: chaliebudd --- monai/networks/blocks/crf.py | 42 +++++++++++++++++++----------------- 1 file changed, 22 insertions(+), 20 deletions(-) diff --git a/monai/networks/blocks/crf.py b/monai/networks/blocks/crf.py index 8bd056f00e..5772c66a46 100644 --- a/monai/networks/blocks/crf.py +++ b/monai/networks/blocks/crf.py @@ -30,12 +30,21 @@ class CRF(torch.nn.Module): See: https://arxiv.org/abs/1502.03240 + """ - Args: - input_tensor: tensor containing initial class logits. - - referenece_tensor: the reference tensor used to guide the message passing. - + def __init__( + self, + bilateral_weight: float = 1.0, + gaussian_weight: float = 1.0, + bilateral_spatial_sigma: float = 5.0, + bilateral_color_sigma: float = 0.5, + gaussian_spatial_sigma: float = 5.0, + update_factor: float = 3.0, + compatability_kernel_range: int = 1, + iterations: int = 5, + ): + """ + Args: bilateral_weight: the weighting of the bilateral term in the message passing step. gaussian_weight: the weighting of the gaussian term in the message passing step. @@ -51,22 +60,7 @@ class CRF(torch.nn.Module): compatability_kernel_range: the range of the kernel used in the compatability convolution. iterations: the number of iterations. - - Returns: - output (torch.Tensor): output tensor. """ - - def __init__( - self, - bilateral_weight: float = 1.0, - gaussian_weight: float = 1.0, - bilateral_spatial_sigma: float = 5.0, - bilateral_color_sigma: float = 0.5, - gaussian_spatial_sigma: float = 5.0, - update_factor: float = 3.0, - compatability_kernel_range: int = 1, - iterations: int = 5, - ): super(CRF, self).__init__() self.bilateral_weight = bilateral_weight self.gaussian_weight = gaussian_weight @@ -78,6 +72,14 @@ def __init__( self.iterations = iterations def forward(self, input_tensor: torch.Tensor, reference_tensor: torch.Tensor): + """ + Args: + input_tensor: tensor containing initial class logits. + referenece_tensor: the reference tensor used to guide the message passing. + + Returns: + output (torch.Tensor): output tensor. + """ # useful values spatial_dim = input_tensor.dim() - 2 From 86567bb4cc7513a969eacdbf6757db8c0b8fcc3c Mon Sep 17 00:00:00 2001 From: chaliebudd Date: Wed, 24 Mar 2021 09:42:06 +0000 Subject: [PATCH 13/16] removing whitespace in docstring Signed-off-by: chaliebudd --- monai/networks/blocks/crf.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/monai/networks/blocks/crf.py b/monai/networks/blocks/crf.py index 5772c66a46..619e1751ef 100644 --- a/monai/networks/blocks/crf.py +++ b/monai/networks/blocks/crf.py @@ -46,19 +46,12 @@ def __init__( """ Args: bilateral_weight: the weighting of the bilateral term in the message passing step. - gaussian_weight: the weighting of the gaussian term in the message passing step. - bilateral_spatial_sigma: standard deviation in spatial coordinates for the bilateral term. - bilateral_color_sigma: standard deviation in color space for the bilateral term. - gaussian_spatial_sigma: standard deviation in spatial coordinates for the gaussian term. - update_factor: determines the magnitude of each update. - compatability_kernel_range: the range of the kernel used in the compatability convolution. - iterations: the number of iterations. """ super(CRF, self).__init__() From 1efdd39feee5b594fdbc7eecf02117b4383f8470 Mon Sep 17 00:00:00 2001 From: chaliebudd Date: Wed, 24 Mar 2021 10:01:01 +0000 Subject: [PATCH 14/16] fixed docstring indent Signed-off-by: chaliebudd --- monai/networks/blocks/crf.py | 38 ++++++++++++++++++------------------ 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/monai/networks/blocks/crf.py b/monai/networks/blocks/crf.py index 619e1751ef..27556a2c72 100644 --- a/monai/networks/blocks/crf.py +++ b/monai/networks/blocks/crf.py @@ -43,17 +43,17 @@ def __init__( compatability_kernel_range: int = 1, iterations: int = 5, ): - """ - Args: - bilateral_weight: the weighting of the bilateral term in the message passing step. - gaussian_weight: the weighting of the gaussian term in the message passing step. - bilateral_spatial_sigma: standard deviation in spatial coordinates for the bilateral term. - bilateral_color_sigma: standard deviation in color space for the bilateral term. - gaussian_spatial_sigma: standard deviation in spatial coordinates for the gaussian term. - update_factor: determines the magnitude of each update. - compatability_kernel_range: the range of the kernel used in the compatability convolution. - iterations: the number of iterations. - """ + """ + Args: + bilateral_weight: the weighting of the bilateral term in the message passing step. + gaussian_weight: the weighting of the gaussian term in the message passing step. + bilateral_spatial_sigma: standard deviation in spatial coordinates for the bilateral term. + bilateral_color_sigma: standard deviation in color space for the bilateral term. + gaussian_spatial_sigma: standard deviation in spatial coordinates for the gaussian term. + update_factor: determines the magnitude of each update. + compatability_kernel_range: the range of the kernel used in the compatability convolution. + iterations: the number of iterations. + """ super(CRF, self).__init__() self.bilateral_weight = bilateral_weight self.gaussian_weight = gaussian_weight @@ -65,14 +65,14 @@ def __init__( self.iterations = iterations def forward(self, input_tensor: torch.Tensor, reference_tensor: torch.Tensor): - """ - Args: - input_tensor: tensor containing initial class logits. - referenece_tensor: the reference tensor used to guide the message passing. - - Returns: - output (torch.Tensor): output tensor. - """ + """ + Args: + input_tensor: tensor containing initial class logits. + referenece_tensor: the reference tensor used to guide the message passing. + + Returns: + output (torch.Tensor): output tensor. + """ # useful values spatial_dim = input_tensor.dim() - 2 From 52d07e1ae19a6915e793679e16ebc0e2b7b54bef Mon Sep 17 00:00:00 2001 From: chaliebudd Date: Wed, 24 Mar 2021 13:19:41 +0000 Subject: [PATCH 15/16] fixing backwards method in PHLFilter Signed-off-by: chaliebudd --- monai/networks/layers/filtering.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/networks/layers/filtering.py b/monai/networks/layers/filtering.py index 83a33bc609..8de1c7168c 100644 --- a/monai/networks/layers/filtering.py +++ b/monai/networks/layers/filtering.py @@ -94,5 +94,5 @@ def forward(ctx, input, features, sigmas=None): @staticmethod def backward(ctx, grad_output): scaled_features = ctx.saved_variables - grad_input = PHLFilter.scale(grad_output, scaled_features) + grad_input = _C.phl_filter(grad_output, scaled_features) return grad_input From 1d0d374238c29939ddeb584290d8695283ebbd46 Mon Sep 17 00:00:00 2001 From: chaliebudd Date: Wed, 24 Mar 2021 16:35:29 +0000 Subject: [PATCH 16/16] raising error when attempting to backprop through PHLFilter Signed-off-by: chaliebudd --- monai/networks/layers/filtering.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/monai/networks/layers/filtering.py b/monai/networks/layers/filtering.py index 2bf10d98d6..7eca03a280 100644 --- a/monai/networks/layers/filtering.py +++ b/monai/networks/layers/filtering.py @@ -93,6 +93,7 @@ def forward(ctx, input, features, sigmas=None): @staticmethod def backward(ctx, grad_output): - scaled_features = ctx.saved_variables - grad_input = _C.phl_filter(grad_output, scaled_features) - return grad_input + raise NotImplementedError("PHLFilter does not currently support backpropergation") + # scaled_features, = ctx.saved_variables + # grad_input = _C.phl_filter(grad_output, scaled_features) + # return grad_input