Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
1f75f24
Add missing wandb finish for aborted jobs
C-Achard Mar 4, 2024
c73685a
Update utils.py
C-Achard Mar 4, 2024
ee8f54d
Update plugin_model_training.py
C-Achard Mar 4, 2024
5329da2
Update plugin_model_training.py
C-Achard Mar 4, 2024
cc3d086
Fix num_epochs for aborted training in csv
C-Achard Mar 4, 2024
8f853b9
Fix Wnet weight loading
C-Achard Mar 4, 2024
af4e016
Fix reference to WNet weights for transfer learn
C-Achard Mar 4, 2024
0d2e74a
Update plugin_model_training.py
C-Achard Mar 4, 2024
be62e21
Improved safety for closing previous wandb runs
C-Achard Mar 4, 2024
97a4b35
Fix rogue Path instead of str
C-Achard Mar 4, 2024
80746ed
Update worker_training.py
C-Achard Mar 4, 2024
7ff2b63
Change bounds for WNet train parameters
C-Achard Mar 4, 2024
0a585c8
Shorten file names in train log
C-Achard Mar 4, 2024
86a8384
More range foor rec loss weight
C-Achard Mar 6, 2024
530b7a3
Update Dice coeff calculation
C-Achard Mar 6, 2024
3a418b2
Fix matching len check for csv
C-Achard Mar 11, 2024
8b57657
Update training_wnet.rst
C-Achard Mar 21, 2024
76218bd
Update WNet docs + typos
C-Achard Mar 21, 2024
de51029
Update utils.py
C-Achard Mar 21, 2024
1d13fe5
Change rec loss weight default
C-Achard Mar 25, 2024
3507c32
Change discrete output display for WNet training
C-Achard Mar 25, 2024
a933fd0
Fix incorrect checks for csv saving
C-Achard Mar 25, 2024
c22d2a6
Small improvement to make_csv
C-Achard Mar 25, 2024
6c5f740
Merge branch 'main' into cy/training-fixes
C-Achard Apr 13, 2024
74efbcf
Fix logic issue in make_csv
C-Achard Apr 15, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/guides/cropping_module_guide.rst
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ Once you have launched the review process, you will gain control over three slid
you to **adjust the position** of the cropped volumes and labels in the x,y and z positions.

.. note::
* If your **cropped volume isnt visible**, consider changing the **colormap** of the image and the cropped
* If your **cropped volume isn't visible**, consider changing the **colormap** of the image and the cropped
volume to improve their visibility.
* You may want to adjust the **opacity** and **contrast thresholds** depending on your image.
* If the image appears empty:
Expand Down
68 changes: 51 additions & 17 deletions docs/source/guides/training_wnet.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,50 @@ Walkthrough - WNet3D training
===============================

This plugin provides a reimplemented, custom version of the WNet3D model from `WNet, A Deep Model for Fully Unsupervised Image Segmentation`_.

For training your model, you can choose among:

* Directly within the plugin
* The provided Jupyter notebook (locally)
* Our Colab notebook (inspired by ZeroCostDL4Mic)
* Our Colab notebook (inspired by https://github.com/HenriquesLab/ZeroCostDL4Mic)

Selecting training data
-------------------------

The WNet3D **does not require a large amount of data to train**, but **choosing the right data** to train this unsupervised model **is crucial**.

You may find below some guidelines, based on our own data and testing.

The WNet3D is designed to segment objects based on their brightness, and is particularly well-suited for images with a clear contrast between objects and background.

The WNet3D is not suitable for images with artifacts, therefore care should be taken that the images are clean and that the objects are at least somewhat distinguishable from the background.


.. important::
For optimal performance, the following should be avoided for training:

- Images with very large, bright regions
- Almost-empty and empty images
- Images with large empty regions or "holes"

The WNet3D does not require a large amount of data to train, but during inference images should be similar to those
the model was trained on; you can retrain from our pretrained model to your image dataset to quickly reach good performance.
However, the model may be accomodate:

- Uneven brightness distribution
- Varied object shapes and radius
- Noisy images
- Uneven illumination across the image

For optimal results, during inference, images should be similar to those the model was trained on; however this is not a strict requirement.

You may also retrain from our pretrained model to your image dataset to help quickly reach good performance if, simply check "Use pre-trained weights" in the training module, and lower the learning rate.

.. note::
- The WNet3D relies on brightness to distinguish objects from the background. For better results, use image regions with minimal artifacts. If you notice many artifacts, consider training on one of the supervised models.
- The model has two losses, the **`SoftNCut loss`**, which clusters pixels according to brightness, and a reconstruction loss, either **`Mean Square Error (MSE)`** or **`Binary Cross Entropy (BCE)`**. Unlike the method described in the original paper, these losses are added in a weighted sum and the backward pass is performed for the whole model at once. The SoftNcuts and BCE are bounded between 0 and 1; the MSE may take large positive values. It is recommended to watch for the weighted sum of losses to be **close to one on the first epoch**, for training stability.
- For good performance, you should wait for the SoftNCut to reach a plateau; the reconstruction loss must also decrease but is generally less critical.
- The WNet3D relies on brightness to distinguish objects from the background. For better results, use image regions with minimal artifacts. If you notice many artifacts, consider trying one of our supervised models (for lightsheet microscopy).
- The model has two losses, the **`SoftNCut loss`**, which clusters pixels according to brightness, and a reconstruction loss, either **`Mean Square Error (MSE)`** or **`Binary Cross Entropy (BCE)`**.
- For good performance, wait for the SoftNCut to reach a plateau; the reconstruction loss should also be decreasing overall, but this is generally less critical for segmentation performance.

Parameters
----------
-------------

.. figure:: ../images/training_tab_4.png
:scale: 100 %
Expand All @@ -29,7 +57,7 @@ Parameters

_`When using the WNet3D training module`, the **Advanced** tab contains a set of additional options:

- **Number of classes** : Dictates the segmentation classes (default is 2). Increasing the number of classes will result in a more progressive segmentation according to brightness; can be useful if you have "halos" around your objects or artifacts with a significantly different brightness.
- **Number of classes** : Dictates the segmentation classes (default is 2). Increasing the number of classes will result in a more progressive segmentation according to brightness; can be useful if you have "halos" around your objects, or to approximate boundary labels.
- **Reconstruction loss** : Choose between MSE or BCE (default is MSE). MSE is more precise but also sensitive to outliers; BCE is more robust against outliers at the cost of precision.

- NCuts parameters:
Expand All @@ -43,22 +71,28 @@ _`When using the WNet3D training module`, the **Advanced** tab contains a set of

- Weights for the sum of losses :
- **NCuts weight** : Sets the weight of the NCuts loss (default is 0.5).
- **Reconstruction weight** : Sets the weight for the reconstruction loss (default is 0.5*1e-2).
- **Reconstruction weight** : Sets the weight for the reconstruction loss (default is 5*1e-3).

.. note::
The weight of the reconstruction loss should be adjusted to ensure the weighted sum is around one during the first epoch;
ideally the reconstruction loss should be of the same order of magnitude as the NCuts loss after being multiplied by its weight.
.. important::
The weight of the reconstruction loss should be adjusted to ensure that both losses are balanced.

This balance can be assessed using the live view of training outputs :
if the NCuts loss is "taking over", causing the segmentation to only label very large, brighter versus dimmer regions, the reconstruction loss should be increased.

This will help the model to focus on the details of the objects, rather than just the overall brightness of the volume.

Common issues troubleshooting
------------------------------
If you do not find a satisfactory answer here, please do not hesitate to `open an issue`_ on GitHub.

- **The NCuts loss explodes after a few epochs** : Lower the learning rate, first by a factor of two, then ten.
.. important::
If you do not find a satisfactory answer here, please do not hesitate to `open an issue`_ on GitHub.


- **The NCuts loss "explodes" after a few epochs** : Lower the learning rate, for example start with a factor of two, then ten.

- **The NCuts loss does not converge and is unstable** :
The normalization step might not be adapted to your images. Disable normalization and change intensity_sigma according to the distribution of values in your image. For reference, by default images are remapped to values between 0 and 100, and intensity_sigma=1.
- **Reconstruction (decoder) performance is poor** : First, try increasing the weight of the reconstruction loss. If this is ineffective, switch to BCE loss and set the scaling factor of the reconstruction loss to 0.5, OR adjust the weight of the MSE loss.

- **Reconstruction (decoder) performance is poor** : switch to BCE and set the scaling factor of the reconstruction loss to 0.5, OR adjust the weight of the MSE loss to make it closer to 1 in the weighted sum.
- **Segmentation only separates the brighter versus dimmer regions** : Increase the weight of the reconstruction loss.


.. _WNet, A Deep Model for Fully Unsupervised Image Segmentation: https://arxiv.org/abs/1711.08506
Expand Down
50 changes: 36 additions & 14 deletions napari_cellseg3d/code_models/worker_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,7 @@ def log_parameters(self):
self.log("-- Data --")
self.log("Training data :\n")
[
self.log(f"{v}")
self.log(f"{Path(v).stem}")
for d in self.config.train_data_dict
for k, v in d.items()
]
Expand Down Expand Up @@ -423,6 +423,11 @@ def train(
if WANDB_INSTALLED:
config_dict = self.config.__dict__
logger.debug(f"wandb config : {config_dict}")
if wandb.run is not None:
logger.warning(
"A previous wandb run is still active. It will be stopped before starting a new one."
)
wandb.finish()
wandb.init(
config=config_dict,
project="CellSeg3D - WNet",
Expand Down Expand Up @@ -472,13 +477,21 @@ def train(
if WANDB_INSTALLED:
wandb.watch(model, log_freq=100)

if self.config.weights_info.use_custom:
if (
self.config.weights_info.use_pretrained
or self.config.weights_info.use_custom
):
if self.config.weights_info.use_pretrained:
weights_file = "wnet.pth"
from napari_cellseg3d.code_models.models.model_WNet import (
WNet_,
)

weights_file = WNet_.weights_file
self.downloader.download_weights("WNet", weights_file)
weights = PRETRAINED_WEIGHTS_DIR / Path(weights_file)
weights = str(PRETRAINED_WEIGHTS_DIR / Path(weights_file))
self.config.weights_info.path = weights
else:

if self.config.weights_info.use_custom:
weights = str(Path(self.config.weights_info.path))

try:
Expand Down Expand Up @@ -624,6 +637,9 @@ def train(
del criterionW
torch.cuda.empty_cache()

if WANDB_INSTALLED:
wandb.finish()

self.ncuts_losses.append(
epoch_ncuts_loss / len(self.dataloader)
)
Expand All @@ -642,9 +658,7 @@ def train(
"cmap": "turbo",
},
"Encoder output (discrete)": {
"data": AsDiscrete(threshold=0.5)(
enc_out
).numpy(),
"data": np.where(enc_out > 0.5, enc_out, 0),
"cmap": "bop blue",
},
"Decoder output": {
Expand Down Expand Up @@ -736,7 +750,8 @@ def train(
if epoch % 5 == 0:
torch.save(
model.state_dict(),
self.config.results_path_folder + "/wnet_.pth",
self.config.results_path_folder
+ "/wnet_checkpoint.pth",
)

self.log("Training finished")
Expand Down Expand Up @@ -856,8 +871,7 @@ def eval(self, model, epoch) -> TrainingReport:
self.dice_metric(
y_pred=val_outputs[
:,
max_dice_channel : (max_dice_channel + 1),
:,
max_dice_channel:, # : (max_dice_channel + 1),
:,
:,
],
Expand Down Expand Up @@ -1120,6 +1134,11 @@ def train(
config_dict = self.config.__dict__
logger.debug(f"wandb config : {config_dict}")
try:
if wandb.run is not None:
logger.warning(
"A previous wandb run is still active. It will be stopped before starting a new one."
)
wandb.finish()
wandb.init(
config=config_dict,
project="CellSeg3D",
Expand Down Expand Up @@ -1410,13 +1429,13 @@ def get_patch_loader_func(num_samples):
# time = utils.get_date_time()
logger.debug("Weights")

if weights_config.use_custom:
if weights_config.use_custom or weights_config.use_pretrained:
if weights_config.use_pretrained:
weights_file = model_class.weights_file
self.downloader.download_weights(model_name, weights_file)
weights = PRETRAINED_WEIGHTS_DIR / Path(weights_file)
weights = str(PRETRAINED_WEIGHTS_DIR / Path(weights_file))
weights_config.path = weights
else:
elif weights_config.use_custom:
weights = str(Path(weights_config.path))

try:
Expand Down Expand Up @@ -1523,6 +1542,9 @@ def get_patch_loader_func(num_samples):
if device.type == "cuda":
torch.cuda.empty_cache()

if WANDB_INSTALLED:
wandb.finish()

yield TrainingReport(
show_plot=False,
weights=model.state_dict(),
Expand Down
66 changes: 42 additions & 24 deletions napari_cellseg3d/code_plugins/plugin_model_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -1054,7 +1054,7 @@ def _set_worker_config(
logger.debug("Loading config...")
model_config = config.ModelInfo(name=self.model_choice.currentText())

self.weights_config.path = self.weights_config.path
# self.weights_config.path = self.weights_config.path
self.weights_config.use_custom = self.custom_weights_choice.isChecked()

self.weights_config.use_pretrained = (
Expand Down Expand Up @@ -1365,31 +1365,44 @@ def on_yield(self, report: TrainingReport):
self.on_stop()
self._stop_requested = False

def _make_csv(self):
size_column = range(1, self.worker_config.max_epochs + 1)
def _check_lens(self, size_column, loss_values):
if len(size_column) != len(loss_values):
logger.info(
f"Training was stopped, setting epochs for csv to {len(loss_values)}"
)
return range(1, len(loss_values) + 1)
return size_column

if len(self.loss_1_values) == 0 or self.loss_1_values is None:
def _handle_loss_values(self, size_column, key):
loss_values = self.loss_1_values.get(key)
if loss_values is None:
return None

if len(loss_values) == 0:
logger.warning("No loss values to add to csv !")
return
return None

try:
self.loss_1_values["Loss"]
supervised = True
except KeyError:
try:
self.loss_1_values["SoftNCuts"]
supervised = False
except KeyError as e:
raise KeyError(
"Error when making csv. Check loss dict keys ?"
) from e
return self._check_lens(size_column, loss_values)

def _make_csv(self): # TDOD(cyril) design could use a good rework
size_column = range(1, self.worker_config.max_epochs + 1)

supervised = True
size_column = self._handle_loss_values(size_column, "Loss")
if size_column is None:
size_column = self._handle_loss_values(size_column, "SoftNCuts")
if size_column is None:
raise KeyError("Error when making csv. Check loss dict keys ?")
supervised = False

if supervised:
val = utils.fill_list_in_between(
val = utils.fill_list_in_between( # fills the validation list based on validation interval
self.loss_2_values,
self.worker_config.validation_interval - 1,
"",
)[: len(size_column)]
)[
: len(size_column)
]

self.df = pd.DataFrame(
{
Expand All @@ -1404,8 +1417,13 @@ def _make_csv(self):
raise ValueError(err)
else:
ncuts_loss = self.loss_1_values["SoftNCuts"]

logger.debug(f"Epochs : {len(size_column)}")
logger.debug(f"Loss 1 values : {len(ncuts_loss)}")
logger.debug(f"Loss 2 values : {len(self.loss_2_values)}")
try:
dice_metric = self.loss_1_values["Dice metric"]
logger.debug(f"Dice metric : {dice_metric}")
self.df = pd.DataFrame(
{
"Epoch": size_column,
Expand Down Expand Up @@ -1630,15 +1648,15 @@ def __init__(self, parent):
text_label="Number of classes",
)
self.intensity_sigma_choice = ui.DoubleIncrementCounter(
lower=1.0,
lower=0.01,
upper=100.0,
default=self.default_config.intensity_sigma,
parent=parent,
text_label="Intensity sigma",
)
self.intensity_sigma_choice.setMaximumWidth(20)
self.spatial_sigma_choice = ui.DoubleIncrementCounter(
lower=1.0,
lower=0.01,
upper=100.0,
default=self.default_config.spatial_sigma,
parent=parent,
Expand Down Expand Up @@ -1674,10 +1692,10 @@ def __init__(self, parent):
)
self.reconstruction_weight_choice.setMaximumWidth(20)
self.reconstruction_weight_divide_factor_choice = (
ui.IntIncrementCounter(
lower=1,
upper=10000,
default=100,
ui.DoubleIncrementCounter(
lower=0.01,
upper=10000.0,
default=1.0,
parent=parent,
text_label="Reconstruction weight divide factor",
)
Expand Down
2 changes: 1 addition & 1 deletion napari_cellseg3d/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,7 @@ class WNetTrainingWorkerConfig(TrainingWorkerConfig):
reconstruction_loss: str = "MSE" # or "BCE"
# summed losses weights
n_cuts_weight: float = 0.5
rec_loss_weight: float = 0.5 / 100
rec_loss_weight: float = 0.5 / 1.0 # 0.5 / 100
# normalization params
# normalizing_function: callable = remap_image # FIXME: call directly in worker, not a param
# data params
Expand Down