Add class-wise metrics logging and confusion matrix to DeepMIL#647
Conversation
| tile_size: int = 224, | ||
| level: int = 1) -> None: | ||
| level: int = 1, | ||
| class_names: List[str] = None) -> None: |
There was a problem hiding this comment.
Type annotation has a small discrepancy - if you set the default to None, it should be "Optional[List[str]]".
There was a problem hiding this comment.
Thanks, now changed
| 'precision': Precision(), | ||
| 'recall': Recall(), | ||
| 'f1score': F1(), | ||
| 'confusion_matrix': ConfusionMatrix(num_classes=self.n_classes+1)}) |
There was a problem hiding this comment.
Seeing this here: It would be good to add documentation for n_classes beyond what you have now. For two classes "0" and "1", n_classes should be set to 1, correct?
There was a problem hiding this comment.
Thanks, added to the docstring
| for metric_name, metric_object in self.get_metrics_dict(stage).items(): | ||
| self.log(f'{stage}/{metric_name}', metric_object, on_epoch=True, on_step=False, logger=True, sync_dist=True) | ||
| if not metric_name == "confusion_matrix": | ||
| self.log(f'{stage}/{metric_name}', metric_object, on_epoch=True, on_step=False, logger=True, sync_dist=True) |
There was a problem hiding this comment.
can you replace the self.log calls with hi-ml's "log_on_epoch" function? that gives you a simpler interface and handle sync_dist better.
There was a problem hiding this comment.
wondering if log_on_epoch will work with torch module metrics objects?
There was a problem hiding this comment.
This is now changed, thanks
| raise Exception(f"Invalid stage. Chose one of {valid_stages}") | ||
| for metric_name, metric_object in self.get_metrics_dict(stage).items(): | ||
| self.log(f'{stage}/{metric_name}', metric_object, on_epoch=True, on_step=False, logger=True, sync_dist=True) | ||
| if not metric_name == "confusion_matrix": |
There was a problem hiding this comment.
I see you have some if statements that are "if not A then something() else otherthing()". If you have anyway handling both cases, it is easier to read the code if you do not negate the condition and swap if/else.
There was a problem hiding this comment.
Thanks, I swapped if/else
|
|
||
| print("Computing and saving confusion matrix...") | ||
| metrics_dict = self.get_metrics_dict('test') | ||
| cf_matrix = metrics_dict["confusion_matrix"].compute() |
There was a problem hiding this comment.
This literal "confusion_matrix" needs to be in sync with your other uses of "confusion_matrix". This is an extremely common source of errors - you change the constant somewhere, and forget to change is somewhere else (as trivial/benign as this may sound). Your code will be a lot safer (and require fewer tests) if you define those literals as constants, CONF_MATRIX_METRIC = "confusion_matrix" and re-use it.
There was a problem hiding this comment.
Thanks for the suggestion, now defined metrics names as constants
| absolute_checkpoint_path_parent = Path(fixed_paths.repository_parent_directory(), | ||
| self.checkpoint_folder_path, | ||
| self.best_checkpoint_filename_with_suffix) | ||
| if absolute_checkpoint_path_parent.is_file(): | ||
| return absolute_checkpoint_path_parent |
There was a problem hiding this comment.
Confused. This and the variable above are exactly the same?
There was a problem hiding this comment.
I have taken this from our other config.
It was only to enable this config to work and may not be related to this PR. Will remove for now for another PR| absolute_checkpoint_path = Path(fixed_paths.repository_root_directory(), | ||
| self.checkpoint_folder_path, | ||
| self.best_checkpoint_filename_with_suffix) | ||
| if absolute_checkpoint_path.is_file(): | ||
| return absolute_checkpoint_path |
There was a problem hiding this comment.
I'm puzzled. Checkpoints are normally stored in the "outputs" folder, so that they are also available at the end of an AzureML run. Why are the checkpoints here accessed as part of repository root?
There was a problem hiding this comment.
I have taken this from our other config
It was only to enable the Crck config to work and may not be related to this PR. Will remove for now for another PR
| assert file.exists() | ||
| expected = full_ml_test_data_path("histo_heatmaps") / f"confusion_matrix_{n_classes}.png" | ||
| # To update the stored results, uncomment this line: | ||
| expected.write_bytes(file.read_bytes()) |
There was a problem hiding this comment.
This should not be a checked in - your test will always pass now!
| expected.write_bytes(file.read_bytes()) | |
| # expected.write_bytes(file.read_bytes()) |
There was a problem hiding this comment.
Thanks for spotting this, changed now
dccastro
left a comment
There was a problem hiding this comment.
Great job! No other comments 👍
In this PR: