Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions monai/networks/blocks/crf.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def __init__(
bilateral_color_sigma: float = 0.5,
gaussian_spatial_sigma: float = 5.0,
update_factor: float = 3.0,
compatability_matrix: Optional[torch.Tensor] = None,
compatibility_matrix: Optional[torch.Tensor] = None,
):
"""
Args:
Expand All @@ -54,7 +54,7 @@ def __init__(
bilateral_color_sigma: standard deviation in color space for the bilateral term.
gaussian_spatial_sigma: standard deviation in spatial coordinates for the gaussian term.
update_factor: determines the magnitude of each update.
compatability_matrix: a matrix describing class compatability, should be NxN where N is the numer of classes.
compatibility_matrix: a matrix describing class compatibility, should be NxN where N is the numer of classes.
"""
super(CRF, self).__init__()
self.iterations = iterations
Expand All @@ -64,7 +64,7 @@ def __init__(
self.bilateral_color_sigma = bilateral_color_sigma
self.gaussian_spatial_sigma = gaussian_spatial_sigma
self.update_factor = update_factor
self.compatability_matrix = compatability_matrix
self.compatibility_matrix = compatibility_matrix

def forward(self, input_tensor: torch.Tensor, reference_tensor: torch.Tensor):
"""
Expand Down Expand Up @@ -98,10 +98,10 @@ def forward(self, input_tensor: torch.Tensor, reference_tensor: torch.Tensor):
# combining filter outputs
combined_output = self.bilateral_weight * bliateral_output + self.gaussian_weight * gaussian_output

# optionally running a compatability transform
if self.compatability_matrix is not None:
# optionally running a compatibility transform
if self.compatibility_matrix is not None:
flat = combined_output.flatten(start_dim=2).permute(0, 2, 1)
flat = torch.matmul(flat, self.compatability_matrix)
flat = torch.matmul(flat, self.compatibility_matrix)
combined_output = flat.permute(0, 2, 1).reshape(combined_output.shape)

# update and normalize
Expand Down