diff --git a/monai/networks/blocks/crf.py b/monai/networks/blocks/crf.py index 736259cb96..219f1a70e2 100644 --- a/monai/networks/blocks/crf.py +++ b/monai/networks/blocks/crf.py @@ -43,7 +43,7 @@ def __init__( bilateral_color_sigma: float = 0.5, gaussian_spatial_sigma: float = 5.0, update_factor: float = 3.0, - compatability_matrix: Optional[torch.Tensor] = None, + compatibility_matrix: Optional[torch.Tensor] = None, ): """ Args: @@ -54,7 +54,7 @@ def __init__( bilateral_color_sigma: standard deviation in color space for the bilateral term. gaussian_spatial_sigma: standard deviation in spatial coordinates for the gaussian term. update_factor: determines the magnitude of each update. - compatability_matrix: a matrix describing class compatability, should be NxN where N is the numer of classes. + compatibility_matrix: a matrix describing class compatibility, should be NxN where N is the numer of classes. """ super(CRF, self).__init__() self.iterations = iterations @@ -64,7 +64,7 @@ def __init__( self.bilateral_color_sigma = bilateral_color_sigma self.gaussian_spatial_sigma = gaussian_spatial_sigma self.update_factor = update_factor - self.compatability_matrix = compatability_matrix + self.compatibility_matrix = compatibility_matrix def forward(self, input_tensor: torch.Tensor, reference_tensor: torch.Tensor): """ @@ -98,10 +98,10 @@ def forward(self, input_tensor: torch.Tensor, reference_tensor: torch.Tensor): # combining filter outputs combined_output = self.bilateral_weight * bliateral_output + self.gaussian_weight * gaussian_output - # optionally running a compatability transform - if self.compatability_matrix is not None: + # optionally running a compatibility transform + if self.compatibility_matrix is not None: flat = combined_output.flatten(start_dim=2).permute(0, 2, 1) - flat = torch.matmul(flat, self.compatability_matrix) + flat = torch.matmul(flat, self.compatibility_matrix) combined_output = flat.permute(0, 2, 1).reshape(combined_output.shape) # update and normalize