diff --git a/.github/workflows/preview_metadata.yml b/.github/workflows/preview_metadata.yml index 1a96f99f..98651d4b 100644 --- a/.github/workflows/preview_metadata.yml +++ b/.github/workflows/preview_metadata.yml @@ -4,7 +4,7 @@ name: napari hub Preview Page # we use this name to find your preview page artif on: pull_request: branches: - - '**' + - 'test' # '**' for all jobs: preview-page: @@ -16,6 +16,6 @@ jobs: uses: actions/checkout@v2 - name: napari hub Preview Page Builder - uses: chanzuckerberg/napari-hub-preview-action@v0.1.5 + uses: chanzuckerberg/napari-hub-preview-action@v0.1.6 with: hub-ref: main \ No newline at end of file diff --git a/.github/workflows/test_and_deploy.yml b/.github/workflows/test_and_deploy.yml index 033dfb6d..11b8776f 100644 --- a/.github/workflows/test_and_deploy.yml +++ b/.github/workflows/test_and_deploy.yml @@ -22,21 +22,22 @@ jobs: runs-on: ${{ matrix.platform }} strategy: matrix: - platform: [ubuntu-latest, windows-latest, macos-latest] - python-version: [3.8, 3.9, "3.10"] +# platform: [ubuntu-latest, windows-latest, macos-latest] + platform: [windows-latest] + python-version: [3.8, 3.9] # 3.10 steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v3 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 + uses: actions/setup-python@v4 with: python-version: ${{ matrix.python-version }} - # these libraries enable testing on Qt on linux +# these libraries enable testing on Qt on linux - uses: tlambert03/setup-qt-libs@v1 - # strategy borrowed from vispy for installing opengl libs on windows +# strategy borrowed from vispy for installing opengl libs on windows - name: Install Windows OpenGL if: runner.os == 'Windows' run: | @@ -44,18 +45,18 @@ jobs: powershell gl-ci-helpers/appveyor/install_opengl.ps1 if (Test-Path -Path "C:\Windows\system32\opengl32.dll" -PathType Leaf) {Exit 0} else {Exit 1} - # note: if you need dependencies from conda, considering using - # setup-miniconda: https://github.com/conda-incubator/setup-miniconda - # and - # tox-conda: https://github.com/tox-dev/tox-conda +# note: if you need dependencies from conda, considering using +# setup-miniconda: https://github.com/conda-incubator/setup-miniconda +# and +# tox-conda: https://github.com/tox-dev/tox-conda - name: Install dependencies run: | python -m pip install --upgrade pip python -m pip install setuptools tox tox-gh-actions - # this runs the platform-specific tests declared in tox.ini +# this runs the platform-specific tests declared in tox.ini - name: Test with tox - uses: GabrielBB/xvfb-action@v1 + uses: GabrielBB/xvfb-action@v1 # aganders3/headless-gui@v1 with: run: python -m tox env: @@ -65,9 +66,9 @@ jobs: uses: codecov/codecov-action@v2 deploy: - # this will run when you have tagged a commit, starting with "v*" - # and requires that you have put your twine API key in your - # github secrets (see readme for details) +# this will run when you have tagged a commit, starting with "v*" +# and requires that you have put your twine API key in your +# github secrets (see readme for details) needs: [test] runs-on: ubuntu-latest if: contains(github.ref, 'tags') diff --git a/.gitignore b/.gitignore index 74feefbe..ffe6e1f8 100644 --- a/.gitignore +++ b/.gitignore @@ -95,7 +95,12 @@ venv/ ######## #project specific #dataset, weights, old logos, requirements -/napari_cellseg3d/models/dataset/ -/napari_cellseg3d/models/saved_weights/ +/napari_cellseg3d/code_models/models/dataset/ +/napari_cellseg3d/code_models/models/saved_weights/ /docs/res/logo/old_logo/ /reqs/ +/Loss_plots/ +notebooks/csv_cell_plot.html +notebooks/full_plot.html +*.csv +*.png diff --git a/docs/conf.py b/docs/conf.py index cd40efa0..bc766d97 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -52,7 +52,7 @@ # General information about the project. project = "napari-cellseg3d" -copyright = "2022, Cyril Achard, Maxime Vidal" +copyright = "2022-2023, Cyril Achard, Maxime Vidal" author = "Cyril Achard, Maxime Vidal" # The version info for the project you're documenting, acts as replacement for diff --git a/docs/index.rst b/docs/index.rst index e0bcde3f..90f430a0 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -16,7 +16,7 @@ Welcome to napari-cellseg3d's documentation! :caption: Utilities : res/guides/metrics_module_guide - res/guides/convert_module_guide + res/guides/utils_module_guide res/guides/cropping_module_guide @@ -34,8 +34,7 @@ Welcome to napari-cellseg3d's documentation! res/code/interface res/code/plugin_base res/code/plugin_review - res/code/launch_review - res/code/plugin_dock + res/code/plugin_review_dock res/code/plugin_crop res/code/plugin_convert res/code/plugin_metrics diff --git a/docs/res/code/interface.rst b/docs/res/code/interface.rst index a11c7047..3bc4f914 100644 --- a/docs/res/code/interface.rst +++ b/docs/res/code/interface.rst @@ -4,6 +4,27 @@ interface.py Classes ------------- +QWidgetSingleton +************************************** +.. autoclass:: napari_cellseg3d.interface::QWidgetSingleton + :members: __call__ + +UtilsDropdown +************************************** +.. autoclass:: napari_cellseg3d.interface::UtilsDropdown + :members: __init__, dropdown_menu_call, show_utils_menu + +Log +************************************** +.. autoclass:: napari_cellseg3d.interface::Log + :members: __init__, write, replace_last_line, print_and_log, warn + + +ContainerWidget +************************************** +.. autoclass:: napari_cellseg3d.interface::ContainerWidget + :members: __init__ + Button ************************************** .. autoclass:: napari_cellseg3d.interface::Button @@ -22,13 +43,13 @@ CheckBox AnisotropyWidgets ************************************** .. autoclass:: napari_cellseg3d.interface::AnisotropyWidgets - :members: __init__, build, get_anisotropy_resolution_xyz, get_anisotropy_resolution_zyx, anisotropy_zoom_factor,is_enabled,toggle_permanent_visibility + :members: __init__, build, scaling_zyx, resolution_zyx, scaling_xyz, resolution_xyz,enabled FilePathWidget ************************************** .. autoclass:: napari_cellseg3d.interface::FilePathWidget - :members: __init__, build, get_text_field, get_button, check_ready, set_required, update_field_color, set_description + :members: __init__, build, text_field, button, check_ready, required, update_field_color, tooltips ScrollArea ************************************** @@ -38,7 +59,7 @@ ScrollArea DoubleIncrementCounter ************************************** .. autoclass:: napari_cellseg3d.interface::DoubleIncrementCounter - :members: __init__, set_precision, make_n + :members: __init__, precision, make_n IntIncrementCounter ************************************** @@ -49,22 +70,21 @@ IntIncrementCounter Functions ------------- -open_url +handle_adjust_errors ************************************** -.. autofunction:: napari_cellseg3d.interface::open_url - +.. autofunction:: napari_cellseg3d.interface::handle_adjust_errors -make_group +handle_adjust_errors_wrapper ************************************** -.. autofunction:: napari_cellseg3d.interface::make_group +.. autofunction:: napari_cellseg3d.interface::handle_adjust_errors_wrapper -add_to_group +open_url ************************************** -.. autofunction:: napari_cellseg3d.interface::add_to_group +.. autofunction:: napari_cellseg3d.interface::open_url -make_container +make_group ************************************** -.. autofunction:: napari_cellseg3d.interface::make_container +.. autofunction:: napari_cellseg3d.interface::make_group combine_blocks ************************************** @@ -74,6 +94,10 @@ add_blank ************************************** .. autofunction:: napari_cellseg3d.interface::add_blank +add_label +************************************** +.. autofunction:: napari_cellseg3d.interface::add_label + toggle_visibility ************************************** .. autofunction:: napari_cellseg3d.interface::toggle_visibility diff --git a/docs/res/code/launch_review.rst b/docs/res/code/launch_review.rst deleted file mode 100644 index e98be895..00000000 --- a/docs/res/code/launch_review.rst +++ /dev/null @@ -1,10 +0,0 @@ -launch_review.py -=================================== - -Functions -------------------------------------- - -launch_review -******************************* - -.. autofunction:: napari_cellseg3d.launch_review::launch_review diff --git a/docs/res/code/model_framework.rst b/docs/res/code/model_framework.rst index c26dd5f2..59f6d004 100644 --- a/docs/res/code/model_framework.rst +++ b/docs/res/code/model_framework.rst @@ -11,13 +11,13 @@ Class : ModelFramework Methods ********************** -.. autoclass:: napari_cellseg3d.model_framework::ModelFramework - :members: __init__, load_dataset_paths,save_log, save_log_to_path,get_model, get_loss,display_status_report,load_dataset_paths,create_train_dataset_dict, load_image_dataset, load_label_dataset,load_results_path,load_model_path,get_device, update_default, remove_from_viewer +.. autoclass:: napari_cellseg3d.code_models.model_framework::ModelFramework + :members: __init__, send_log, save_log, save_log_to_path, display_status_report, create_train_dataset_dict, get_model, get_available_models, get_device, empty_cuda_cache :noindex: Attributes ********************* -.. autoclass:: napari_cellseg3d.model_framework::ModelFramework - :members: _viewer, worker, docked_widgets, images_filepaths, labels_filepaths, results_path, model_path \ No newline at end of file +.. autoclass:: napari_cellseg3d.code_models.model_framework::ModelFramework + :members: _viewer, worker, docked_widgets, images_filepaths, labels_filepaths, results_path \ No newline at end of file diff --git a/docs/res/code/model_instance_seg.rst b/docs/res/code/model_instance_seg.rst index f8bc78b6..e4146ec1 100644 --- a/docs/res/code/model_instance_seg.rst +++ b/docs/res/code/model_instance_seg.rst @@ -7,24 +7,24 @@ Functions binary_connected ************************************** -.. autofunction:: napari_cellseg3d.model_instance_seg::binary_connected +.. autofunction:: napari_cellseg3d.code_models.model_instance_seg::binary_connected binary_watershed ************************************** -.. autofunction:: napari_cellseg3d.model_instance_seg::binary_watershed +.. autofunction:: napari_cellseg3d.code_models.model_instance_seg::binary_watershed volume_stats ************************************** -.. autofunction:: napari_cellseg3d.model_instance_seg::volume_stats +.. autofunction:: napari_cellseg3d.code_models.model_instance_seg::volume_stats clear_small_objects ************************************** -.. autofunction:: napari_cellseg3d.model_instance_seg::clear_small_objects +.. autofunction:: napari_cellseg3d.code_models.model_instance_seg::clear_small_objects to_instance ************************************** -.. autofunction:: napari_cellseg3d.model_instance_seg::to_instance +.. autofunction:: napari_cellseg3d.code_models.model_instance_seg::to_instance to_semantic ************************************** -.. autofunction:: napari_cellseg3d.model_instance_seg::to_semantic +.. autofunction:: napari_cellseg3d.code_models.model_instance_seg::to_semantic diff --git a/docs/res/code/model_workers.rst b/docs/res/code/model_workers.rst index 1a4f2856..914f6507 100644 --- a/docs/res/code/model_workers.rst +++ b/docs/res/code/model_workers.rst @@ -10,7 +10,7 @@ Class : LogSignal Attributes ************************ -.. autoclass:: napari_cellseg3d.model_workers::LogSignal +.. autoclass:: napari_cellseg3d.code_models.model_workers::LogSignal :members: log_signal :noindex: @@ -24,7 +24,7 @@ Class : InferenceWorker Methods ************************ -.. autoclass:: napari_cellseg3d.model_workers::InferenceWorker +.. autoclass:: napari_cellseg3d.code_models.model_workers::InferenceWorker :members: __init__, log, create_inference_dict, inference :noindex: @@ -39,7 +39,7 @@ Class : TrainingWorker Methods ************************ -.. autoclass:: napari_cellseg3d.model_workers::TrainingWorker +.. autoclass:: napari_cellseg3d.code_models.model_workers::TrainingWorker :members: __init__, log, train :noindex: diff --git a/docs/res/code/plugin_base.rst b/docs/res/code/plugin_base.rst index 2e1b0277..af4a4f4e 100644 --- a/docs/res/code/plugin_base.rst +++ b/docs/res/code/plugin_base.rst @@ -8,17 +8,15 @@ Class : BasePluginSingleImage Methods ********************** -.. autoclass:: napari_cellseg3d.plugin_base::BasePluginSingleImage - :members: __init__, remove_from_viewer, show_dialog_images, show_dialog_labels, update_default +.. autoclass:: napari_cellseg3d.code_plugins.plugin_base::BasePluginSingleImage + :members: __init__, enable_utils_menu, remove_from_viewer, remove_docked_widgets :noindex: - - Attributes ********************* -.. autoclass:: napari_cellseg3d.plugin_base::BasePluginSingleImage - :members: _viewer, image_path, label_path, filetype, file_handling_box +.. autoclass:: napari_cellseg3d.code_plugins.plugin_base::BasePluginSingleImage + :members: _viewer, image_path, label_path, image_layer_loader, label_layer_loader @@ -29,13 +27,12 @@ Class : BasePluginFolder Methods *********************** -.. autoclass:: napari_cellseg3d.plugin_base::BasePluginFolder - :members: __init__, remove_from_viewer,make_close_button,make_prev_button,make_next_button, load_dataset_paths,load_image_dataset,load_label_dataset,load_results_path, update_default,remove_docked_widgets +.. autoclass:: napari_cellseg3d.code_plugins.plugin_base::BasePluginFolder + :members: __init__, load_dataset_paths,load_image_dataset,load_label_dataset :noindex: - Attributes ********************* -.. autoclass:: napari_cellseg3d.plugin_base::BasePluginFolder - :members: _viewer, images_filepaths, labels_filepaths,results_path, filetype_choice \ No newline at end of file +.. autoclass:: napari_cellseg3d.code_plugins.plugin_base::BasePluginFolder + :members: _viewer, images_filepaths, labels_filepaths, results_path \ No newline at end of file diff --git a/docs/res/code/plugin_convert.rst b/docs/res/code/plugin_convert.rst index b32acde7..7a244040 100644 --- a/docs/res/code/plugin_convert.rst +++ b/docs/res/code/plugin_convert.rst @@ -1,21 +1,50 @@ plugin_convert.py ================================== +Classes +---------------------------------- -Class : ConvertUtils ------------------------------------------- +AnisoUtils +********************************** +.. autoclass:: napari_cellseg3d.code_plugins.plugin_convert::AnisoUtils + :members: __init__ -.. important:: - Inherits from : :doc:`plugin_base` +RemoveSmallUtils +********************************** +.. autoclass:: napari_cellseg3d.code_plugins.plugin_convert::RemoveSmallUtils + :members: __init__ -Methods -********************** -.. autoclass:: napari_cellseg3d.plugin_convert::ConvertUtils - :members: __init__, build, folder_to_semantic, layer_to_semantic, folder_to_instance, layer_to_instance, layer_remove_small, folder_remove_small , check_ready_layer,check_ready_folder - :noindex: +ToSemanticUtils +********************************** +.. autoclass:: napari_cellseg3d.code_plugins.plugin_convert::ToSemanticUtils + :members: __init__ -Attributes -********************* +InstanceWidgets +********************************** +.. autoclass:: napari_cellseg3d.code_plugins.plugin_convert::InstanceWidgets + :members: __init__, run_method -.. autoclass:: napari_cellseg3d.plugin_convert::ConvertUtils - :members: _viewer \ No newline at end of file +ToInstanceUtils +********************************** +.. autoclass:: napari_cellseg3d.code_plugins.plugin_convert::ToInstanceUtils + :members: __init__ + +ThresholdUtils +********************************** +.. autoclass:: napari_cellseg3d.code_plugins.plugin_convert::ThresholdUtils + :members: __init__ + +Functions +----------------------------------- + +save_folder +***************************************** +.. autofunction:: napari_cellseg3d.code_plugins.plugin_convert::save_folder + +save_layer +**************************************** +.. autofunction:: napari_cellseg3d.code_plugins.plugin_convert::save_layer + +show_result +**************************************** +.. autofunction:: napari_cellseg3d.code_plugins.plugin_convert::show_result \ No newline at end of file diff --git a/docs/res/code/plugin_crop.rst b/docs/res/code/plugin_crop.rst index d1a6e8d9..cb311d74 100644 --- a/docs/res/code/plugin_crop.rst +++ b/docs/res/code/plugin_crop.rst @@ -11,8 +11,8 @@ Class : Cropping Methods ********************** -.. autoclass:: napari_cellseg3d.plugin_crop::Cropping - :members: __init__, start, quicksave, build, remove_from_viewer +.. autoclass:: napari_cellseg3d.code_plugins.plugin_crop::Cropping + :members: __init__, _start, quicksave, remove_from_viewer :noindex: @@ -21,8 +21,8 @@ Methods Attributes ********************* -.. autoclass:: napari_cellseg3d.plugin_crop::Cropping - :members: _viewer, image_path, label_path, filetype +.. autoclass:: napari_cellseg3d.code_plugins.plugin_crop::Cropping + :members: _viewer, image_path, label_path diff --git a/docs/res/code/plugin_metrics.rst b/docs/res/code/plugin_metrics.rst index fcc9fad3..f9014edb 100644 --- a/docs/res/code/plugin_metrics.rst +++ b/docs/res/code/plugin_metrics.rst @@ -11,12 +11,12 @@ Class : MetricsUtils Methods ********************** -.. autoclass:: napari_cellseg3d.plugin_metrics::MetricsUtils - :members: __init__, build, plot_dice, remove_plots, compute_dice +.. autoclass:: napari_cellseg3d.code_plugins.plugin_metrics::MetricsUtils + :members: __init__, plot_dice, remove_plots, compute_dice :noindex: Attributes ********************* -.. autoclass:: napari_cellseg3d.plugin_metrics::MetricsUtils +.. autoclass:: napari_cellseg3d.code_plugins.plugin_metrics::MetricsUtils :members: _viewer, layout, canvas, plots \ No newline at end of file diff --git a/docs/res/code/plugin_model_inference.rst b/docs/res/code/plugin_model_inference.rst index f6f29b28..cdd4d6eb 100644 --- a/docs/res/code/plugin_model_inference.rst +++ b/docs/res/code/plugin_model_inference.rst @@ -10,8 +10,8 @@ Class : Inferer Methods ********************** -.. autoclass:: napari_cellseg3d.plugin_model_inference::Inferer - :members: __init__, start, build,on_start,on_error,on_finish,on_yield,check_ready,toggle_display_number,toggle_display_thresh, remove_from_viewer +.. autoclass:: napari_cellseg3d.code_plugins.plugin_model_inference::Inferer + :members: __init__, start,on_start,on_error,on_finish,on_yield, check_ready, remove_from_viewer :noindex: @@ -20,5 +20,5 @@ Methods Attributes ********************* -.. autoclass:: napari_cellseg3d.plugin_model_inference::Inferer - :members: _viewer, models_dict +.. autoclass:: napari_cellseg3d.code_plugins.plugin_model_inference::Inferer + :members: _viewer, worker, config, instance_config, post_process_config, worker_config, model_info diff --git a/docs/res/code/plugin_model_training.rst b/docs/res/code/plugin_model_training.rst index c9e561e5..a531b877 100644 --- a/docs/res/code/plugin_model_training.rst +++ b/docs/res/code/plugin_model_training.rst @@ -10,8 +10,8 @@ Class : Trainer Methods ********************** -.. autoclass:: napari_cellseg3d.plugin_model_training::Trainer - :members: __init__, build, show_dialog_lab, show_dialog_dat, check_ready, start, on_start, on_finish, on_error, on_yield, plot_loss, update_loss_plot, remove_from_viewer +.. autoclass:: napari_cellseg3d.code_plugins.plugin_model_training::Trainer + :members: __init__, get_loss, check_ready, send_log, start, on_start, on_finish, on_error, on_yield, plot_loss, update_loss_plot :noindex: @@ -19,5 +19,5 @@ Methods Attributes ********************* -.. autoclass:: napari_cellseg3d.plugin_model_training::Trainer - :members: _viewer, worker, models_dict, loss_dict, canvas, train_loss_plot, dice_metric_plot \ No newline at end of file +.. autoclass:: napari_cellseg3d.code_plugins.plugin_model_training::Trainer + :members: _viewer, worker, loss_dict, canvas, train_loss_plot, dice_metric_plot \ No newline at end of file diff --git a/docs/res/code/plugin_review.rst b/docs/res/code/plugin_review.rst index fb69ee19..f61e5661 100644 --- a/docs/res/code/plugin_review.rst +++ b/docs/res/code/plugin_review.rst @@ -10,8 +10,8 @@ Class : Loader Methods ********************** -.. autoclass:: napari_cellseg3d.plugin_review::Reviewer - :members: __init__, run_review, build, remove_from_viewer +.. autoclass:: napari_cellseg3d.code_plugins.plugin_review::Reviewer + :members: __init__, run_review, launch_review, check_image_data, remove_from_viewer :noindex: @@ -22,8 +22,8 @@ Methods Attributes ********************* -.. autoclass:: napari_cellseg3d.plugin_review::Reviewer - :members: _viewer, image_path, label_path, filetype +.. autoclass:: napari_cellseg3d.code_plugins.plugin_review::Reviewer + :members: _viewer, image_path, label_path diff --git a/docs/res/code/plugin_dock.rst b/docs/res/code/plugin_review_dock.rst similarity index 66% rename from docs/res/code/plugin_dock.rst rename to docs/res/code/plugin_review_dock.rst index 24ca3c66..597b30e9 100644 --- a/docs/res/code/plugin_dock.rst +++ b/docs/res/code/plugin_review_dock.rst @@ -6,7 +6,7 @@ Datamanager Methods ********************** -.. autoclass:: napari_cellseg3d.plugin_dock::Datamanager +.. autoclass:: napari_cellseg3d.code_plugins.plugin_review_dock::Datamanager :members: __init__, prepare, update, load_csv, create :noindex: @@ -17,7 +17,7 @@ Methods Attributes ********************* -.. autoclass:: napari_cellseg3d.plugin_dock::Datamanager +.. autoclass:: napari_cellseg3d.code_plugins.plugin_review_dock::Datamanager :members: viewer diff --git a/docs/res/code/utils.rst b/docs/res/code/utils.rst index f5fd0eee..15d7fc1d 100644 --- a/docs/res/code/utils.rst +++ b/docs/res/code/utils.rst @@ -1,9 +1,15 @@ utils.py ============= -Functions +Classes ------------- +Singleton +************************************** +.. autoclass:: napari_cellseg3d.utils::Singleton + +Functions +------------- get_date_time ************************************** diff --git a/docs/res/guides/cropping_module_guide.rst b/docs/res/guides/cropping_module_guide.rst index 70713658..a862ffff 100644 --- a/docs/res/guides/cropping_module_guide.rst +++ b/docs/res/guides/cropping_module_guide.rst @@ -7,33 +7,35 @@ This module allows you to crop your volumes and labels dynamically, by selecting a fixed size volume and moving it around the image. You can then save the cropped volume and labels directly using napari, -by selecting the layer and then using **File -> Save selected layer**, +by using the **Quicksave** button, +or by selecting the layer and then using **File -> Save selected layer**, or simply **CTRL+S** once you have selected the correct layer. - Launching the cropping process --------------------------------- -First, you will be asked to load your images and labels; you can use the checkbox above the Open buttons to -choose whether you want to load a single 3D **.tif** image or a folder of 2D images as a 3D stack. -Folders can be stacks of either .png or .tif files, ideally numbered with the index of the slice at the end. - -.. note:: - Only single 3D **.tif** files or one folder of several **.png** or **.tif** (stack of 2D images) are supported. +First, simply pick your images using the layer selection dropdown menu. +If you'd like to crop a second image, e.g. labels, at the same time, +simply check the *"Crop another image simultaneously"* checkbox and +pick the corresponding layer. You can then choose the size of the cropped volume, which will be constant throughout the process; make sure it is correct beforehand. -Setting a larger size than the size of the image will cause issues. You can also opt to correct the anisotropy if your image is anisotropic : simply set the resolution to the one of your microscope. .. important:: - This will simply scale the image in the viewer, but saved images will **still be anisotropic.** To resize your image, see :doc:`convert_module_guide` + This will simply scale the image in the viewer, but saved images will **still be anisotropic.** To resize your image, see :doc:`utils_module_guide` Once you are ready, you can press **Start** to start the review process. +If you'd like to change the size of the volume, change the parameters as previously to your desired size and hit start again. - +Creating new layers +--------------------------------- +To "zoom in" your volume, you can use the "Create new layers" checkbox to make a new layer not controlled by the plugin next +time you hit Start. This way, you can first select your region of interest by using the tool as described above, +the enable the option, select the cropped layer, and define a smaller crop size to have easier access to your region of interest. Interface & functionalities --------------------------------------------------------------- @@ -56,11 +58,9 @@ you **change the position** of the cropped volumes and labels in the x,y and z p If you want more options (name, format) you can save by selecting the layer and then using **File -> Save selected layer**, or simply **CTRL+S** once you have selected the correct layer. +.. + Source code + ------------------------------------------------- - - -Source code -------------------------------------------------- - -* :doc:`../code/plugin_crop` -* :doc:`../code/plugin_base` + * :doc:`../code/plugin_crop` + * :doc:`../code/plugin_base` diff --git a/docs/res/guides/inference_module_guide.rst b/docs/res/guides/inference_module_guide.rst index f200c39e..00e67078 100644 --- a/docs/res/guides/inference_module_guide.rst +++ b/docs/res/guides/inference_module_guide.rst @@ -3,12 +3,12 @@ Inference module guide ================================= -This module allows you to use pre-trained segmentation algorithms (written in Pytorch) on 3D volumes +This module allows you to use pre-trained segmentation algorithms (written in Pytorch) on 3D volumes to automatically label cells. .. important:: Currently, only inference on **3D volumes is supported**. Your image and label folders should both contain a set of - **3D image files**, currently either **.tif** or **.tiff**. Loading a folder of 2D images as a stack is not supported as of yet. + **3D image files**, currently either **.tif** or **.tiff**. Currently, the following pre-trained models are available : @@ -17,14 +17,16 @@ Model Link to original paper ============== ================================================================================================ VNet `Fully Convolutional Neural Networks for Volumetric Medical Image Segmentation`_ SegResNet `3D MRI brain tumor segmentation using autoencoder regularization`_ -TRAILMAP_MS A PyTorch implementation of the `TRAILMAP project on GitHub`_ pretrained with MesoSpim data +TRAILMAP_MS A PyTorch implementation of the `TRAILMAP project on GitHub`_ pretrained with mesoSPIM data TRAILMAP An implementation of the `TRAILMAP project on GitHub`_ using a `3DUNet for PyTorch`_ +SwinUNetR `Swin Transformers for Semantic Segmentation of Brain Tumors in MRI Images`_ ============== ================================================================================================ .. _Fully Convolutional Neural Networks for Volumetric Medical Image Segmentation: https://arxiv.org/pdf/1606.04797.pdf .. _3D MRI brain tumor segmentation using autoencoder regularization: https://arxiv.org/pdf/1810.11654.pdf .. _TRAILMAP project on GitHub: https://github.com/AlbertPun/TRAILMAP .. _3DUnet for Pytorch: https://github.com/wolny/pytorch-3dunet +.. _Swin Transformers for Semantic Segmentation of Brain Tumors in MRI Images: https://arxiv.org/abs/2201.01266 Interface and functionalities -------------------------------- @@ -35,28 +37,29 @@ Interface and functionalities * **Loading data** : - | When launching the module, you will be asked to provide an **image folder** containing all the 3D volumes you'd like to be labeled. - | All images with the chosen extension (**.tif** or **.tiff** currently supported) in this folder will be labeled. + | When launching the module, you will be asked to provide an **image layer** or an **image folder** with the 3D volumes you'd like to be labeled. + | If loading from folder : All images with the chosen extension (**.tif** or **.tiff** currently supported) in this folder will be labeled. | You can then choose an **output folder**, where all the results will be saved. - * **Model choice** : | You can then choose one of the provided **models** above, which will be used for inference. | You may also choose to **load custom weights** rather than the pre-trained ones, simply ensure they are **compatible** (e.g. produced from the training module for the same model) - | If you choose to use a SegResNet with custom weights, you will have to provide the size of images it was trained on to ensure compatibility. (See note below) + | If you choose to use SegResNet or SwinUNetR with custom weights, you will have to provide the size of images it was trained on to ensure compatibility. (See note below) + +.. note:: + Currently the SegResNet and SwinUNetR models requires you to provide the size of the images the model was trained with. + Provided weights use a size of 128, please leave it on the default value if you're not using custom weights. * **Inference parameters** : | You can choose to use inference on the whole image at once, which generally yields better performance at the cost of more memory, or you can use a specific window size to run inference on smaller chunks one by one, for lower memory usage. | You can also choose to keep the dataset in the RAM rather than the VRAM (cpu vs cuda device) to avoid running out of VRAM if you have several images. - * **Anisotropy** : | If you want to see your results without **anisotropy** when you have anisotropic images, you can specify that you have anisotropic data and set the **resolution of your imaging method in micron**, this wil save and show the results without anisotropy. - * **Thresholding** : | You can perform thresholding to **binarize your labels**, @@ -107,11 +110,6 @@ Once it has finished, results will be saved then displayed in napari; each outpu On the left side, a progress bar and a log will keep you informed on the process. -.. note:: - Currently the SegResNet model requires you to provide the size of the images the model was trained with due to the VAE module. - Provided weights use a size of 128, please leave it as is if you're not using custom weights. - - .. note:: | The files will be saved using the following format : | ``{original_name}_{model}_{date & time}_pred{id}.file_ext`` diff --git a/docs/res/guides/review_module_guide.rst b/docs/res/guides/review_module_guide.rst index a26adf55..ffecf9a0 100644 --- a/docs/res/guides/review_module_guide.rst +++ b/docs/res/guides/review_module_guide.rst @@ -23,7 +23,7 @@ Launching the review process This will scale the images to visually remove the anisotropy, so as to make review easier. .. important:: - Results will still be saved as anisotropic images. If you wish to resize your images, see the :doc:`convert_module_guide` + Results will still be saved as anisotropic images. If you wish to resize your images, see the :doc:`utils_module_guide` * CSV file name : You can then provide a model name, which will be used to name the csv file recording the status of each slice. @@ -59,5 +59,5 @@ Source code ------------------------------------------------- * :doc:`../code/plugin_review` -* :doc:`../code/launch_review` -* :doc:`../code/plugin_base` \ No newline at end of file +* :doc:`../code/plugin_review_dock` +* :doc:`../code/plugin_base` diff --git a/docs/res/guides/training_module_guide.rst b/docs/res/guides/training_module_guide.rst index 45be6ccd..fb8992d2 100644 --- a/docs/res/guides/training_module_guide.rst +++ b/docs/res/guides/training_module_guide.rst @@ -60,6 +60,7 @@ The training module is comprised of several tabs. * The size of patches to be extracted (ideally, please use a value **close or equal to a power of two**, such as 120 or 60 to ensure correct size. See above note.) * The number of samples to extract from each of your images. A larger number will likely mean better performances, but longer training and larger memory usage. +.. note:: If you're using a single image (preferably large) it is recommended to enable patch extraction. * Whether to perform data augmentation or not (elastic deforms, intensity shifts. random flipping,etc). Ideally it should always be enabled, but you can disable it if it causes issues. @@ -73,6 +74,7 @@ The training module is comprised of several tabs. * The **batch size** (larger means quicker training and possibly better performance but increased memory usage) * The **number of epochs** (a possibility is to start with 60 epochs, and decrease or increase depending on performance.) * The **epoch interval** for validation (for example, if set to two, the module will use the validation dataset to evaluate the model with the dice metric every two epochs.) +* Whether to use deterministic training, and the seed to use. .. note:: If the dice metric is better on a given validation interval, the model weights will be saved in the results folder. diff --git a/docs/res/guides/convert_module_guide.rst b/docs/res/guides/utils_module_guide.rst similarity index 81% rename from docs/res/guides/convert_module_guide.rst rename to docs/res/guides/utils_module_guide.rst index bdf2dbc2..fd9f7401 100644 --- a/docs/res/guides/convert_module_guide.rst +++ b/docs/res/guides/utils_module_guide.rst @@ -1,4 +1,4 @@ -.. _convert_module_guide: +.. _utils_module_guide: Label conversion utility guide ================================== @@ -18,16 +18,18 @@ You can : You can specify a size threshold in pixels; all objects smaller than this size will be removed in the image. * Resize anisotropic images : - Specifiy the resolution of your microscope to remove anisotropy from images. + Specify the resolution of your microscope to remove anisotropy from images. .. important:: Does not work for instance labels currently. +* Threshold images : + Remove all values below a threshold in an image. .. figure:: ../images/converted_labels.png :scale: 30 % :align: center - Example of instance labels (left) converted to instance labels (right) + Example of instance labels (left) converted to semantic labels (right) Source code ------------------------------------------------- diff --git a/docs/res/images/cropping_process_example.png b/docs/res/images/cropping_process_example.png index b3c408de..d691650d 100644 Binary files a/docs/res/images/cropping_process_example.png and b/docs/res/images/cropping_process_example.png differ diff --git a/docs/res/images/inference_plugin_layout.png b/docs/res/images/inference_plugin_layout.png index b7e2838f..760b572d 100644 Binary files a/docs/res/images/inference_plugin_layout.png and b/docs/res/images/inference_plugin_layout.png differ diff --git a/docs/res/images/inference_results_example.png b/docs/res/images/inference_results_example.png index bbc94176..3f9d9b46 100644 Binary files a/docs/res/images/inference_results_example.png and b/docs/res/images/inference_results_example.png differ diff --git a/docs/res/welcome.rst b/docs/res/welcome.rst index 20c92066..6832e71e 100644 --- a/docs/res/welcome.rst +++ b/docs/res/welcome.rst @@ -2,7 +2,7 @@ Introduction =================== -Here you will find instructions on how to use the plugin for direct-to-3D segmentation. +Here you will find instructions on how to use the plugin for direct segmentation in 3D. If the installation was successful, you'll see the napari-cellseg3D plugin in the Plugins section of napari. @@ -20,8 +20,13 @@ From this page you can access the guides on the several modules available for yo * Review : :ref:`loader_module_guide` * Utilities : * Cropping (3D) : :ref:`cropping_module_guide` - * Convert labels : :ref:`convert_module_guide` + * Other utilities : :ref:`utils_module_guide` + +.. + * Convert labels : :ref:`utils_module_guide` +.. * Compute scores : :ref:`metrics_module_guide` + * Advanced : * Defining custom models directly in the plugin (WIP) : :ref:`custom_model_guide` @@ -66,10 +71,11 @@ To use the plugin, please run: Then go into Plugins > napari-cellseg3d, and choose which tool to use: -- **Train**: This module allows you to train segmentation algorithms from labeled volumes. -- **Infer**: This module allows you to use pre-trained segmentation algorithms on volumes to automatically label cells. -- **Review**: This module allows you to review your labels, from predictions or manual labeling, and correct them if needed. It then saves the status of each file in a csv, for easier monitoring. +- **Review**: This module allows you to review your labels, from predictions or manual labeling, and correct them if needed. It then saves the status of each file in a csv, for easier monitoring +- **Inference**: This module allows you to use pre-trained segmentation algorithms on volumes to automatically label cells +- **Training**: This module allows you to train segmentation algorithms from labeled volumes - **Utilities**: This module allows you to use several utilities, e.g. to crop your volumes and labels, compute prediction scores or convert labels +- **Help/About...** : Quick access to version info, Github page and docs See above for links to detailed guides regarding the usage of the modules. diff --git a/napari_cellseg3d/__init__.py b/napari_cellseg3d/__init__.py index a12d527a..6e2681e8 100644 --- a/napari_cellseg3d/__init__.py +++ b/napari_cellseg3d/__init__.py @@ -1 +1 @@ -__version__ = "0.0.1rc4" +__version__ = "0.0.2rc1" diff --git a/napari_cellseg3d/_tests/fixtures.py b/napari_cellseg3d/_tests/fixtures.py new file mode 100644 index 00000000..8fcf56db --- /dev/null +++ b/napari_cellseg3d/_tests/fixtures.py @@ -0,0 +1,15 @@ +import warnings +from qtpy.QtWidgets import QTextEdit + + +class LogFixture(QTextEdit): + """Fixture for testing, replaces napari_cellseg3d.interface.Log in model_workers during testing""" + + def __init__(self): + super(LogFixture, self).__init__() + + def print_and_log(self, text, printing=None): + print(text) + + def warn(self, warning): + warnings.warn(warning) diff --git a/napari_cellseg3d/_tests/test_dock_widget.py b/napari_cellseg3d/_tests/test_dock_widget.py index 2ce9f6d3..f621dba4 100644 --- a/napari_cellseg3d/_tests/test_dock_widget.py +++ b/napari_cellseg3d/_tests/test_dock_widget.py @@ -1,24 +1,21 @@ -import os from pathlib import Path from tifffile import imread -from napari_cellseg3d.plugin_dock import Datamanager +from napari_cellseg3d.code_plugins.plugin_review_dock import Datamanager def test_prepare(make_napari_viewer): - path_image = Path( - os.path.dirname(os.path.realpath(__file__)) + "/res/test.tif" - ) - image = imread(path_image) + path_image = str(Path(__file__).resolve().parent / "res/test.tif") + image = imread(str(path_image)) viewer = make_napari_viewer() viewer.add_image(image) widget = Datamanager(viewer) - widget.prepare(path_image, ".tif", "", False, False) + widget.prepare(path_image, ".tif", "", False) assert widget.filetype == ".tif" assert widget.as_folder == False - assert Path(widget.csv_path) == Path( - os.path.dirname(os.path.realpath(__file__)) + "/res/_train0.csv" + assert Path(widget.csv_path) == ( + Path(__file__).resolve().parent / "res/_train0.csv" ) diff --git a/napari_cellseg3d/_tests/test_helper.py b/napari_cellseg3d/_tests/test_helper.py new file mode 100644 index 00000000..b35fc111 --- /dev/null +++ b/napari_cellseg3d/_tests/test_helper.py @@ -0,0 +1,16 @@ +from napari_cellseg3d.code_plugins.plugin_helper import Helper + + +def test_helper(make_napari_viewer): + + viewer = make_napari_viewer() + widget = Helper(viewer) + + dock = viewer.window.add_dock_widget(widget) + children = len(viewer.window._dock_widgets) + + assert dock is not None + + widget.btnc.click() + + assert len(viewer.window._dock_widgets) == children - 1 diff --git a/napari_cellseg3d/_tests/test_interface.py b/napari_cellseg3d/_tests/test_interface.py new file mode 100644 index 00000000..be811721 --- /dev/null +++ b/napari_cellseg3d/_tests/test_interface.py @@ -0,0 +1,14 @@ +from napari_cellseg3d.interface import Log + + +def test_log(qtbot): + log = Log() + log.print_and_log("test") + + assert log.toPlainText() == "\ntest" + + log.replace_last_line("test2") + + assert log.toPlainText() == "\ntest2" + + qtbot.add_widget(log) diff --git a/napari_cellseg3d/_tests/test_model_framework.py b/napari_cellseg3d/_tests/test_model_framework.py index aa7a17f8..5734b329 100644 --- a/napari_cellseg3d/_tests/test_model_framework.py +++ b/napari_cellseg3d/_tests/test_model_framework.py @@ -1,34 +1,40 @@ -from napari_cellseg3d import model_framework +from pathlib import Path + +from napari_cellseg3d.code_models import model_framework + + +def pth(path): + return str(Path(path)) def test_update_default(make_napari_viewer): view = make_napari_viewer() widget = model_framework.ModelFramework(view) - widget.images_filepaths = [""] - widget.results_path = "" + widget.images_filepaths = [] + widget.results_path = None - widget.update_default() + widget._update_default() - assert widget._default_path == [] + assert widget._default_folders == [] widget.images_filepaths = [ - "C:/test/test/images.tif", - "C:/images/test/data.png", + pth("C:/test/test/images.tif"), + pth("C:/images/test/data.png"), ] widget.labels_filepaths = [ - "C:/dataset/labels/lab1.tif", - "C:/data/labels/lab2.tif", + pth("C:/dataset/labels/lab1.tif"), + pth("C:/data/labels/lab2.tif"), ] - widget.results_path = "D:/dataset/res" - widget.model_path = "" + widget.results_path = pth("D:/dataset/res") + # widget.model_path = None - widget.update_default() + widget._update_default() - assert widget._default_path == [ - "C:/test/test", - "C:/dataset/labels", - "D:/dataset/res", + assert widget._default_folders == [ + pth("C:/test/test"), + pth("C:/dataset/labels"), + pth("D:/dataset/res"), ] diff --git a/napari_cellseg3d/_tests/test_plugin_inference.py b/napari_cellseg3d/_tests/test_plugin_inference.py new file mode 100644 index 00000000..1ec7e77e --- /dev/null +++ b/napari_cellseg3d/_tests/test_plugin_inference.py @@ -0,0 +1,41 @@ +from tifffile import imread +from pathlib import Path + +from napari_cellseg3d.config import MODEL_LIST +from napari_cellseg3d._tests.fixtures import LogFixture +from napari_cellseg3d.code_plugins.plugin_model_inference import Inferer +from napari_cellseg3d.code_models.models.model_test import TestModel + + +def test_inference(make_napari_viewer, qtbot): + + im_path = str(Path(__file__).resolve().parent / "res/test.tif") + image = imread(im_path) + + assert image.shape == (6, 6, 6) + + viewer = make_napari_viewer() + widget = Inferer(viewer) + widget.log = LogFixture() + viewer.window.add_dock_widget(widget) + viewer.add_image(image) + + assert len(viewer.layers) == 1 + + widget.window_infer_box.setChecked(True) + widget.window_overlap_slider.setValue(0) + widget.keep_data_on_cpu_box.setChecked(True) + + assert widget.check_ready() + + MODEL_LIST["test"] = TestModel + widget.model_choice.addItem("test") + widget.setCurrentIndex(-1) + + # widget.start() # takes too long on Github Actions + # assert widget.worker is not None + + # with qtbot.waitSignal(signal=widget.worker.finished, timeout=60000, raising=False) as blocker: + # blocker.connect(widget.worker.errored) + + # assert len(viewer.layers) == 2 diff --git a/napari_cellseg3d/_tests/test_plugin_utils.py b/napari_cellseg3d/_tests/test_plugin_utils.py new file mode 100644 index 00000000..f0fac98a --- /dev/null +++ b/napari_cellseg3d/_tests/test_plugin_utils.py @@ -0,0 +1,14 @@ +from napari_cellseg3d.code_plugins.plugin_utilities import Utilities +from napari_cellseg3d.code_plugins.plugin_utilities import UTILITIES_WIDGETS + + +def test_utils_plugin(make_napari_viewer): + view = make_napari_viewer() + widget = Utilities(view) + + view.window.add_dock_widget(widget) + for i, utils_name in enumerate(UTILITIES_WIDGETS.keys()): + widget.utils_choice.setCurrentIndex(i) + assert isinstance( + widget.utils_widgets[i], UTILITIES_WIDGETS[utils_name] + ) diff --git a/napari_cellseg3d/_tests/test_review.py b/napari_cellseg3d/_tests/test_review.py index 97dfaa4b..d2b49061 100644 --- a/napari_cellseg3d/_tests/test_review.py +++ b/napari_cellseg3d/_tests/test_review.py @@ -1,6 +1,6 @@ -import os +from pathlib import Path -from napari_cellseg3d import plugin_review as rev +from napari_cellseg3d.code_plugins import plugin_review as rev def test_launch_review(make_napari_viewer): @@ -10,15 +10,15 @@ def test_launch_review(make_napari_viewer): # widget.filetype_choice.setCurrentIndex(0) - im_path = os.path.dirname(os.path.realpath(__file__)) + "/res/test.tif" + im_path = str(Path(__file__).resolve().parent / "res/test.tif") - widget.image_path = im_path - widget.label_path = im_path + widget.folder_choice.setChecked(True) + widget.image_filewidget.text_field = im_path + widget.labels_filewidget.text_field = im_path + widget.results_filewidget.text_field = str( + Path(__file__).resolve().parent / "res" + ) - print(widget.image_path) - print(widget.label_path) - print(widget.as_folder) - print(widget.filetype) widget.run_review() widget._viewer.close() diff --git a/napari_cellseg3d/_tests/test_training.py b/napari_cellseg3d/_tests/test_training.py index 1cf236eb..70b79b31 100644 --- a/napari_cellseg3d/_tests/test_training.py +++ b/napari_cellseg3d/_tests/test_training.py @@ -1,28 +1,55 @@ -from napari_cellseg3d import plugin_model_training as train +from pathlib import Path +from napari_cellseg3d import config +from napari_cellseg3d.code_plugins.plugin_model_training import Trainer +from napari_cellseg3d._tests.fixtures import LogFixture +from napari_cellseg3d.config import MODEL_LIST +from napari_cellseg3d.code_models.models.model_test import TestModel -def test_check_ready(make_napari_viewer): - view = make_napari_viewer() - widget = train.Trainer(view) - widget.images_filepath = [""] - widget.labels_filepaths = [""] +def test_training(make_napari_viewer, qtbot): + im_path = str(Path(__file__).resolve().parent / "res/test.tif") + + viewer = make_napari_viewer() + widget = Trainer(viewer) + widget.log = LogFixture() + viewer.window.add_dock_widget(widget) + + widget.images_filepath = None + widget.labels_filepaths = None + + assert not widget.check_ready() + + assert widget.filetype_choice.currentText() == ".tif" + + widget.images_filepaths = [im_path] + widget.labels_filepaths = [im_path] + widget.epoch_choice.setValue(1) + widget.val_interval_choice.setValue(1) + + assert widget.check_ready() + + ################# + # Training is too long to test properly this way. Do not use on Github + ################# + MODEL_LIST["test"] = TestModel() + widget.model_choice.addItem("test") + widget.model_choice.setCurrentIndex(len(MODEL_LIST.keys()) - 1) - res = widget.check_ready() - assert not res + # widget.start() + # assert widget.worker is not None - # widget.images_filepath = ["C:/test/something.tif"] - # widget.labels_filepaths = ["C:/test/lab_something.tif"] - # res = widget.check_ready() - # - # assert res + # with qtbot.waitSignal(signal=widget.worker.finished, timeout=10000, raising=False) as blocker: # wait only for 60 seconds. + # blocker.connect(widget.worker.errored) def test_update_loss_plot(make_napari_viewer): view = make_napari_viewer() - widget = train.Trainer(view) + widget = Trainer(view) - widget.val_interval = 1 + widget.worker_config = config.TrainingWorkerConfig() + widget.worker_config.validation_interval = 1 + widget.worker_config.results_path_folder = "." epoch_loss_values = [1] metric_values = [] @@ -32,7 +59,7 @@ def test_update_loss_plot(make_napari_viewer): assert widget.dice_metric_plot is None assert widget.train_loss_plot is None - widget.val_interval = 2 + widget.worker_config.validation_interval = 2 epoch_loss_values = [0, 1] metric_values = [0.2] diff --git a/napari_cellseg3d/_tests/test_utils.py b/napari_cellseg3d/_tests/test_utils.py index 45416c53..6a7b6eeb 100644 --- a/napari_cellseg3d/_tests/test_utils.py +++ b/napari_cellseg3d/_tests/test_utils.py @@ -112,10 +112,10 @@ def test_normalize_x(): def test_parse_default_path(): user_path = os.path.expanduser("~") - assert utils.parse_default_path([""]) == user_path + assert utils.parse_default_path([None]) == user_path - path = ["C:/test/test", "", [""]] + path = ["C:/test/test", None, None] assert utils.parse_default_path(path) == "C:/test/test" - path = ["C:/test/test", "", [""], "D:/very/long/path/what/a/bore", ""] + path = ["C:/test/test", None, None, "D:/very/long/path/what/a/bore", ""] assert utils.parse_default_path(path) == "D:/very/long/path/what/a/bore" diff --git a/napari_cellseg3d/_tests/test_weight_download.py b/napari_cellseg3d/_tests/test_weight_download.py new file mode 100644 index 00000000..306fdf6c --- /dev/null +++ b/napari_cellseg3d/_tests/test_weight_download.py @@ -0,0 +1,12 @@ +from napari_cellseg3d.code_models.model_workers import ( + WeightsDownloader, + WEIGHTS_DIR, +) + +# DISABLED, causes GitHub actions to freeze +def test_weight_download(): + downloader = WeightsDownloader() + downloader.download_weights("test", "test.pth") + result_path = WEIGHTS_DIR / "test.pth" + + assert result_path.is_file() diff --git a/napari_cellseg3d/models/__init__.py b/napari_cellseg3d/code_models/__init__.py similarity index 100% rename from napari_cellseg3d/models/__init__.py rename to napari_cellseg3d/code_models/__init__.py diff --git a/napari_cellseg3d/model_framework.py b/napari_cellseg3d/code_models/model_framework.py similarity index 51% rename from napari_cellseg3d/model_framework.py rename to napari_cellseg3d/code_models/model_framework.py index 47208616..88d3f887 100644 --- a/napari_cellseg3d/model_framework.py +++ b/napari_cellseg3d/code_models/model_framework.py @@ -1,5 +1,5 @@ -import os import warnings +from pathlib import Path import napari import torch @@ -9,24 +9,26 @@ from qtpy.QtWidgets import QSizePolicy # local +from napari_cellseg3d import config from napari_cellseg3d import interface as ui from napari_cellseg3d import utils -from napari_cellseg3d.log_utility import Log -from napari_cellseg3d.models import model_SegResNet as SegResNet -from napari_cellseg3d.models import model_SwinUNetR as SwinUNetR - -# from napari_cellseg3d.models import model_TRAILMAP as TRAILMAP -from napari_cellseg3d.models import model_VNet as VNet -from napari_cellseg3d.models import model_TRAILMAP_MS as TRAILMAP_MS -from napari_cellseg3d.plugin_base import BasePluginFolder +from napari_cellseg3d.code_plugins.plugin_base import BasePluginFolder warnings.formatwarning = utils.format_Warning +logger = utils.LOGGER class ModelFramework(BasePluginFolder): """A framework with buttons to use for loading images, labels, models, etc. for both inference and training""" - def __init__(self, viewer: "napari.viewer.Viewer"): + def __init__( + self, + viewer: "napari.viewer.Viewer", + parent=None, + loads_images=True, + loads_labels=True, + has_results=True, + ): """Creates a plugin framework with the following elements : * A button to choose an image folder containing the images of a dataset (e.g. dataset/images) @@ -41,36 +43,30 @@ def __init__(self, viewer: "napari.viewer.Viewer"): Args: viewer (napari.viewer.Viewer): viewer to load the widget in + parent: parent QWidget + loads_images: if True, will contain UI elements used to load napari image layers + loads_labels: if True, will contain UI elements used to load napari label layers + has_results: if True, will add UI to choose a results path """ - super().__init__(viewer) + super().__init__( + viewer, parent, loads_images, loads_labels, has_results + ) self._viewer = viewer """Viewer to display the widget in""" - self.model_path = "" - """str: path to custom model defined by user""" - self.weights_path = "" + # self.model_path = "" # TODO add custom models + # """str: path to custom model defined by user""" + + self.weights_config = config.WeightsInfo() """str : path to custom weights defined by user""" - self._default_path = [ - self.images_filepaths, - self.labels_filepaths, - self.model_path, - self.weights_path, - self.results_path, - ] - """Update defaults from PluginBaseFolder with model_path""" + self._default_weights_folder = self.weights_config.path + """Default path for plugin weights""" - self.models_dict = { - "VNet": VNet, - "SegResNet": SegResNet, - # "TRAILMAP": TRAILMAP, - "TRAILMAP_MS": TRAILMAP_MS, - "SwinUNetR": SwinUNetR, - } - """dict: dictionary of available models, with string for widget display as key + self.available_models = config.MODEL_LIST - Currently implemented : SegResNet, VNet, TRAILMAP_MS""" + """dict: dictionary of available models, with string as key for name in widget display""" self.worker = None """Worker from model_workers.py, either inference or training""" @@ -79,81 +75,73 @@ def __init__(self, viewer: "napari.viewer.Viewer"): # interface # TODO : implement custom model - self.model_filewidget = ui.FilePathWidget( - "Model path", self.load_model_path, self - ) - self.btn_model_path = self.model_filewidget.get_button() - self.lbl_model_path = self.model_filewidget.get_text_field() + # self.model_filewidget = ui.FilePathWidget( + # "Model path", self.load_model_path, self + # ) self.model_choice = ui.DropdownMenu( - sorted(self.models_dict.keys()), label="Model name" + sorted(self.available_models.keys()), label="Model name" ) - self.lbl_model_choice = self.model_choice.label self.weights_filewidget = ui.FilePathWidget( - "Weights path", self.load_weights_path, self - ) - self.btn_weights_path = self.weights_filewidget.get_button() - self.lbl_weights_path = self.weights_filewidget.get_text_field() - - self.weights_path_container = ui.combine_blocks( - self.btn_weights_path, self.lbl_weights_path, b=0 + "Weights path", self._load_weights_path, self ) - self.weights_path_container.setVisible(False) - self.custom_weights_choice = ui.make_checkbox( - "Load custom weights", self.toggle_weights_path, self + self.custom_weights_choice = ui.CheckBox( + "Load custom weights", self._toggle_weights_path, self ) ################################################### # status report docked widget - ( - self.container_report, - self.container_report_layout, - ) = ui.make_container(10, 5, 5, 5) - self.container_report.setSizePolicy( + + self.report_container = ui.ContainerWidget(l=10, t=5, r=5, b=5) + + self.report_container.setSizePolicy( QSizePolicy.Fixed, QSizePolicy.Minimum ) self.container_docked = False # check if already docked - self.progress = QProgressBar(self.container_report) + self.progress = QProgressBar(self.report_container) self.progress.setVisible(False) """Widget for the progress bar""" - self.log = Log(self.container_report) + self.log = ui.Log(self.report_container) self.log.setVisible(False) """Read-only display for process-related info. Use only for info destined to user.""" self.btn_save_log = ui.Button( "Save log in results folder", func=self.save_log, - parent=self.container_report, + parent=self.report_container, fixed=False, ) self.btn_save_log.setVisible(False) - ##################################################### def send_log(self, text): """Emit a signal to print in a Log""" - self.log.print_and_log(text) + if self.log is not None: + self.log.print_and_log(text) def save_log(self): """Saves the worker's log to disk at self.results_path when called""" - log = self.log.toPlainText() - - path = self.results_path - - if len(log) != 0: - with open( - path + f"/Log_report_{utils.get_date_time()}.txt", - "x", - ) as f: - f.write(log) - f.close() + if self.log is not None: + log = self.log.toPlainText() + + path = self.results_path + + if len(log) != 0: + with open( + path + f"/Log_report_{utils.get_date_time()}.txt", + "x", + ) as f: + f.write(log) + f.close() + else: + warnings.warn( + "No job has been completed yet, please start one or re-open the log window." + ) else: - warnings.warn( - "No job has been completed yet, please start one or re-open the log window." - ) + warnings.warn(f"No logger defined : Log is {self.log}") def save_log_to_path(self, path): """Saves the worker log to a specific path. Cannot be used with connect. @@ -163,10 +151,13 @@ def save_log_to_path(self, path): """ log = self.log.toPlainText() + path = str( + Path(path) / Path(f"Log_report_{utils.get_date_time()}.txt") + ) if len(log) != 0: with open( - path + f"/Log_report_{utils.get_date_time()}.txt", + path, "x", ) as f: f.write(log) @@ -202,19 +193,26 @@ def display_status_report(self): elif not self.container_docked: ui.add_widgets( - self.container_report_layout, + self.report_container.layout, [self.progress, self.log, self.btn_save_log], alignment=None, ) - self.container_report.setLayout(self.container_report_layout) + self.report_container.setLayout(self.report_container.layout) report_dock = self._viewer.window.add_dock_widget( - self.container_report, + self.report_container, name="Status report", area="left", allowed_areas=["left"], ) + report_dock._close_btn = False + + # TODO move to activity log once they figure out _qt_window access and private attrib. + # activity_log = self._viewer.window._qt_window._activity_dialog + # activity_layout = activity_log._activityLayout + # activity_layout.addWidget(self.container_report) + self.docked_widgets.append(report_dock) self.container_docked = True @@ -223,33 +221,32 @@ def display_status_report(self): self.btn_save_log.setVisible(True) self.progress.setValue(0) - def toggle_weights_path(self): + def _toggle_weights_path(self): """Toggle visibility of weight path""" ui.toggle_visibility( - self.custom_weights_choice, self.weights_path_container + self.custom_weights_choice, self.weights_filewidget ) def create_train_dataset_dict(self): """Creates data dictionary for MONAI transforms and training. - Returns: a dict with the following : - **Keys:** + Returns: + A dict with the following keys * "image": image - * "label" : corresponding label """ if len(self.images_filepaths) == 0 or len(self.labels_filepaths) == 0: raise ValueError("Data folders are empty") - print("Images :\n") + logger.info("Images :\n") for file in self.images_filepaths: - print(os.path.basename(file).split(".")[0]) - print("*" * 10) - print("\nLabels :\n") + logger.info(Path(file).name) + logger.info("*" * 10) + logger.info("Labels :\n") for file in self.labels_filepaths: - print(os.path.basename(file).split(".")[0]) + logger.info(Path(file).name) data_dicts = [ {"image": image_name, "label": label_name} @@ -257,34 +254,44 @@ def create_train_dataset_dict(self): self.images_filepaths, self.labels_filepaths ) ] + logger.debug(f"Training data dict : {data_dicts}") return data_dicts - def get_model(self, key): + def get_model(self, key): # TODO remove """Getter for module (class and functions) associated to currently selected model""" return self.models_dict[key] - def get_loss(self, key): - """Getter for loss function selected by user""" - return self.loss_dict[key] + @staticmethod + def get_available_models(): + """Getter for module (class and functions) associated to currently selected model""" + return config.MODEL_LIST - def load_model_path(self): - """Show file dialog to set :py:attr:`model_path`""" - dir = ui.open_file_dialog(self, self._default_path) - if dir != "" and type(dir) is str and os.path.isdir(dir): - self.model_path = dir - self.lbl_model_path.setText(self.results_path) - # self.update_default() + # def load_model_path(self): # TODO add custom models + # """Show file dialog to set :py:attr:`model_path`""" + # folder = ui.open_folder_dialog(self, self._default_folders) + # if folder is not None and type(folder) is str and os.path.isdir(folder): + # self.model_path = folder + # self.lbl_model_path.setText(self.model_path) + # # self.update_default() - def load_weights_path(self): + def _load_weights_path(self): """Show file dialog to set :py:attr:`model_path`""" + + # logger.debug(self._default_weights_folder) + file = ui.open_file_dialog( - self, self._default_path, filetype="Weights file (*.pth)" + self, + [self._default_weights_folder], + filetype="Weights file (*.pth)", ) - if file != "": - self.weights_path = file[0] - self.lbl_weights_path.setText(self.weights_path) - self.update_default() + if file[0] == self._default_weights_folder: + return + if file is not None: + if file[0] != "": + self.weights_config.path = file[0] + self.weights_filewidget.text_field.setText(file[0]) + self._default_weights_folder = str(Path(file[0]).parent) @staticmethod def get_device(show=True): @@ -292,30 +299,43 @@ def get_device(show=True): If none is available (CUDA not installed), uses cpu instead.""" device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if show: - print(f"Using {device} device") - print("Using torch :") - print(torch.__version__) + logger.info(f"Using {device} device") + logger.info("Using torch :") + logger.info(torch.__version__) return device def empty_cuda_cache(self): """Empties the cuda cache if the device is a cuda device""" if self.get_device(show=False).type == "cuda": - print("Empyting cache...") + logger.info("Attempting to empty cache...") torch.cuda.empty_cache() - print("Cache emptied") - - def update_default(self): - """Update default path for smoother file dialogs, here with :py:attr:`~model_path` included""" - self._default_path = [ - path - for path in [ - os.path.dirname(self.images_filepaths[0]), - os.path.dirname(self.labels_filepaths[0]), - self.model_path, - self.results_path, - ] - if (path != [""] and path != "") - ] - - def build(self): + logger.info("Attempt complete : Cache emptied") + + # def update_default(self): # TODO add custom models + # """Update default path for smoother file dialogs, here with :py:attr:`~model_path` included""" + # + # if len(self.images_filepaths) != 0: + # from_images = str(Path(self.images_filepaths[0]).parent) + # else: + # from_images = None + # + # if len(self.labels_filepaths) != 0: + # from_labels = str(Path(self.labels_filepaths[0]).parent) + # else: + # from_labels = None + # + # possible_paths = [ + # path + # for path in [ + # from_images, + # from_labels, + # # self.model_path, + # self.results_path, + # ] + # if path is not None + # ] + # self._default_folders = possible_paths + # update if model_path is used again + + def _build(self): raise NotImplementedError("Should be defined in children classes") diff --git a/napari_cellseg3d/model_instance_seg.py b/napari_cellseg3d/code_models/model_instance_seg.py similarity index 72% rename from napari_cellseg3d/model_instance_seg.py rename to napari_cellseg3d/code_models/model_instance_seg.py index db16ddf3..88940d7d 100644 --- a/napari_cellseg3d/model_instance_seg.py +++ b/napari_cellseg3d/code_models/model_instance_seg.py @@ -1,7 +1,11 @@ from __future__ import division from __future__ import print_function +from dataclasses import dataclass +from typing import List + import numpy as np +from skimage.filters import thresholding # from skimage.measure import marching_cubes # from skimage.measure import mesh_surface_area @@ -18,8 +22,48 @@ # from napari_cellseg3d.utils import sphericity_volume_area +@dataclass +class ImageStats: + volume: List[float] + centroid_x: List[float] + centroid_y: List[float] + centroid_z: List[float] + sphericity_ax: List[float] + image_size: List[int] + total_image_volume: int + total_filled_volume: int + filling_ratio: float + number_objects: int + + def get_dict(self): + return { + "Volume": self.volume, + "Centroid x": self.centroid_x, + "Centroid y": self.centroid_y, + "Centroid z": self.centroid_z, + # "Sphericity (volume/area)": sphericity_va, + "Sphericity (axes)": self.sphericity_ax, + "Image size": self.image_size, + "Total image volume": self.total_image_volume, + "Total object volume (pixels)": self.total_filled_volume, + "Filling ratio": self.filling_ratio, + "Number objects": self.number_objects, + } + + +def threshold(volume, thresh): + im = np.squeeze(volume) + binary = im > thresh + return np.where(binary, im, np.zeros_like(im)) + + def binary_connected( - volume, thres=0.5, thres_small=3, scale_factors=(1.0, 1.0, 1.0) + volume, + thres=0.5, + thres_small=3, + # scale_factors=(1.0, 1.0, 1.0), + *args, + **kwargs ): r"""Convert binary foreground probability maps to instance masks via connected-component labeling. @@ -35,30 +79,32 @@ def binary_connected( segm = label(foreground) segm = remove_small_objects(segm, thres_small) - if not all(x == 1.0 for x in scale_factors): - target_size = ( - int(semantic.shape[0] * scale_factors[0]), - int(semantic.shape[1] * scale_factors[1]), - int(semantic.shape[2] * scale_factors[2]), - ) - segm = resize( - segm, - target_size, - order=0, - anti_aliasing=False, - preserve_range=True, - ) + # if not all(x == 1.0 for x in scale_factors): + # target_size = ( + # int(semantic.shape[0] * scale_factors[0]), + # int(semantic.shape[1] * scale_factors[1]), + # int(semantic.shape[2] * scale_factors[2]), + # ) + # segm = resize( + # segm, + # target_size, + # order=0, + # anti_aliasing=False, + # preserve_range=True, + # ) return segm def binary_watershed( volume, - thres_seeding=0.9, - thres_small=10, thres_objects=0.3, - scale_factors=(1.0, 1.0, 1.0), + thres_small=10, + thres_seeding=0.9, + # scale_factors=(1.0, 1.0, 1.0), rem_seed_thres=3, + *args, + **kwargs ): r"""Convert binary foreground probability maps to instance masks via watershed segmentation algorithm. @@ -83,19 +129,19 @@ def binary_watershed( segm = watershed(-semantic.astype(np.float64), seed, mask=foreground) segm = remove_small_objects(segm, thres_small) - if not all(x == 1.0 for x in scale_factors): - target_size = ( - int(semantic.shape[0] * scale_factors[0]), - int(semantic.shape[1] * scale_factors[1]), - int(semantic.shape[2] * scale_factors[2]), - ) - segm = resize( - segm, - target_size, - order=0, - anti_aliasing=False, - preserve_range=True, - ) + # if not all(x == 1.0 for x in scale_factors): + # target_size = ( + # int(semantic.shape[0] * scale_factors[0]), + # int(semantic.shape[1] * scale_factors[1]), + # int(semantic.shape[2] * scale_factors[2]), + # ) + # segm = resize( + # segm, + # target_size, + # order=0, + # anti_aliasing=False, + # preserve_range=True, + # ) return np.array(segm) @@ -227,16 +273,15 @@ def fill(lst, n=len(properties) - 1): else: ratio = 0 - return { - "Volume": volume, - "Centroid x": [region.centroid[0] for region in properties], - "Centroid y": [region.centroid[1] for region in properties], - "Centroid z": [region.centroid[2] for region in properties], - # "Sphericity (volume/area)": sphericity_va, - "Sphericity (axes)": sphericity_ax, - "Image size": fill([volume_image.shape]), - "Total image volume": fill([len(volume_image.flatten())]), - "Total object volume (pixels)": fill([np.sum(volume)]), - "Filling ratio": ratio, - "Number objects": fill([len(properties)]), - } + return ImageStats( + volume, + [region.centroid[0] for region in properties], + [region.centroid[0] for region in properties], + [region.centroid[2] for region in properties], + sphericity_ax, + fill([volume_image.shape]), + fill([len(volume_image.flatten())]), + fill([np.sum(volume)]), + ratio, + fill([len(properties)]), + ) diff --git a/napari_cellseg3d/model_workers.py b/napari_cellseg3d/code_models/model_workers.py similarity index 67% rename from napari_cellseg3d/model_workers.py rename to napari_cellseg3d/code_models/model_workers.py index d857e22a..e85548a7 100644 --- a/napari_cellseg3d/model_workers.py +++ b/napari_cellseg3d/code_models/model_workers.py @@ -1,13 +1,12 @@ -import os import platform +from dataclasses import dataclass +from math import ceil from pathlib import Path -import importlib.util +from typing import List from typing import Optional import numpy as np -from tifffile import imwrite import torch -from tqdm import tqdm # MONAI from monai.data import CacheDataset @@ -44,14 +43,20 @@ # Qt from qtpy.QtCore import Signal +from tifffile import imwrite +from tqdm import tqdm +from napari_cellseg3d import config +from napari_cellseg3d import interface as ui from napari_cellseg3d import utils -from napari_cellseg3d import log_utility # local -from napari_cellseg3d.model_instance_seg import binary_connected -from napari_cellseg3d.model_instance_seg import binary_watershed -from napari_cellseg3d.model_instance_seg import volume_stats +from napari_cellseg3d.code_models.model_instance_seg import binary_connected +from napari_cellseg3d.code_models.model_instance_seg import binary_watershed +from napari_cellseg3d.code_models.model_instance_seg import ImageStats +from napari_cellseg3d.code_models.model_instance_seg import volume_stats + +logger = utils.LOGGER """ Writing something to log messages from outside the main thread is rather problematic (plenty of silent crashes...) @@ -63,20 +68,19 @@ # https://www.pythoncentral.io/pysidepyqt-tutorial-creating-your-own-signals-and-slots/ # https://napari-staging-site.github.io/guides/stable/threading.html -WEIGHTS_DIR = os.path.dirname(os.path.realpath(__file__)) + str( - Path("/models/pretrained") -) +WEIGHTS_DIR = Path(__file__).parent.resolve() / Path("models/pretrained") +logger.debug(f"PRETRAINED WEIGHT DIR LOCATION : {WEIGHTS_DIR}") class WeightsDownloader: """A utility class the downloads the weights of a model when needed.""" - def __init__(self, log_widget: Optional[log_utility.Log] = None): + def __init__(self, log_widget: Optional[ui.Log] = None): """ Creates a WeightsDownloader, optionally with a log widget to display the progress. Args: - log_widget (log_utility.Log): a Log to display the progress bar in. If None, uses print() + log_widget (log_utility.Log): a Log to display the progress bar in. If None, uses logger.info() """ self.log_widget = log_widget @@ -96,24 +100,17 @@ def download_weights(self, model_name: str, model_weights_filename: str): def show_progress(count, block_size, total_size): pbar.update(block_size) - cellseg3d_path = os.path.split( - importlib.util.find_spec("napari_cellseg3d").origin - )[0] - pretrained_folder_path = os.path.join( - cellseg3d_path, "models", "pretrained" - ) - json_path = os.path.join( - pretrained_folder_path, "pretrained_model_urls.json" - ) + logger.info("*" * 20) + pretrained_folder_path = WEIGHTS_DIR + json_path = pretrained_folder_path / Path("pretrained_model_urls.json") - check_path = os.path.join( - pretrained_folder_path, model_weights_filename - ) - if os.path.exists(check_path): + check_path = pretrained_folder_path / Path(model_weights_filename) + + if Path(check_path).is_file(): message = f"Weight file {model_weights_filename} already exists, skipping download" if self.log_widget is not None: self.log_widget.print_and_log(message, printing=False) - print(message) + logger.info(message) return with open(json_path) as f: @@ -122,10 +119,10 @@ def show_progress(count, block_size, total_size): url = neturls[model_name] response = urllib.request.urlopen(url) - start_message = f"Downloading the model from the M.W. Mathis Lab server {url}...." + start_message = f"Downloading the model from HuggingFace {url}...." total_size = int(response.getheader("Content-Length")) if self.log_widget is None: - print(start_message) + logger.info(start_message) pbar = tqdm(unit="B", total=total_size, position=0) else: self.log_widget.print_and_log(start_message) @@ -140,26 +137,33 @@ def show_progress(count, block_size, total_size): url, reporthook=show_progress ) with tarfile.open(filename, mode="r:gz") as tar: + def is_within_directory(directory, target): - - abs_directory = os.path.abspath(directory) - abs_target = os.path.abspath(target) - - prefix = os.path.commonprefix([abs_directory, abs_target]) - - return prefix == abs_directory - - def safe_extract(tar, path=".", members=None, *, numeric_owner=False): - + + abs_directory = Path(directory).resolve() + abs_target = Path(target).resolve() + # prefix = os.path.commonprefix([abs_directory, abs_target]) + logger.debug(abs_directory) + logger.debug(abs_target) + logger.debug(abs_directory.parents) + + return abs_directory in abs_target.parents + + def safe_extract( + tar, path=".", members=None, *, numeric_owner=False + ): + for member in tar.getmembers(): - member_path = os.path.join(path, member.name) + member_path = str(Path(path) / member.name) if not is_within_directory(path, member_path): - raise Exception("Attempted Path Traversal in Tar File") - - tar.extractall(path, members, numeric_owner=numeric_owner) - - + raise Exception( + "Attempted Path Traversal in Tar File" + ) + + tar.extractall(path, members, numeric_owner=numeric_owner) + safe_extract(tar, pretrained_folder_path) + # tar.extractall(pretrained_folder_path) else: raise ValueError( f"Unknown model: {model_name}. Should be one of {', '.join(neturls)}" @@ -183,7 +187,16 @@ def __init__(self): super().__init__() -# TODO : use dataclass for config instead ? +@dataclass +class InferenceResult: + """Class to record results of a segmentation job""" + + image_id: int = 0 + original: np.array = None + instance_labels: np.array = None + stats: ImageStats = None + result: np.array = None + model_name: str = None class InferenceWorker(GeneratorWorker): @@ -192,24 +205,12 @@ class InferenceWorker(GeneratorWorker): def __init__( self, - device, - model_dict, - weights_dict, - results_path, - filetype, - transforms, - instance, - use_window, - window_infer_size, - window_overlap, - keep_on_cpu, - stats_csv, - images_filepaths=None, - layer=None, # FIXME + worker_config: config.InferenceWorkerConfig, ): """Initializes a worker for inference with the arguments needed by the :py:func:`~inference` function. Args: + * config (config.InferenceWorkerConfig): dataclass containing the proper configuration elements * device: cuda or cpu device to use for torch * model_dict: the :py:attr:`~self.models_dict` dictionary to obtain the model name, class and instance @@ -242,26 +243,8 @@ def __init__( self._signals = LogSignal() # add custom signals self.log_signal = self._signals.log_signal self.warn_signal = self._signals.warn_signal - ########################################### - ########################################### - self.device = device - self.model_dict = model_dict - self.weights_dict = weights_dict - self.results_path = results_path - self.filetype = filetype - self.transforms = transforms - self.instance_params = instance - self.use_window = use_window - self.window_infer_size = window_infer_size - self.window_overlap_percentage = window_overlap - self.keep_on_cpu = keep_on_cpu - self.stats_to_csv = stats_csv - ############################################ - ############################################ - self.layer = layer - self.images_filepaths = images_filepaths - ############################################ - ############################################ + + self.config = worker_config """These attributes are all arguments of :py:func:~inference, please see that for reference""" @@ -294,48 +277,60 @@ def warn(self, warning): def log_parameters(self): + config = self.config + self.log("-" * 20) self.log("\nParameters summary :") - self.log(f"Model is : {self.model_dict['name']}") - if self.transforms["thresh"][0]: + self.log(f"Model is : {config.model_info.name}") + if config.post_process_config.thresholding.enabled: self.log( - f"Thresholding is enabled at {self.transforms['thresh'][1]}" + f"Thresholding is enabled at {config.post_process_config.thresholding.threshold_value}" ) - if self.use_window: + if config.sliding_window_config.is_enabled(): status = "enabled" else: status = "disabled" self.log(f"Window inference is {status}\n") + if status == "enabled": + self.log( + f"Window size is {self.config.sliding_window_config.window_size}" + ) + self.log( + f"Window overlap is {self.config.sliding_window_config.window_overlap}" + ) - if self.keep_on_cpu: + if config.keep_on_cpu: self.log(f"Dataset loaded to CPU") else: - self.log(f"Dataset loaded on {self.device}") + self.log(f"Dataset loaded on {config.device}") - if self.transforms["zoom"][0]: - self.log(f"Scaling factor : {self.transforms['zoom'][1]} (x,y,z)") + if config.post_process_config.zoom.enabled: + self.log( + f"Scaling factor : {config.post_process_config.zoom.zoom_values} (x,y,z)" + ) - if self.instance_params["do_instance"]: + instance_config = config.post_process_config.instance + if instance_config.enabled: self.log( - f"Instance segmentation enabled, method : {self.instance_params['method']}\n" - f"Probability threshold is {self.instance_params['threshold']:.2f}\n" - f"Objects smaller than {self.instance_params['size_small']} pixels will be removed\n" + f"Instance segmentation enabled, method : {instance_config.method}\n" + f"Probability threshold is {instance_config.threshold.threshold_value:.2f}\n" + f"Objects smaller than {instance_config.small_object_removal_threshold.threshold_value} pixels will be removed\n" ) self.log("-" * 20) def load_folder(self): - images_dict = self.create_inference_dict(self.images_filepaths) + images_dict = self.create_inference_dict(self.config.images_filepaths) # TODO : better solution than loading first image always ? data_check = LoadImaged(keys=["image"])(images_dict[0]) check = data_check["image"].shape - self.log("\nChecking dimensions...") + # self.log("\nChecking dimensions...") pad = utils.get_padding_dim(check) # dims = self.model_dict["model_input_size"] @@ -358,11 +353,11 @@ def load_folder(self): # # self.log_parameters() # - # model.to(self.device) + # model.to(self.config.device) - # print("FILEPATHS PRINT") - # print(self.images_filepaths) - if self.use_window: + # logger.debug("FILEPATHS PRINT") + # logger.debug(self.images_filepaths) + if self.config.sliding_window_config.is_enabled(): load_transforms = Compose( [ LoadImaged(keys=["image"]), @@ -395,8 +390,8 @@ def load_folder(self): return inference_loader def load_layer(self): - - data = np.squeeze(self.layer.data) + self.log("Loading layer\n") + data = np.squeeze(self.config.layer) volume = np.array(data, dtype=np.int16) @@ -410,14 +405,14 @@ def load_layer(self): volume = np.swapaxes( volume, 0, 2 ) # for anisotropy to be monai-like, i.e. zyx # FIXME rotation not always correct - print("Loading layer\n") + dims_check = volume.shape - self.log("\nChecking dimensions...") + # self.log("\nChecking dimensions...") pad = utils.get_padding_dim(dims_check) - # print(volume.shape) - # print(volume.dtype) - if self.use_window: + # logger.debug(volume.shape) + # logger.debug(volume.dtype) + if self.config.sliding_window_config.is_enabled(): load_transforms = Compose( [ ToTensor(), @@ -444,7 +439,6 @@ def load_layer(self): log_stats=True, ) - self.log("\nLoading dataset...") input_image = load_transforms(volume) self.log("Done") return input_image @@ -461,31 +455,47 @@ def model_output( inputs = inputs.to("cpu") model_output = lambda inputs: post_process_transforms( - self.model_dict["class"].get_output(model, inputs) + self.config.model_info.get_model().get_output( + model, inputs + ) # TODO(cyril) refactor those functions ) - if self.keep_on_cpu: + def model_output(inputs): + return post_process_transforms( + self.config.model_info.get_model().get_output(model, inputs) + ) + + if self.config.keep_on_cpu: dataset_device = "cpu" else: - dataset_device = self.device + dataset_device = self.config.device - if self.use_window: - window_size = self.window_infer_size - window_overlap = self.window_overlap_percentage - else: - window_size = None - window_overlap = 0.25 + window_size = self.config.sliding_window_config.window_size + window_overlap = self.config.sliding_window_config.window_overlap + + # FIXME + # import sys + + # old_stdout = sys.stdout + # old_stderr = sys.stderr + + # sys.stdout = self.downloader.log_widget + # sys.stdout = self.downloader.log_widget outputs = sliding_window_inference( inputs, - roi_size=window_size, - sw_batch_size=1, + roi_size=[window_size, window_size, window_size], + sw_batch_size=1, # TODO add param predictor=model_output, - sw_device=self.device, + sw_device=self.config.device, device=dataset_device, overlap=window_overlap, + progress=True, ) + # sys.stdout = old_stdout + # sys.stderr = old_stderr + out = outputs.detach().cpu() if aniso_transform is not None: @@ -498,13 +508,13 @@ def model_output( else: return out - def create_result_dict( + def create_result_dict( # FIXME replace with result class self, semantic_labels, instance_labels, from_layer: bool, original=None, - data_dict=None, + stats=None, i=0, ): @@ -520,17 +530,17 @@ def create_result_dict( ) semantic_labels = np.swapaxes(semantic_labels, 0, 2) - return { - "image_id": i + 1, - "original": original, - "instance_labels": instance_labels, - "object stats": data_dict, - "result": semantic_labels, - "model_name": self.model_dict["name"], - } + return InferenceResult( + image_id=i + 1, + original=original, + instance_labels=instance_labels, + stats=stats, + result=semantic_labels, + model_name=self.config.model_info.name, + ) def get_original_filename(self, i): - return os.path.basename(self.images_filepaths[i]).split(".")[0] + return Path(self.config.images_filepaths[i]).stem def get_instance_result(self, semantic_labels, from_layer=False, i=-1): @@ -539,7 +549,7 @@ def get_instance_result(self, semantic_labels, from_layer=False, i=-1): "An ID should be provided when running from a file" ) - if self.instance_params["do_instance"]: + if self.config.post_process_config.instance.enabled: instance_labels = self.instance_seg( semantic_labels, i + 1, @@ -567,17 +577,16 @@ def save_image( time = utils.get_date_time() file_path = ( - self.results_path + self.config.results_path + "/" + f"Prediction_{i+1}" + original_filename - + self.model_dict["name"] + + self.config.model_info.name + f"_{time}_" - + self.filetype + + self.config.filetype ) - imwrite(file_path, image) - filename = os.path.split(file_path)[1] + filename = Path(file_path).stem if from_layer: self.log(f"\nLayer prediction saved as : {filename}") @@ -585,24 +594,32 @@ def save_image( self.log(f"\nFile n°{i+1} saved as : {filename}") def aniso_transform(self, image): - zoom = self.transforms["zoom"][1] - anisotropic_transform = Zoom( - zoom=zoom, - keep_size=False, - padding_mode="empty", - ) - return anisotropic_transform(image[0]) + + if self.config.post_process_config.zoom.enabled: + zoom = self.config.post_process_config.zoom.zoom_values + anisotropic_transform = Zoom( + zoom=zoom, + keep_size=False, + padding_mode="empty", + ) + return anisotropic_transform(image[0]) + else: + return image def instance_seg(self, to_instance, image_id=0, original_filename="layer"): if image_id is not None: self.log(f"\nRunning instance segmentation for image n°{image_id}") - threshold = self.instance_params["threshold"] - size_small = self.instance_params["size_small"] - method_name = self.instance_params["method"] + threshold = ( + self.config.post_process_config.instance.threshold.threshold_value + ) + size_small = ( + self.config.post_process_config.instance.small_object_removal_threshold.threshold_value + ) + method_name = self.config.post_process_config.instance.method - if method_name == "Watershed": + if method_name == "Watershed": # FIXME use dict in config instead def method(image): return binary_watershed(image, threshold, size_small) @@ -620,21 +637,21 @@ def method(image): instance_labels = method(to_instance) instance_filepath = ( - self.results_path + self.config.results_path + "/" + f"Instance_seg_labels_{image_id}_" + original_filename + "_" - + self.model_dict["name"] + + self.config.model_info.name + f"_{utils.get_date_time()}_" - + self.filetype + + self.config.filetype ) imwrite(instance_filepath, instance_labels) self.log( f"Instance segmentation results for image n°{image_id} have been saved as:" ) - self.log(os.path.split(instance_filepath)[1]) + self.log(Path(instance_filepath).name) return instance_labels def inference_on_folder(self, inf_data, i, model, post_process_transforms): @@ -652,30 +669,27 @@ def inference_on_folder(self, inf_data, i, model, post_process_transforms): ) self.save_image(out, i=i) - instance_labels, data_dict = self.get_instance_result(out, i=i) + instance_labels, stats = self.get_instance_result(out, i=i) original = np.array(inf_data["image"]).astype(np.float32) - self.log(f"Inference completed on layer") + self.log(f"Inference completed on image n°{i+1}") return self.create_result_dict( out, instance_labels, from_layer=False, original=original, - data_dict=data_dict, + stats=stats, i=i, ) def stats_csv(self, instance_labels): - if self.stats_to_csv: - - # try: - - data_dict = volume_stats( + if self.config.compute_stats: + stats = volume_stats( instance_labels ) # TODO test with area mesh function - return data_dict + return stats # except ValueError as e: # self.log(f"Error occurred during stats computing : {e}") @@ -698,12 +712,13 @@ def inference_on_layer(self, image, model, post_process_transforms): self.save_image(out, from_layer=True) - instance_labels, data_dict = self.get_instance_result( - out, from_layer=True - ) + instance_labels, stats = self.get_instance_result(out, from_layer=True) return self.create_result_dict( - out, instance_labels, from_layer=True, data_dict=data_dict + semantic_labels=out, + instance_labels=instance_labels, + from_layer=True, + stats=stats, ) def inference(self): @@ -741,36 +756,41 @@ def inference(self): """ sys = platform.system() - print(f"OS is {sys}") + logger.debug(f"OS is {sys}") if sys == "Darwin": torch.set_num_threads(1) # required for threading on macOS ? self.log("Number of threads has been set to 1 for macOS") try: - dims = self.model_dict["model_input_size"] - self.log(f"MODEL DIMS : {dims}") - self.log(self.model_dict["name"]) + dims = self.config.model_info.model_input_size + # self.log(f"MODEL DIMS : {dims}") + model_name = self.config.model_info.name + model_class = self.config.model_info.get_model() + self.log(model_name) - if self.model_dict["name"] == "SegResNet": - model = self.model_dict["class"].get_net( + weights_config = self.config.weights_config + post_process_config = self.config.post_process_config + + if model_name == "SegResNet": + model = model_class.get_net( input_image_size=[ dims, dims, dims, ], # TODO FIX ! find a better way & remove model-specific code ) - elif self.model_dict["name"] == "SwinUNetR": - model = self.model_dict["class"].get_net( + elif model_name == "SwinUNetR": + model = model_class.get_net( img_size=[dims, dims, dims], use_checkpoint=False, ) else: - model = self.model_dict["class"].get_net() - model = model.to(self.device) + model = model_class.get_net() + model = model.to(self.config.device) self.log_parameters() - model.to(self.device) + model.to(self.config.device) # load_transforms = Compose( # [ @@ -784,37 +804,35 @@ def inference(self): # ] # ) - if not self.transforms["thresh"][0]: + if not post_process_config.thresholding.enabled: post_process_transforms = EnsureType() else: - t = self.transforms["thresh"][1] + t = post_process_config.thresholding.threshold_value post_process_transforms = Compose( AsDiscrete(threshold=t), EnsureType() ) self.log("\nLoading weights...") - - if self.weights_dict["custom"]: - weights = self.weights_dict["path"] + if weights_config.custom: + weights = weights_config.path else: self.downloader.download_weights( - self.model_dict["name"], - self.model_dict["class"].get_weights_file(), + model_name, + model_class.get_weights_file(), ) - weights = os.path.join( - WEIGHTS_DIR, self.model_dict["class"].get_weights_file() + weights = str( + WEIGHTS_DIR / Path(model_class.get_weights_file()) ) - model.load_state_dict( torch.load( weights, - map_location=self.device, + map_location=self.config.device, ) ) self.log("Done") - is_folder = self.images_filepaths is not None - is_layer = self.layer is not None + is_folder = self.config.images_filepaths is not None + is_layer = self.config.layer is not None if is_layer and is_folder: raise ValueError( @@ -829,7 +847,7 @@ def inference(self): # # check_data = first(inference_loader) # image = check_data[0][0] - # print(image.shape) + # logger.debug(image.shape) ################## ################## elif is_layer: @@ -854,34 +872,29 @@ def inference(self): model.to("cpu") except Exception as e: - self.log(f"Error : {e}") + self.log(f"Error during inference : {e}") self.quit() finally: self.quit() +@dataclass +class TrainingReport: + show_plot: bool = True + epoch: int = 0 + loss_values: List = None + validation_metric: List = None + weights: np.array = None + images: List[np.array] = None + + class TrainingWorker(GeneratorWorker): """A custom worker to run training jobs in. Inherits from :py:class:`napari.qt.threading.GeneratorWorker`""" def __init__( self, - device, - model_dict, - weights_path, - data_dicts, - validation_percent, - max_epochs, - loss_function, - learning_rate, - val_interval, - batch_size, - results_path, - sampling, - num_samples, - sample_size, - do_augmentation, - deterministic, + config: config.TrainingWorkerConfig, ): """Initializes a worker for inference with the arguments needed by the :py:func:`~train` function. Note: See :py:func:`~train` @@ -927,24 +940,7 @@ def __init__( self._weight_error = False ############################################# - self.device = device - self.model_dict = model_dict - self.weights_path = weights_path - self.data_dicts = data_dicts - self.validation_percent = validation_percent - self.max_epochs = max_epochs - self.loss_function = loss_function - self.learning_rate = learning_rate - self.val_interval = val_interval - self.batch_size = batch_size - self.results_path = results_path - - self.num_samples = num_samples - self.sampling = sampling - self.sample_size = sample_size - - self.do_augment = do_augmentation - self.seed_dict = deterministic + self.config = config self.train_files = [] self.val_files = [] @@ -972,43 +968,47 @@ def log_parameters(self): self.log("Parameters summary :\n") self.log( - f"Percentage of dataset used for validation : {self.validation_percent * 100}%" + f"Percentage of dataset used for validation : {self.config.validation_percent * 100}%" ) + self.log("-" * 10) self.log("Training files :\n") [ - self.log(f"{os.path.basename(str(train_file)[:-2])}\n") + self.log(f"{Path(train_file['image']).name}\n") for train_file in self.train_files ] self.log("-" * 10) self.log("Validation files :\n") [ - self.log(f"{os.path.basename(str(val_file)[:-2])}\n") + self.log(f"{Path(val_file['image']).name}\n") for val_file in self.val_files ] self.log("-" * 10) - if self.seed_dict["use deterministic"]: + + if self.config.deterministic_config.enabled: self.log(f"Deterministic training is enabled") - self.log(f"Seed is {self.seed_dict['seed']}") + self.log(f"Seed is {self.config.deterministic_config.seed}") - self.log(f"Training for {self.max_epochs} epochs") - self.log(f"Loss function is : {str(self.loss_function)}") - self.log(f"Validation is performed every {self.val_interval} epochs") - self.log(f"Batch size is {self.batch_size}") - self.log(f"Learning rate is {self.learning_rate}") + self.log(f"Training for {self.config.max_epochs} epochs") + self.log(f"Loss function is : {str(self.config.loss_function)}") + self.log( + f"Validation is performed every {self.config.validation_interval} epochs" + ) + self.log(f"Batch size is {self.config.batch_size}") + self.log(f"Learning rate is {self.config.learning_rate}") - if self.sampling: + if self.config.sampling: self.log( - f"Extracting {self.num_samples} patches of size {self.sample_size}" + f"Extracting {self.config.num_samples} patches of size {self.config.sample_size}" ) else: self.log("Using whole images as dataset") - if self.do_augment: + if self.config.do_augmentation: self.log("Data augmentation is enabled") - if self.weights_path is not None: - self.log(f"Using weights from : {self.weights_path}") + if not self.config.weights_info.use_pretrained: + self.log(f"Using weights from : {self.config.weights_info.path}") if self._weight_error: self.log( ">>>>>>>>>>>>>>>>>\n" @@ -1064,66 +1064,102 @@ def train(self): # error_log = open(results_path +"/error_log.log" % multiprocessing.current_process().name, 'x') # faulthandler.enable(file=error_log, all_threads=True) ######################### + model_config = self.config.model_info + weights_config = self.config.weights_info + deterministic_config = self.config.deterministic_config + try: - if self.seed_dict["use deterministic"]: + if deterministic_config.enabled: set_determinism( - seed=self.seed_dict["seed"] + seed=deterministic_config.seed ) # use_deterministic_algorithms = True causes cuda error sys = platform.system() - print(sys) + logger.debug(sys) if sys == "Darwin": # required for macOS ? torch.set_num_threads(1) self.log("Number of threads has been set to 1 for macOS") - model_name = self.model_dict["name"] - model_class = self.model_dict["class"] + self.log(f"config model : {self.config.model_info.name}") + model_name = model_config.name + model_class = model_config.get_model() - if not self.sampling: - data_check = LoadImaged(keys=["image"])(self.data_dicts[0]) + if not self.config.sampling: + data_check = LoadImaged(keys=["image"])( + self.config.train_data_dict[0] + ) check = data_check["image"].shape + do_sampling = self.config.sampling + if model_name == "SegResNet": - if self.sampling: - size = self.sample_size + if do_sampling: + size = self.config.sample_size else: size = check - print(f"Size of image : {size}") + logger.info(f"Size of image : {size}") model = model_class.get_net( input_image_size=utils.get_padding_dim(size), - out_channels=1, - dropout_prob=0.3, + # out_channels=1, + # dropout_prob=0.3, ) elif model_name == "SwinUNetR": - if self.sampling: + if do_sampling: size = self.sample_size else: size = check - print(f"Size of image : {size}") + logger.info(f"Size of image : {size}") model = model_class.get_net( img_size=utils.get_padding_dim(size), use_checkpoint=True, ) else: model = model_class.get_net() # get an instance of the model - model = model.to(self.device) + model = model.to(self.config.device) epoch_loss_values = [] val_metric_values = [] - self.train_files, self.val_files = ( - self.data_dicts[ - 0 : int(len(self.data_dicts) * self.validation_percent) - ], - self.data_dicts[ - int(len(self.data_dicts) * self.validation_percent) : - ], + if len(self.config.train_data_dict) > 1: + self.train_files, self.val_files = ( + self.config.train_data_dict[ + 0 : int( + len(self.config.train_data_dict) + * self.config.validation_percent + ) + ], + self.config.train_data_dict[ + int( + len(self.config.train_data_dict) + * self.config.validation_percent + ) : + ], + ) + else: + self.train_files = self.val_files = self.config.train_data_dict + msg = f"Only one image file was provided : {self.config.train_data_dict[0]['image']}.\n" + + logger.debug(f"SAMPLING is {self.config.sampling}") + if not self.config.sampling: + msg += f"Sampling is not in use, the only image provided will be used as the validation file." + self.warn(msg) + else: + msg += f"Samples for validation will be cropped for the same only volume that is being used for training" + + logger.warning(msg) + + logger.debug( + f"Data dict from config is {self.config.train_data_dict}" ) + logger.debug(f"Train files : {self.train_files}") + logger.debug(f"Val. files : {self.val_files}") - if self.train_files == [] or self.val_files == []: - self.log("ERROR : datasets are empty") + if len(self.train_files) == 0: + raise ValueError("Training dataset is empty") + if len(self.val_files) == 0: + raise ValueError("Validation dataset is empty") - if self.sampling: + if do_sampling: sample_loader = Compose( [ LoadImaged(keys=["image", "label"]), @@ -1131,24 +1167,24 @@ def train(self): RandSpatialCropSamplesd( keys=["image", "label"], roi_size=( - self.sample_size + self.config.sample_size ), # multiply by axis_stretch_factor if anisotropy # max_roi_size=(120, 120, 120), random_size=False, - num_samples=self.num_samples, + num_samples=self.config.num_samples, ), Orientationd(keys=["image", "label"], axcodes="PLI"), SpatialPadd( keys=["image", "label"], spatial_size=( - utils.get_padding_dim(self.sample_size) + utils.get_padding_dim(self.config.sample_size) ), ), EnsureTyped(keys=["image", "label"]), ] ) - if self.do_augment: + if self.config.do_augmentation: train_transforms = ( Compose( # TODO : figure out which ones and values ? [ @@ -1178,24 +1214,46 @@ def train(self): ] ) # self.log("Loading dataset...\n") - if self.sampling: - print("train_ds") + if do_sampling: + + # if there is only one volume, split samples + # TODO(cyril) : maybe implement something in user config to toggle this behavior + if len(self.config.train_data_dict) < 2: + num_train_samples = ceil( + self.config.num_samples + * self.config.validation_percent + ) + num_val_samples = ceil( + self.config.num_samples + * (1 - self.config.validation_percent) + ) + else: + num_train_samples = ( + num_val_samples + ) = self.config.num_samples + + logger.debug(f"AMOUNT of train samples : {num_train_samples}") + logger.debug( + f"AMOUNT of validation samples : {num_val_samples}" + ) + + logger.debug("train_ds") train_ds = PatchDataset( data=self.train_files, transform=train_transforms, patch_func=sample_loader, - samples_per_image=self.num_samples, + samples_per_image=num_train_samples, ) - print("val_ds") + logger.debug("val_ds") val_ds = PatchDataset( data=self.val_files, transform=val_transforms, patch_func=sample_loader, - samples_per_image=self.num_samples, + samples_per_image=num_val_samples, ) else: - load_single_images = Compose( + load_whole_images = Compose( [ LoadImaged(keys=["image", "label"]), EnsureChannelFirstd(keys=["image", "label"]), @@ -1207,32 +1265,32 @@ def train(self): EnsureTyped(keys=["image", "label"]), ] ) - print("Cache dataset : train") + logger.debug("Cache dataset : train") train_ds = CacheDataset( data=self.train_files, - transform=Compose(load_single_images, train_transforms), + transform=Compose(load_whole_images, train_transforms), ) - print("Cache dataset : val") + logger.debug("Cache dataset : val") val_ds = CacheDataset( - data=self.val_files, transform=load_single_images + data=self.val_files, transform=load_whole_images ) - print("Dataloader") + logger.debug("Dataloader") train_loader = DataLoader( train_ds, - batch_size=self.batch_size, + batch_size=self.config.batch_size, shuffle=True, num_workers=2, collate_fn=pad_list_data_collate, ) val_loader = DataLoader( - val_ds, batch_size=self.batch_size, num_workers=2 + val_ds, batch_size=self.config.batch_size, num_workers=2 ) - print("\nDone") + logger.info("\nDone") - print("Optimizer") + logger.debug("Optimizer") optimizer = torch.optim.Adam( - model.parameters(), self.learning_rate + model.parameters(), self.config.learning_rate ) dice_metric = DiceMetric(include_background=True, reduction="mean") @@ -1240,25 +1298,26 @@ def train(self): best_metric_epoch = -1 # time = utils.get_date_time() - print("Weights") - if self.weights_path is not None: - if self.weights_path == "use_pretrained": + logger.debug("Weights") + + if weights_config.custom: + if weights_config.use_pretrained: weights_file = model_class.get_weights_file() self.downloader.download_weights(model_name, weights_file) - weights = os.path.join(WEIGHTS_DIR, weights_file) - self.weights_path = weights + weights = WEIGHTS_DIR / Path(weights_file) + weights_config.path = weights else: - weights = os.path.join(self.weights_path) + weights = str(Path(weights_config.path)) try: model.load_state_dict( torch.load( weights, - map_location=self.device, + map_location=self.config.device, ) ) except RuntimeError as e: - print(f"Error : {e}") + logger.error(f"Error when loading weights : {e}") warn = ( "WARNING:\nIt'd seem that the weights were incompatible with the model,\n" "the model will be trained from random weights" @@ -1267,7 +1326,7 @@ def train(self): self.warn(warn) self._weight_error = True - if self.device.type == "cuda": + if self.config.device.type == "cuda": self.log("\nUsing GPU :") self.log(torch.cuda.get_device_name(0)) else: @@ -1275,11 +1334,17 @@ def train(self): self.log_parameters() - for epoch in range(self.max_epochs): + device = self.config.device + + if model_name == "test": + self.quit() + yield TrainingReport(False) + + for epoch in range(self.config.max_epochs): # self.log("\n") self.log("-" * 10) - self.log(f"Epoch {epoch + 1}/{self.max_epochs}") - if self.device.type == "cuda": + self.log(f"Epoch {epoch + 1}/{self.config.max_epochs}") + if device.type == "cuda": self.log("Memory Usage:") alloc_mem = round( torch.cuda.memory_allocated(0) / 1024**3, 1 @@ -1296,13 +1361,13 @@ def train(self): for batch_data in train_loader: step += 1 inputs, labels = ( - batch_data["image"].to(self.device), - batch_data["label"].to(self.device), + batch_data["image"].to(device), + batch_data["label"].to(device), ) optimizer.zero_grad() outputs = model_class.get_output(model, inputs) - # print(f"OUT : {outputs.shape}") - loss = self.loss_function(outputs, labels) + # self.log(f"Output dimensions : {outputs.shape}") + loss = self.config.loss_function(outputs, labels) loss.backward() optimizer.step() epoch_loss += loss.detach().item() @@ -1310,19 +1375,26 @@ def train(self): f"* {step}/{len(train_ds) // train_loader.batch_size}, " f"Train loss: {loss.detach().item():.4f}" ) - yield {"plot": False, "weights": model.state_dict()} + yield TrainingReport( + show_plot=False, weights=model.state_dict() + ) epoch_loss /= step epoch_loss_values.append(epoch_loss) self.log(f"Epoch: {epoch + 1}, Average loss: {epoch_loss:.4f}") - if (epoch + 1) % self.val_interval == 0: + checkpoint_output = [] + + if ( + (epoch + 1) % self.config.validation_interval == 0 + or epoch + 1 == self.config.max_epochs + ): model.eval() with torch.no_grad(): for val_data in val_loader: val_inputs, val_labels = ( - val_data["image"].to(self.device), - val_data["label"].to(self.device), + val_data["image"].to(device), + val_data["label"].to(device), ) val_outputs = model_class.get_validation( @@ -1347,23 +1419,34 @@ def train(self): post_label(res_tensor) for res_tensor in labs ] - # print(len(val_outputs)) - # print(len(val_labels)) + # logger.debug(len(val_outputs)) + # logger.debug(len(val_labels)) dice_metric(y_pred=val_outputs, y=val_labels) + checkpoint_output.append( + [res.detach().cpu() for res in val_outputs] + ) + + checkpoint_output = [ + item.numpy() + for batch in checkpoint_output + for item in batch + ] metric = dice_metric.aggregate().detach().item() dice_metric.reset() val_metric_values.append(metric) - train_report = { - "plot": True, - "epoch": epoch, - "losses": epoch_loss_values, - "val_metrics": val_metric_values, - "weights": model.state_dict(), - } + train_report = TrainingReport( + show_plot=True, + epoch=epoch, + loss_values=epoch_loss_values, + validation_metric=val_metric_values, + weights=model.state_dict(), + images=checkpoint_output, + ) + yield train_report weights_filename = ( @@ -1377,8 +1460,9 @@ def train(self): self.log("Saving best metric model") torch.save( model.state_dict(), - os.path.join( - self.results_path, weights_filename + Path(self.config.results_path_folder) + / Path( + weights_filename, ), ) self.log("Saving complete") @@ -1395,7 +1479,7 @@ def train(self): model.to("cpu") except Exception as e: - self.log(f"Error : {e}") + self.log(f"Error in training : {e}") self.quit() finally: self.quit() diff --git a/napari_cellseg3d/models/pretrained/__init__.py b/napari_cellseg3d/code_models/models/__init__.py similarity index 100% rename from napari_cellseg3d/models/pretrained/__init__.py rename to napari_cellseg3d/code_models/models/__init__.py diff --git a/napari_cellseg3d/models/model_SegResNet.py b/napari_cellseg3d/code_models/models/model_SegResNet.py similarity index 68% rename from napari_cellseg3d/models/model_SegResNet.py rename to napari_cellseg3d/code_models/models/model_SegResNet.py index ee1dc9a8..8856e18d 100644 --- a/napari_cellseg3d/models/model_SegResNet.py +++ b/napari_cellseg3d/code_models/models/model_SegResNet.py @@ -1,9 +1,9 @@ from monai.networks.nets import SegResNetVAE -def get_net(input_image_size, dropout_prob=None): +def get_net(input_image_size, out_channels=1, dropout_prob=0.3): return SegResNetVAE( - input_image_size, out_channels=1, dropout_prob=dropout_prob + input_image_size, out_channels=out_channels, dropout_prob=dropout_prob ) diff --git a/napari_cellseg3d/models/model_SwinUNetR.py b/napari_cellseg3d/code_models/models/model_SwinUNetR.py similarity index 100% rename from napari_cellseg3d/models/model_SwinUNetR.py rename to napari_cellseg3d/code_models/models/model_SwinUNetR.py diff --git a/napari_cellseg3d/models/model_TRAILMAP.py b/napari_cellseg3d/code_models/models/model_TRAILMAP.py similarity index 100% rename from napari_cellseg3d/models/model_TRAILMAP.py rename to napari_cellseg3d/code_models/models/model_TRAILMAP.py diff --git a/napari_cellseg3d/models/model_TRAILMAP_MS.py b/napari_cellseg3d/code_models/models/model_TRAILMAP_MS.py similarity index 85% rename from napari_cellseg3d/models/model_TRAILMAP_MS.py rename to napari_cellseg3d/code_models/models/model_TRAILMAP_MS.py index 1ee50158..d62fee26 100644 --- a/napari_cellseg3d/models/model_TRAILMAP_MS.py +++ b/napari_cellseg3d/code_models/models/model_TRAILMAP_MS.py @@ -1,4 +1,4 @@ -from napari_cellseg3d.models.unet.model import UNet3D +from napari_cellseg3d.code_models.models.unet.model import UNet3D def get_weights_file(): diff --git a/napari_cellseg3d/models/model_VNet.py b/napari_cellseg3d/code_models/models/model_VNet.py similarity index 100% rename from napari_cellseg3d/models/model_VNet.py rename to napari_cellseg3d/code_models/models/model_VNet.py diff --git a/napari_cellseg3d/code_models/models/model_test.py b/napari_cellseg3d/code_models/models/model_test.py new file mode 100644 index 00000000..5871c4a7 --- /dev/null +++ b/napari_cellseg3d/code_models/models/model_test.py @@ -0,0 +1,36 @@ +import torch +from torch import nn + + +def get_weights_file(): + return "test.pth" + + +class TestModel(nn.Module): + def __init__(self): + super().__init__() + self.linear = nn.Linear(1, 1) + + def forward(self, x): + return self.linear(torch.tensor(x, requires_grad=True)) + + def get_net(self): + return self + + def get_output(self, _, input): + return input + + def get_validation(self, val_inputs): + return val_inputs + + +# if __name__ == "__main__": +# +# model = TestModel() +# model.train() +# model.zero_grad() +# from napari_cellseg3d.config import WEIGHTS_DIR +# torch.save( +# model.state_dict(), +# WEIGHTS_DIR + f"/{get_weights_file()}" +# ) diff --git a/napari_cellseg3d/models/unet/__init__.py b/napari_cellseg3d/code_models/models/pretrained/__init__.py similarity index 100% rename from napari_cellseg3d/models/unet/__init__.py rename to napari_cellseg3d/code_models/models/pretrained/__init__.py diff --git a/napari_cellseg3d/models/pretrained/pretrained_model_urls.json b/napari_cellseg3d/code_models/models/pretrained/pretrained_model_urls.json similarity index 76% rename from napari_cellseg3d/models/pretrained/pretrained_model_urls.json rename to napari_cellseg3d/code_models/models/pretrained/pretrained_model_urls.json index c8c0b8e1..9331484b 100644 --- a/napari_cellseg3d/models/pretrained/pretrained_model_urls.json +++ b/napari_cellseg3d/code_models/models/pretrained/pretrained_model_urls.json @@ -2,5 +2,6 @@ "TRAILMAP_MS": "https://huggingface.co/C-Achard/cellseg3d/resolve/main/TRAILMAP_MS.tar.gz", "SegResNet": "https://huggingface.co/C-Achard/cellseg3d/resolve/main/SegResNet.tar.gz", "VNet": "https://huggingface.co/C-Achard/cellseg3d/resolve/main/VNet.tar.gz", - "SwinUNetR": "https://huggingface.co/C-Achard/cellseg3d/resolve/main/Swin64.tar.gz" -} + "SwinUNetR": "https://huggingface.co/C-Achard/cellseg3d/resolve/main/Swin64.tar.gz", + "test": "https://huggingface.co/C-Achard/cellseg3d/resolve/main/test.tar.gz" +} \ No newline at end of file diff --git a/napari_cellseg3d/conftest.py b/napari_cellseg3d/code_models/models/unet/__init__.py similarity index 100% rename from napari_cellseg3d/conftest.py rename to napari_cellseg3d/code_models/models/unet/__init__.py diff --git a/napari_cellseg3d/models/unet/buildingblocks.py b/napari_cellseg3d/code_models/models/unet/buildingblocks.py similarity index 100% rename from napari_cellseg3d/models/unet/buildingblocks.py rename to napari_cellseg3d/code_models/models/unet/buildingblocks.py diff --git a/napari_cellseg3d/models/unet/model.py b/napari_cellseg3d/code_models/models/unet/model.py similarity index 95% rename from napari_cellseg3d/models/unet/model.py rename to napari_cellseg3d/code_models/models/unet/model.py index 1be580c2..a31cc580 100644 --- a/napari_cellseg3d/models/unet/model.py +++ b/napari_cellseg3d/code_models/models/unet/model.py @@ -1,8 +1,12 @@ import torch.nn as nn -from napari_cellseg3d.models.unet.buildingblocks import create_decoders -from napari_cellseg3d.models.unet.buildingblocks import create_encoders -from napari_cellseg3d.models.unet.buildingblocks import DoubleConv +from napari_cellseg3d.code_models.models.unet.buildingblocks import ( + create_decoders, +) +from napari_cellseg3d.code_models.models.unet.buildingblocks import ( + create_encoders, +) +from napari_cellseg3d.code_models.models.unet.buildingblocks import DoubleConv def number_of_features_per_level(init_channel_number, num_levels): diff --git a/napari_cellseg3d/code_plugins/__init__.py b/napari_cellseg3d/code_plugins/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/napari_cellseg3d/code_plugins/plugin_base.py b/napari_cellseg3d/code_plugins/plugin_base.py new file mode 100644 index 00000000..8e0fab3c --- /dev/null +++ b/napari_cellseg3d/code_plugins/plugin_base.py @@ -0,0 +1,454 @@ +import warnings +from functools import partial +from pathlib import Path + +import napari + +# Qt +from qtpy.QtCore import qInstallMessageHandler +from qtpy.QtWidgets import QTabWidget +from qtpy.QtWidgets import QWidget + +# local +from napari_cellseg3d import interface as ui +from napari_cellseg3d import utils + +logger = utils.LOGGER + + +class BasePluginSingleImage(QTabWidget): + """A basic plugin template for working with **single images**""" + + def __init__( + self, + viewer: "napari.viewer.Viewer", + parent=None, + loads_images=True, + loads_labels=True, + has_results=True, + ): + """ + Creates a Base plugin with several buttons pre-defined + + Args: + viewer: napari viewer to display in + parent: parent QWidget. Defaults to None + loads_images: whether to show image IO widgets + loads_labels: whether to show labels IO widgets + has_results: whether to show results IO widgets + + """ + super().__init__(parent) + """Parent widget""" + self._viewer = viewer + """napari.viewer.Viewer: viewer in which the widget is displayed""" + + self.docked_widgets = [] + self.container_docked = False + + self.image_path = None + """str: path to image folder""" + self.show_image_io = loads_images + + self.label_path = None + """str: path to label folder""" + self.show_label_io = loads_labels + + self.results_path = None + """str: path to results folder""" + self.show_results_io = has_results + + self._default_path = [self.image_path, self.label_path] + + ################ + self.layer_choice = ui.RadioButton("Layer", parent=self) + self.folder_choice = ui.RadioButton("Folder", parent=self) + self.radio_buttons = ui.combine_blocks( + self.folder_choice, self.layer_choice + ) + self.io_panel = None # call self._build_io_panel to build + ################ + # Image widgets + self.image_filewidget = ui.FilePathWidget( + "Image path", self._show_dialog_images, self + ) + + self.image_layer_loader: ui.LayerSelecter = ui.LayerSelecter( + self._viewer, + name="Image :", + layer_type=napari.layers.Image, + parent=self, + ) + """LayerSelecter for images""" + ################ + # Label widgets + self.labels_filewidget = ui.FilePathWidget( + "Label path", self._show_dialog_labels, parent=self + ) + + self.label_layer_loader: ui.LayerSelecter = ui.LayerSelecter( + self._viewer, + name="Labels :", + layer_type=napari.layers.Labels, + parent=self, + ) + """LayerSelecter for labels""" + ################ + # Results widget + self.results_filewidget = ui.FilePathWidget( + "Saving path", self._load_results_path, parent=self + ) + + self.filetype_choice = ui.DropdownMenu( + [".tif", ".tiff"], label="File format" + ) + ######## + qInstallMessageHandler(ui.handle_adjust_errors_wrapper(self)) + + def enable_utils_menu(self): + """ + Enables the usage of the CTRL+right-click shortcut to the utilities. + Should only be used in "high-level" widgets (provided in napari Plugins menu) to avoid multiple activation + """ + viewer = self._viewer + + @viewer.mouse_drag_callbacks.append + def show_menu(_, event): + return ui.UtilsDropdown().dropdown_menu_call(self, event) + + def _build_io_panel(self): + self.io_panel = ui.GroupedWidget("Data") + + # self.io_panel.setToolTip("IO Panel") + + ui.add_widgets( + self.io_panel.layout, + [ + self.radio_buttons, + self.image_layer_loader, + self.label_layer_loader, + self.filetype_choice, + self.image_filewidget, + self.labels_filewidget, + self.results_filewidget, + ], + ) + self.io_panel.setLayout(self.io_panel.layout) + + # self._set_io_visibility() + return self.io_panel + + def _remove_unused(self): + if not self.show_label_io: + self.labels_filewidget = None + self.label_layer_loader = None + + if not self.show_image_io: + self.image_layer_loader = None + self.image_filewidget = None + + if not self.show_results_io: + self.results_filewidget = None + + def _set_io_visibility(self): + ################## + # Show when layer is selected + if self.show_image_io: + self._show_io_element(self.image_layer_loader, self.layer_choice) + else: + self._hide_io_element(self.image_layer_loader) + if self.show_label_io: + self._show_io_element(self.label_layer_loader, self.layer_choice) + else: + self._hide_io_element(self.label_layer_loader) + + ################## + # Show when folder is selected + f = self.folder_choice + + self._show_io_element(self.filetype_choice, f) + if self.show_image_io: + self._show_io_element(self.image_filewidget, f) + else: + self._hide_io_element(self.image_filewidget) + if self.show_label_io: + self._show_io_element(self.labels_filewidget, f) + else: + self._hide_io_element(self.labels_filewidget) + if not self.show_results_io: + self._hide_io_element(self.results_filewidget) + + self.folder_choice.toggle() + self.layer_choice.toggle() + + @staticmethod + def _show_io_element(widget: QWidget, toggle: QWidget = None): + """ + Args: + widget: Widget to be shown or hidden + toggle: Toggle to be used to determine whether widget should be shown (Checkbox or RadioButton) + """ + widget.setVisible(True) + + if toggle is not None: + toggle.toggled.connect( + partial(ui.toggle_visibility, toggle, widget) + ) + + @staticmethod + def _hide_io_element(widget: QWidget, toggle: QWidget = None): + """ + Attempts to disconnect widget from toggle and hide it. + Args: + widget: Widget to be hidden + toggle: Toggle to be disconnected from widget, if any + """ + + if toggle is not None: + try: + toggle.toggled.disconnect() + except TypeError: + logger.warning( + "Warning: no method was found to disconnect from widget visibility" + ) + + widget.setVisible(False) + + def _build(self): + """Method to be defined by children classes""" + raise NotImplementedError("To be defined in child classes") + + def _show_filetype_choice(self): + """Method to show/hide the filetype choice when "loading as folder" is (de)selected""" + show = self.load_as_stack_choice.isChecked() + if show is not None: + self.filetype_choice.setVisible(show) + # self.lbl_ft.setVisible(show) + + def _show_file_dialog(self): + """Open file dialog and process path depending on single file/folder loading behaviour""" + if self.load_as_stack_choice.isChecked(): + folder = ui.open_folder_dialog( + self, + self._default_path, + filetype=f"Image file (*{self.filetype_choice.currentText()})", + ) + return folder + else: + f_name = ui.open_file_dialog(self, self._default_path) + f_name = str(f_name[0]) + self.filetype = str(Path(f_name).suffix) + return f_name + + def _show_dialog_images(self): + """Show file dialog and set image path""" + f_name = self._show_file_dialog() + if type(f_name) is str and f_name != "": + self.image_path = f_name + self.image_filewidget.text_field.setText(self.image_path) + self._update_default() + + def _show_dialog_labels(self): + """Show file dialog and set label path""" + f_name = self._show_file_dialog() + if isinstance(f_name, str) and f_name != "": + self.label_path = f_name + self.labels_filewidget.text_field.setText(self.label_path) + self._update_default() + + def _check_results_path(self, folder): + if folder != "" and isinstance(folder, str): + if not Path(folder).is_dir(): + Path(folder).mkdir(parents=True, exist_ok=True) + if not Path(folder).is_dir(): + return False + logger.info(f"Created missing results folder : {folder}") + return True + return False + + def _load_results_path(self): + """Show file dialog to set :py:attr:`~results_path`""" + folder = ui.open_folder_dialog(self, self._default_path) + + if self._check_results_path(folder): + self.results_path = folder + # logger.debug(f"Results path : {self.results_path}") + self.results_filewidget.text_field.setText(self.results_path) + self._update_default() + + def _update_default(self): + """Updates default path for smoother navigation when opening file dialogs""" + self._default_path = [ + self.image_path, + self.label_path, + self.results_path, + ] + + def _make_close_button(self): + btn = ui.Button("Close", self.remove_from_viewer) + btn.setToolTip( + "Close the window and all docked widgets. Make sure to save your work !" + ) + return btn + + def _make_prev_button(self): + btn = ui.Button( + "Previous", lambda: self.setCurrentIndex(self.currentIndex() - 1) + ) + return btn + + def _make_next_button(self): + btn = ui.Button( + "Next", lambda: self.setCurrentIndex(self.currentIndex() + 1) + ) + return btn + + def remove_from_viewer(self): + """Removes the widget from the napari window. + Can be re-implemented in children classes if needed""" + + self.remove_docked_widgets() + self._viewer.window.remove_dock_widget(self) + + def remove_docked_widgets(self): + """Removes all docked widgets from napari window""" + try: + if len(self.docked_widgets) != 0: + [ + self._viewer.window.remove_dock_widget(w) + for w in self.docked_widgets + if w is not None + ] + self.docked_widgets = [] + self.container_docked = False + return True + except LookupError: + return False + + +class BasePluginFolder(BasePluginSingleImage): + """A basic plugin template for working with **folders of images**""" + + def __init__( + self, + viewer: "napari.viewer.Viewer", + parent=None, + loads_images=True, + loads_labels=True, + has_results=True, + ): + """Creates a plugin template with the following widgets defined but not added in a layout : + + * A button to load a folder of images + + * A button to load a folder of labels + + * A button to set a results folder + + * A dropdown menu to select the file extension to be loaded from the folders""" + super().__init__( + viewer, parent, loads_images, loads_labels, has_results + ) + + self.images_filepaths = [] + """array(str): paths to images for training or inference""" + self.labels_filepaths = [] + """array(str): paths to labels for training""" + self.results_path = None + """str: path to output folder,to save results in""" + + self._default_folders = [None] + """Update defaults from PluginBaseFolder with model_path""" + + self.docked_widgets = [] + """List of docked widgets (returned by :py:func:`viewer.window.add_dock_widget())`, + can be used to remove docked widgets""" + + ####################################################### + # interface + # self.image_filewidget = ui.FilePathWidget( + # "Images directory", self.load_image_dataset, self + # ) + self.image_filewidget.text_field = "Images directory" + self.image_filewidget.button.clicked.disconnect( + self._show_dialog_images + ) + self.image_filewidget.button.clicked.connect(self.load_image_dataset) + + # self.labels_filewidget = ui.FilePathWidget( + # "Labels directory", self.load_label_dataset, self + # ) + self.labels_filewidget.text_field = "Labels directory" + self.labels_filewidget.button.clicked.disconnect( + self._show_dialog_labels + ) + self.labels_filewidget.button.clicked.connect(self.load_label_dataset) + + # self.filetype_choice = ui.DropdownMenu( + # [".tif", ".tiff"], label="File format" + # ) + """Allows to choose which file will be loaded from folder""" + ####################################################### + # self._set_io_visibility() + + def load_dataset_paths(self): + """Loads all image paths (as str) in a given folder for which the extension matches the set filetype + + Returns: + array(str): all loaded file paths + """ + filetype = self.filetype_choice.currentText() + directory = ui.open_folder_dialog(self, self._default_folders) + + file_paths = sorted(Path(directory).glob("*" + filetype)) + if len(file_paths) == 0: + warnings.warn( + f"The folder does not contain any compatible {filetype} files.\n" + f"Please check the validity of the folder and images." + ) + + return file_paths + + def load_image_dataset(self): + """Show file dialog to set :py:attr:`~images_filepaths`""" + filenames = self.load_dataset_paths() + logger.debug(f"image filenames : {filenames}") + if filenames: + self.images_filepaths = [str(path) for path in sorted(filenames)] + path = str(Path(filenames[0]).parent) + self.image_filewidget.text_field.setText(path) + self.image_filewidget.check_ready() + self._update_default() + + def load_label_dataset(self): + """Show file dialog to set :py:attr:`~labels_filepaths`""" + filenames = self.load_dataset_paths() + logger.debug(f"labels filenames : {filenames}") + if filenames: + self.labels_filepaths = [str(path) for path in sorted(filenames)] + path = str(Path(filenames[0]).parent) + self.labels_filewidget.text_field.setText(path) + self.labels_filewidget.check_ready() + self._update_default() + + def _update_default(self): + """Update default path for smoother file dialogs""" + if len(self.images_filepaths) != 0: + from_images = str(Path(self.images_filepaths[0]).parent) + else: + from_images = None + + if len(self.labels_filepaths) != 0: + from_labels = str(Path(self.labels_filepaths[0]).parent) + else: + from_labels = None + + self._default_folders = [ + path + for path in [ + from_images, + from_labels, + self.results_path, + ] + if (path != [] and path is not None) + ] diff --git a/napari_cellseg3d/code_plugins/plugin_convert.py b/napari_cellseg3d/code_plugins/plugin_convert.py new file mode 100644 index 00000000..37be03c8 --- /dev/null +++ b/napari_cellseg3d/code_plugins/plugin_convert.py @@ -0,0 +1,721 @@ +import warnings +from pathlib import Path + +import napari +import numpy as np +from qtpy.QtWidgets import QSizePolicy +from qtpy.QtWidgets import QWidget +from tifffile import imread +from tifffile import imwrite + +import napari_cellseg3d.interface as ui +from napari_cellseg3d import config +from napari_cellseg3d import utils +from napari_cellseg3d.code_models.model_instance_seg import clear_small_objects +from napari_cellseg3d.code_models.model_instance_seg import threshold +from napari_cellseg3d.code_models.model_instance_seg import to_semantic +from napari_cellseg3d.code_plugins.plugin_base import BasePluginFolder + +# TODO break down into multiple mini-widgets +# TODO create parent class for utils modules to avoid duplicates + +MAX_W = 200 +MAX_H = 1000 + +logger = utils.LOGGER + + +def save_folder(results_path, folder_name, images, image_paths): + """ + Saves a list of images in a folder + + Args: + results_path: Path to the folder containing results + folder_name: Name of the folder containing results + images: List of images to save + image_paths: list of filenames of images + """ + results_folder = results_path / Path(folder_name) + results_folder.mkdir(exist_ok=False) + + for file, image in zip(image_paths, images): + path = results_folder / Path(file).name + + imwrite( + path, + image, + ) + logger.info(f"Saved processed folder as : {results_folder}") + + +def save_layer(results_path, image_name, image): + """ + Saves an image layer at the specified path + + Args: + results_path: path to folder containing result + image_name: image name for saving + image: data array containing image + + Returns: + + """ + path = str(results_path / Path(image_name)) # TODO flexible filetype + logger.info(f"Saved as : {path}") + imwrite(path, image) + + +def show_result(viewer, layer, image, name): + """ + Adds layers to a viewer to show result to user + + Args: + viewer: viewer to add layer in + layer: type of the original layer the operation was run on, to determine whether it should be an Image or Labels layer + image: the data array containing the image + name: name of the added layer + + Returns: + + """ + if isinstance(layer, napari.layers.Image): + logger.debug("Added resulting image layer") + viewer.add_image(image, name=name) + elif isinstance(layer, napari.layers.Labels): + logger.debug("Added resulting label layer") + viewer.add_labels(image, name=name) + else: + warnings.warn( + f"Results not shown, unsupported layer type {type(layer)}" + ) + + +class AnisoUtils(BasePluginFolder): + """Class to correct anisotropy in images""" + + def __init__(self, viewer: "napari.Viewer.viewer", parent=None): + """ + Creates a AnisoUtils widget + + Args: + viewer: viewer in which to process data + parent: parent widget + """ + super().__init__( + viewer, + parent, + loads_labels=False, + ) + + self.data_panel = self._build_io_panel() + + self.image_layer_loader.layer_list.label.setText("Layer :") + self.image_layer_loader.set_layer_type(napari.layers.Layer) + + self.aniso_widgets = ui.AnisotropyWidgets(self, always_visible=True) + self.start_btn = ui.Button("Start", self._start) + + self.results_path = Path.home() / Path("cellseg3d/anisotropy") + self.results_filewidget.text_field.setText(str(self.results_path)) + self.results_filewidget.check_ready() + + self._build() + + def _build(self): + + container = ui.ContainerWidget() + + ui.add_widgets( + container.layout, + [ + self.data_panel, + self.aniso_widgets, + self.start_btn, + ], + ) + + ui.ScrollArea.make_scrollable( + container.layout, + self, + max_wh=[MAX_W, MAX_H], # , min_wh=[100, 200], base_wh=[100, 200] + ) + + self._set_io_visibility() + self.setSizePolicy( + QSizePolicy.MinimumExpanding, QSizePolicy.MinimumExpanding + ) + + def _start(self): + + self.results_path.mkdir(exist_ok=True) + zoom = self.aniso_widgets.scaling_zyx() + + if self.layer_choice.isChecked(): + if self.image_layer_loader.layer_data() is not None: + layer = self.image_layer_loader.layer() + + data = np.array(layer.data, dtype=np.int16) + isotropic_image = utils.resize(data, zoom) + + save_layer( + self.results_path, + f"isotropic_{layer.name}_{utils.get_date_time()}.tif", + isotropic_image, + ) + show_result( + self._viewer, + layer, + isotropic_image, + f"isotropic_{layer.name}", + ) + + elif self.folder_choice.isChecked(): + if len(self.images_filepaths) != 0: + images = [ + utils.resize(np.array(imread(file), dtype=np.int16), zoom) + for file in self.images_filepaths + ] + save_folder( + self.results_path, + f"isotropic_results_{utils.get_date_time()}", + images, + self.images_filepaths, + ) + + +class RemoveSmallUtils(BasePluginFolder): + """ + Widget to remove small objects + """ + + def __init__(self, viewer: "napari.viewer.Viewer", parent=None): + """ + Creates a RemoveSmallUtils widget + + Args: + viewer: viewer in which to process data + parent: parent widget + """ + super().__init__( + viewer, + parent, + loads_labels=False, + ) + + self.data_panel = self._build_io_panel() + + self.image_layer_loader.layer_list.label.setText("Layer :") + self.image_layer_loader.set_layer_type(napari.layers.Layer) + + self.start_btn = ui.Button("Start", self._start) + self.size_for_removal_counter = ui.IntIncrementCounter( + lower=1, + upper=100000, + default=10, + label="Remove all smaller than (pxs):", + ) + + self.results_path = Path.home() / Path("cellseg3d/small_removed") + self.results_filewidget.text_field.setText(str(self.results_path)) + self.results_filewidget.check_ready() + + self.container = self._build() + + self.function = clear_small_objects + + def _build(self): + + container = ui.ContainerWidget() + + ui.add_widgets( + self.data_panel.layout, + [ + self.size_for_removal_counter.label, + self.size_for_removal_counter, + self.start_btn, + ], + ) + container.layout.addWidget(self.data_panel) + + ui.ScrollArea.make_scrollable( + container.layout, self, max_wh=[MAX_W, MAX_H] + ) + self._set_io_visibility() + container.setSizePolicy( + QSizePolicy.MinimumExpanding, QSizePolicy.MinimumExpanding + ) + return container + + def _start(self): + self.results_path.mkdir(exist_ok=True) + remove_size = self.size_for_removal_counter.value() + + if self.layer_choice: + if self.image_layer_loader.layer_data() is not None: + layer = self.image_layer_loader.layer() + + data = np.array(layer.data, dtype=np.int16) + removed = self.function(data, remove_size) + + save_layer( + self.results_path, + f"cleared_{layer.name}_{utils.get_date_time()}.tif", + removed, + ) + show_result( + self._viewer, layer, removed, f"cleared_{layer.name}" + ) + elif self.folder_choice.isChecked(): + if len(self.images_filepaths) != 0: + images = [ + clear_small_objects(file, remove_size, is_file_path=True) + for file in self.images_filepaths + ] + save_folder( + self.results_path, + f"small_removed_results_{utils.get_date_time()}", + images, + self.images_filepaths, + ) + return + + +class ToSemanticUtils(BasePluginFolder): + """ + Widget to create semantic labels from instance labels + """ + + def __init__(self, viewer: "napari.viewer.Viewer", parent=None): + """ + Creates a ToSemanticUtils widget + + Args: + viewer: viewer in which to process data + parent: parent widget + """ + super().__init__( + viewer, + parent, + loads_images=False, + ) + + self.data_panel = self._build_io_panel() + + self.start_btn = ui.Button("Start", self._start) + + self.results_path = Path.home() / Path("cellseg3d/threshold") + self.results_filewidget.text_field.setText(str(self.results_path)) + self.results_filewidget.check_ready() + + self._build() + + def _build(self): + + container = ui.ContainerWidget() + + ui.add_widgets( + self.data_panel.layout, + [ + self.start_btn, + ], + ) + container.layout.addWidget(self.data_panel) + + ui.ScrollArea.make_scrollable( + container.layout, self, max_wh=[MAX_W, MAX_H] + ) + self._set_io_visibility() + container.setSizePolicy( + QSizePolicy.MinimumExpanding, QSizePolicy.MinimumExpanding + ) + + def _start(self): + Path(self.results_path).mkdir(exist_ok=True) + + if self.layer_choice: + if self.label_layer_loader.layer_data() is not None: + layer = self.label_layer_loader.layer() + + data = np.array(layer.data, dtype=np.int16) + semantic = to_semantic(data) + + save_layer( + self.results_path, + f"semantic_{layer.name}_{utils.get_date_time()}.tif", + semantic, + ) + show_result( + self._viewer, layer, semantic, f"semantic_{layer.name}" + ) + elif self.folder_choice.isChecked(): + if len(self.images_filepaths) != 0: + images = [ + to_semantic(file, is_file_path=True) + for file in self.images_filepaths + ] + save_folder( + self.results_path, + f"semantic_results_{utils.get_date_time()}", + images, + self.images_filepaths, + ) + + +class InstanceWidgets(QWidget): + """ + Base widget with several sliders, for use in instance segmentation parameters + """ + + def __init__(self, parent=None): + """ + Creates an InstanceWidgets widget + + Args: + parent: parent widget + """ + super().__init__(parent) + + self.method_choice = ui.DropdownMenu( + config.INSTANCE_SEGMENTATION_METHOD_LIST.keys() + ) + self._method = config.INSTANCE_SEGMENTATION_METHOD_LIST[ + self.method_choice.currentText() + ] + + self.method_choice.currentTextChanged.connect(self._show_connected) + self.method_choice.currentTextChanged.connect(self._show_watershed) + + self.threshold_slider1 = ui.Slider( + lower=0, + upper=100, + default=50, + divide_factor=100.0, + step=5, + text_label="Probability threshold :", + ) + """Base prob. threshold""" + self.threshold_slider2 = ui.Slider( + lower=0, + upper=100, + default=90, + divide_factor=100.0, + step=5, + text_label="Probability threshold (seeding) :", + ) + """Second prob. thresh. (seeding)""" + + self.counter1 = ui.IntIncrementCounter( + upper=100, + default=10, + step=5, + label="Small object removal (pxs) :", + ) + """Small obj. rem.""" + + self.counter2 = ui.IntIncrementCounter( + upper=100, + default=3, + step=5, + label="Small seed removal (pxs) :", + ) + """Small seed rem.""" + + self._build() + + def run_method(self, volume): + """ + Calls instance function with chosen parameters + Args: + volume: image data to run method on + + Returns: processed image from self._method + """ + return self._method( + volume, + self.threshold_slider1.slider_value, + self.counter1.value(), + self.threshold_slider2.slider_value, + self.counter2.value(), + ) + + def _build(self): + + group = ui.GroupedWidget("Instance segmentation") + + ui.add_widgets( + group.layout, + [ + self.method_choice, + self.threshold_slider1.container, + self.threshold_slider2.container, + self.counter1.label, + self.counter1, + self.counter2.label, + self.counter2, + ], + ) + + self.setLayout(group.layout) + self._set_tooltips() + + def _set_tooltips(self): + + self.method_choice.setToolTip( + "Choose which method to use for instance segmentation" + "\nConnected components : all separated objects will be assigned an unique ID. " + "Robust but will not work correctly with adjacent/touching objects\n" + "Watershed : assigns objects ID based on the probability gradient surrounding an object. " + "Requires the model to surround objects in a gradient;" + " can possibly correctly separate unique but touching/adjacent objects." + ) + self.threshold_slider1.tooltips = ( + "All objects below this probability will be ignored (set to 0)" + ) + self.counter1.setToolTip( + "Will remove all objects smaller (in volume) than the specified number of pixels" + ) + self.threshold_slider2.tooltips = ( + "All seeds below this probability will be ignored (set to 0)" + ) + self.counter2.setToolTip( + "Will remove all seeds smaller (in volume) than the specified number of pixels" + ) + + def _show_watershed(self): + name = "Watershed" + if self.method_choice.currentText() == name: + + self._show_slider1() + self._show_slider2() + self._show_counter1() + self._show_counter2() + + self._method = config.INSTANCE_SEGMENTATION_METHOD_LIST[name] + + def _show_connected(self): + name = "Connected components" + if self.method_choice.currentText() == name: + + self._show_slider1() + self._show_slider2(False) + self._show_counter1() + self._show_counter2(False) + + self._method = config.INSTANCE_SEGMENTATION_METHOD_LIST[name] + + def _show_slider1(self, is_visible: bool = True): + self.threshold_slider1.container.setVisible(is_visible) + + def _show_slider2(self, is_visible: bool = True): + self.threshold_slider2.container.setVisible(is_visible) + + def _show_counter1(self, is_visible: bool = True): + self.counter1.setVisible(is_visible) + self.counter1.label.setVisible(is_visible) + + def _show_counter2(self, is_visible: bool = True): + self.counter2.setVisible(is_visible) + self.counter2.label.setVisible(is_visible) + + +class ToInstanceUtils(BasePluginFolder): + """ + Widget to convert semantic labels to instance labels + """ + + def __init__(self, viewer: "napari.viewer.Viewer", parent=None): + """ + Creates a ToInstanceUtils widget + + Args: + viewer: viewer in which to process data + parent: parent widget + """ + super().__init__( + viewer, + parent, + loads_images=False, + ) + + self.data_panel = self._build_io_panel() + self.label_layer_loader.set_layer_type(napari.layers.Layer) + + self.instance_widgets = InstanceWidgets() + + self.start_btn = ui.Button("Start", self._start) + + self.results_path = Path.home() / Path("cellseg3d/instance") + self.results_filewidget.text_field.setText(str(self.results_path)) + self.results_filewidget.check_ready() + + self._build() + + def _build(self): + + container = ui.ContainerWidget() + + ui.add_widgets( + container.layout, + [ + self.data_panel, + self.instance_widgets, + ], + ) + + ui.add_widgets(self.instance_widgets.layout(), [self.start_btn]) + + ui.ScrollArea.make_scrollable( + container.layout, self, max_wh=[MAX_W, MAX_H] + ) + self._set_io_visibility() + container.setSizePolicy( + QSizePolicy.MinimumExpanding, QSizePolicy.MinimumExpanding + ) + + def _start(self): + self.results_path.mkdir(exist_ok=True) + + if self.layer_choice: + if self.label_layer_loader.layer_data() is not None: + layer = self.label_layer_loader.layer() + + data = np.array(layer.data, dtype=np.int16) + instance = self.instance_widgets.run_method(data) + + save_layer( + self.results_path, + f"instance_{layer.name}_{utils.get_date_time()}.tif", + instance, + ) + self._viewer.add_labels( + instance, name=f"instance_{layer.name}" + ) + + elif self.folder_choice.isChecked(): + if len(self.images_filepaths) != 0: + images = [ + self.instance_widgets.run_method(imread(file)) + for file in self.images_filepaths + ] + save_folder( + self.results_path, + f"instance_results_{utils.get_date_time()}", + images, + self.images_filepaths, + ) + + +class ThresholdUtils(BasePluginFolder): + """ + Creates a ThresholdUtils widget + Args: + viewer: viewer in which to process data + parent: parent widget + """ + + def __init__(self, viewer: "napari.viewer.Viewer", parent=None): + + super().__init__( + viewer, + parent, + loads_labels=False, + ) + + self.data_panel = self._build_io_panel() + self._set_io_visibility() + + self.image_layer_loader.layer_list.label.setText("Layer :") + self.image_layer_loader.set_layer_type(napari.layers.Layer) + + self.start_btn = ui.Button("Start", self._start) + self.binarize_counter = ui.DoubleIncrementCounter( + lower=0.0, + upper=100000.0, + step=0.5, + default=10.0, + label="Remove all smaller than (value):", + ) + + self.results_path = Path.home() / Path("cellseg3d/threshold") + self.results_filewidget.text_field.setText(str(self.results_path)) + self.results_filewidget.check_ready() + + self.container = self._build() + + self.function = threshold + + def _build(self): + + container = ui.ContainerWidget() + + ui.add_widgets( + self.data_panel.layout, + [ + self.binarize_counter.label, + self.binarize_counter, + self.start_btn, + ], + ) + container.layout.addWidget(self.data_panel) + + ui.ScrollArea.make_scrollable( + container.layout, self, max_wh=[MAX_W, MAX_H] + ) + self._set_io_visibility() + container.setSizePolicy( + QSizePolicy.MinimumExpanding, QSizePolicy.MinimumExpanding + ) + + return container + + def _start(self): + self.results_path.mkdir(exist_ok=True) + remove_size = self.binarize_counter.value() + + if self.layer_choice: + if self.image_layer_loader.layer_data() is not None: + layer = self.image_layer_loader.layer() + + data = np.array(layer.data, dtype=np.int16) + removed = self.function(data, remove_size) + + save_layer( + self.results_path, + f"threshold_{layer.name}_{utils.get_date_time()}.tif", + removed, + ) + show_result( + self._viewer, layer, removed, f"threshold{layer.name}" + ) + elif self.folder_choice.isChecked(): + if len(self.images_filepaths) != 0: + images = [ + self.function(imread(file), remove_size) + for file in self.images_filepaths + ] + save_folder( + self.results_path, + f"threshold_results_{utils.get_date_time()}", + images, + self.images_filepaths, + ) + + +# class ConvertUtils(BasePluginFolder): +# """Utility widget that allows to convert labels from instance to semantic and the reverse.""" +# +# def __init__(self, viewer: "napari.viewer.Viewer", parent): +# """Builds a ConvertUtils widget with the following buttons: +# +# * A button to convert a folder of labels to semantic labels +# +# * A button to convert a folder of labels to instance labels +# +# * A button to convert a currently selected layer to semantic labels +# +# * A button to convert a currently selected layer to instance labels +# """ +# +# super().__init__(viewer, parent) +# self._viewer = viewer +# pass diff --git a/napari_cellseg3d/code_plugins/plugin_crop.py b/napari_cellseg3d/code_plugins/plugin_crop.py new file mode 100644 index 00000000..23ba190c --- /dev/null +++ b/napari_cellseg3d/code_plugins/plugin_crop.py @@ -0,0 +1,643 @@ +import warnings +from pathlib import Path + +import napari +import numpy as np +from magicgui import magicgui + +# Qt +from qtpy.QtWidgets import QSizePolicy + +# local +from napari_cellseg3d import interface as ui +from napari_cellseg3d import utils +from napari_cellseg3d.code_plugins.plugin_base import BasePluginSingleImage + +DEFAULT_CROP_SIZE = 64 +logger = utils.LOGGER + + +class Cropping(BasePluginSingleImage): + """A utility plugin for cropping 3D volumes.""" + + def __init__(self, viewer: "napari.viewer.Viewer", parent=None): + """Creates a Cropping plugin with several buttons : + + * Open file prompt to select volumes directory + + * Open file prompt to select labels directory + + * A dropdown menu with a choice of png or tif filetypes + + * Three spinboxes to choose the dimensions of the cropped volume in x, y, z + + * A button to launch the cropping process (see :doc:`plugin_crop`) + + * A button to close the widget + """ + + super().__init__(viewer) + self.docked_widgets = [] # TODO add remove on close + self.results_path = Path.home() / Path("cellseg3d/cropped") + + self.btn_start = ui.Button("Start", self._start) + + self.image_layer_loader.set_layer_type(napari.layers.Layer) + self.image_layer_loader.layer_list.label.setText("Image 1") + # ui.LayerSelecter(self._viewer, "Image 1") + # self.layer_selection2 = ui.LayerSelecter(self._viewer, "Image 2") + self.label_layer_loader.set_layer_type(napari.layers.Layer) + self.label_layer_loader.layer_list.label.setText("Image 2") + + self.crop_second_image_choice = ui.CheckBox( + "Crop another\nimage simultaneously", + ) + self.crop_second_image_choice.toggled.connect( + self._toggle_second_image_io_visibility + ) + self.crop_second_image_choice.toggled.connect(self._check_image_list) + + self.create_new_layer = ui.CheckBox("Create new layers") + self.create_new_layer.setToolTip( + 'Use this to create a new layer everytime you start cropping, so you can "zoom in" your volume' + ) + + self._viewer.layers.events.inserted.connect(self._check_image_list) + # TODO(cyril) : fix layer removal (issue with cropping layer? ) + self.folder_choice.clicked.connect( + self._toggle_second_image_io_visibility + ) + self.layer_choice.clicked.connect( + self._toggle_second_image_io_visibility + ) + + # self.results_filewidget = ui.FilePathWidget( + # "Results path", + # self._load_results_path, + # default=str(self.results_path), + # ) + # self.results_filewidget.tooltips = str(self.results_path) + self.results_filewidget.text_field.setText(str(self.results_path)) + self.results_filewidget.check_ready() + + self.crop_size_widgets = ui.IntIncrementCounter.make_n( + 3, 1, 1000, DEFAULT_CROP_SIZE + ) + self.crop_size_labels = [ + ui.make_label("Size in " + axis + " of cropped volume :", self) + for axis in "zyx" + ] + + self.aniso_widgets = ui.AnisotropyWidgets(self) + ########### + for box in self.crop_size_widgets: + box.setSizePolicy(QSizePolicy.Fixed, QSizePolicy.Fixed) + + self._x = 0 + self._y = 0 + self._z = 0 + + self._crop_size_x = DEFAULT_CROP_SIZE + self._crop_size_y = DEFAULT_CROP_SIZE + self._crop_size_z = DEFAULT_CROP_SIZE + + self.aniso_factors = [1, 1, 1] + + self.image_layer1 = None + self.image_layer2 = None + + self.im1_crop_layer = None + self.im2_crop_layer = None + + self.crop_second_image = False + + self._build() + self._toggle_second_image_io_visibility() + + def _toggle_second_image_io_visibility(self): + crop_2nd = self.crop_second_image_choice.isChecked() + if self.layer_choice.isChecked(): + self.label_layer_loader.setVisible(crop_2nd) + elif self.folder_choice.isChecked(): + self.labels_filewidget.setVisible(crop_2nd) + + def _check_image_list(self): + + l1 = self.image_layer_loader.layer_list + l2 = self.label_layer_loader.layer_list + + if l1.currentText() == l2.currentText(): + try: + for i in range(l1.count()): + if l1.itemText(i) != l2.currentText(): + l2.setCurrentIndex(i) + except IndexError: + return + + def _build(self): + """Build buttons in a layout and add them to the napari Viewer""" + + container = ui.ContainerWidget(0, 0, 1, 11) + layout = container.layout + + io_panel = self._build_io_panel() + + ui.add_widgets( + layout, + [io_panel, self.crop_second_image_choice], + ) + self.label_layer_loader.setVisible(False) + self.radio_buttons.setVisible( + False + ) # TODO(cyril) : remove code related to folders as it is deprecated here + ###################### + ui.add_blank(self, layout) + ###################### + dim_group_w, dim_group_l = ui.make_group("Dimensions") + + dim_group_l.addWidget(self.create_new_layer) + dim_group_l.addWidget(self.aniso_widgets) + [ + dim_group_l.addWidget(widget, alignment=ui.ABS_AL) + for list in zip(self.crop_size_labels, self.crop_size_widgets) + for widget in list + ] + dim_group_w.setLayout(dim_group_l) + layout.addWidget(dim_group_w) + ##################### + ##################### + ui.add_blank(self, layout) + ##################### + ##################### + ui.add_widgets( + layout, + [ + self.btn_start, + ], + ) + + ui.ScrollArea.make_scrollable(layout, self, min_wh=[200, 400]) + self.setSizePolicy(QSizePolicy.Fixed, QSizePolicy.Expanding) + self._set_io_visibility() + + # def _check_results_path(self, folder): + # if folder != "" and isinstance(folder, str): + # if not Path(folder).is_dir(): + # Path(folder).mkdir(parents=True, exist_ok=True) + # if not Path(folder).is_dir(): + # return False + # logger.debug(f"Created missing results folder : {folder}") + # return True + # return False + + # def _load_results_path(self): + # """Show file dialog to set :py:attr:`~results_path`""" + # folder = ui.open_folder_dialog(self, str(self.results_path)) + # + # if self._check_results_path(folder): + # self.results_path = Path(folder) + # logger.debug(f"Results path : {self.results_path}") + # self.results_filewidget.text_field.setText(str(self.results_path)) + + def quicksave(self): + """Quicksaves the cropped volume in the folder from which they originate, with their original file extension. + + * If images are present, saves the cropped version as a single file or image stacks folder depending on what was loaded. + + * If labels are present, saves the cropped version as a single file or 2D stacks folder depending on what was loaded. + """ + + viewer = self._viewer + + self._check_results_path(str(self.results_path)) + time = utils.get_date_time() + + im1_path = str( + self.results_path + / Path("cropped_" + self.image_layer1.name + time) + ) + + viewer.layers[f"cropped_{self.image_layer1.name}"].save(im1_path) + + logger.info(f"Image 1 saved as: {im1_path}") + + if self.crop_second_image: + im2_path = str( + self.results_path + / Path("cropped_" + self.image_layer2.name + time) + ) + + viewer.layers[f"cropped_{self.image_layer2.name}"].save(im2_path) + + logger.info(f"Image 2 saved as: {im2_path}") + + def _check_ready(self): + + if self.image_layer_loader.layer_data() is not None: + if self.crop_second_image: + if self.label_layer_loader.layer_data() is not None: + return True + else: + return False + return True + return False + + def _start(self): + """Launches cropping process by loading the files from the chosen folders, + and adds control widgets to the napari Viewer for moving the cropped volume. + """ + # TODO maybe implement proper reset function so multiple runs can be done without closing napari + # maybe use singletons or make docked widgets attributes that are hidden upon opening + + if not self._check_ready(): + warnings.warn("Please select at least one valid layer !") + return + + # self._viewer.window.remove_dock_widget(self.parent()) # no need to close utils ? + self.remove_docked_widgets() + + if not self.create_new_layer.isChecked(): + try: + if self.im1_crop_layer is not None: + self._viewer.layers.remove(self.im1_crop_layer) + if self.im2_crop_layer is not None: + self._viewer.layers.remove(self.im2_crop_layer) + except ValueError as e: + logger.warning(e) + logger.warning( + "Could not remove cropping layer programmatically!" + ) + logger.warning("Maybe layer has been removed by user?") + + self.results_path = Path(self.results_filewidget.text_field.text()) + + self.crop_second_image = self.crop_second_image_choice.isChecked() + + if self.aniso_widgets.enabled(): + self.aniso_factors = self.aniso_widgets.scaling_zyx() + + self.image_layer1 = self.image_layer_loader.layer() + + if len(self.image_layer1.data) > 3: + self.image_layer1.data = np.squeeze(self.image_layer1.data) + + if self.crop_second_image: + self.image_layer2 = self.label_layer_loader.layer() + + if len(self.image_layer2.data.shape) > 3: + self.image_layer2.data = np.squeeze( + self.image_layer2.data + ) # if channel/batch remnants from MONAI + + vw = self._viewer + + vw.dims.ndisplay = 3 + vw.scale_bar.visible = True + + if self.aniso_widgets.enabled(): + for layer in vw.layers: + layer.visible = False + # hide other layers, because of anisotropy + + self.image_layer1 = self.add_isotropic_layer(self.image_layer1) + + if self.crop_second_image: + self.image_layer2 = self.add_isotropic_layer( + self.image_layer2, visible=False + ) + else: + self.image_layer1.opacity = 0.7 + self.image_layer1.colormap = "inferno" + self.image_layer1.contrast_limits = [200, 1000] # TODO generalize + + self.image_layer1.refresh() + + if self.crop_second_image: + self.image_layer2.opacity = 0.7 + self.image_layer2.visible = False + + self.image_layer2.refresh() + + @magicgui(call_button="Quicksave") # TODO move to Qt + def save_widget(): + return self.quicksave() + + save = self._viewer.window.add_dock_widget( + save_widget, name="Quicksave", area="left" + ) + save._close_btn = False + self.docked_widgets.append(save) + + self._add_crop_sliders() + + def add_isotropic_layer( + self, + layer, + colormap="inferno", + contrast_lim=[200, 1000], # TODO generalize ? + opacity=0.7, + visible=True, + ): + logger.debug(layer.name) + + if isinstance(layer, napari.layers.Image): + layer = self._viewer.add_image( + layer.data, + name=f"Scaled_{layer.name}", + colormap=colormap, + contrast_limits=contrast_lim, + opacity=opacity, + scale=self.aniso_factors, + visible=visible, + ) + logger.debug("image") + elif isinstance(layer, napari.layers.Labels): + layer = self._viewer.add_labels( + layer.data, + name=f"Scaled_{layer.name}", + opacity=opacity, + scale=self.aniso_factors, + visible=visible, + ) + logger.debug("label") + else: + raise ValueError( + f"Please select a valid layer type, {type(layer)} is not compatible" + ) + return layer + + def _check_for_empty_layer(self, layer, volume_data): + + if layer.data.all() == np.zeros_like(layer.data).all(): + layer.colormap = "red" + layer.data = np.random.random(layer.data.shape) + layer.refresh() + else: + layer.colormap = "twilight_shifted" + layer.data = volume_data + layer.refresh() + + def _add_crop_layer(self, layer, cropx, cropy, cropz): + + crop_data = layer.data[:cropx, :cropy, :cropz] + + if isinstance(layer, napari.layers.Image): + new_layer = self._viewer.add_image( + crop_data, + name=f"cropped_{layer.name}", + blending="additive", + colormap="twilight_shifted", + scale=self.aniso_factors, + ) + # self._check_for_empty_layer(new_layer, crop_data) + + elif isinstance(layer, napari.layers.Labels): + new_layer = self._viewer.add_labels( + crop_data, + name=f"cropped_{layer.name}", + scale=self.aniso_factors, + ) + else: + raise ValueError( + f"Please select a valid layer type, {type(layer)} is not compatible" + ) + return new_layer + + # def _reset_dim(self, dim): + # dim = 0 + + def _add_crop_sliders( + self, + # x, y, z + ): + # modified version of code posted by Juan Nunez Iglesias here : + # https://forum.image.sc/t/napari-viewing-3d-image-of-large-tif-stack-cropping-image-w-general-shape/55500/2 + vw = self._viewer + + im1_stack = self.image_layer1.data + + self._crop_size_x, self._crop_size_y, self._crop_size_z = [ + box.value() for box in self.crop_size_widgets + ] + ############# + dims = [self._x, self._y, self._z] + [logger.debug(f"{dim}") for dim in dims] + logger.debug("SET DIMS ATTEMPT") + # if not self.create_new_layer.isChecked(): + # self._x = x + # self._y = y + # self._z = z + # [logger.debug(f"{dim}") for dim in dims] + # else: + # [self._reset_dim(dim) for dim in dims] + # [logger.debug(f"{dim}") for dim in dims] + ############# + + # logger.debug(f"Crop variables") + # logger.debug(im1_stack.shape) + + # define crop sizes and boundaries for the image + crop_sizes = [self._crop_size_x, self._crop_size_y, self._crop_size_z] + + for i in range(len(crop_sizes)): + if crop_sizes[i] > im1_stack.shape[i]: + crop_sizes[i] = im1_stack.shape[i] + warnings.warn( + f"WARNING : Crop dimension in axis {i} was too large at {crop_sizes[i]}, it was set to {im1_stack.shape[i]}" + ) + + cropx, cropy, cropz = crop_sizes + ends = np.asarray(im1_stack.shape) - np.asarray(crop_sizes) + 1 + + stepsizes = ends // 100 + + # logger.debug(crop_sizes) + # logger.debug(ends) + # logger.debug(stepsizes) + if ( + self.im1_crop_layer is not None + and self.create_new_layer.isChecked() + ): + self.im1_crop_layer.translate = [0, 0, 0] + if self.im2_crop_layer is not None: + self.im2_crop_layer.translate = [0, 0, 0] + + self.im1_crop_layer = self._add_crop_layer( + self.image_layer1, cropx, cropy, cropz + ) + + if self.crop_second_image: + im2_stack = self.image_layer2.data + self.im2_crop_layer = self._add_crop_layer( + self.image_layer2, cropx, cropy, cropz + ) + + def set_slice( + axis, + value, + highres_crop_layer, + labels_crop_layer=None, + crop_lbls=False, + ): + """ "Update cropped volume position""" + # self._check_for_empty_layer(highres_crop_layer, highres_crop_layer.data) + + logger.debug(f"axis : {axis}") + logger.debug(f"value : {value}") + + idx = int(value) + scale = np.asarray(highres_crop_layer.scale) + translate = np.asarray(highres_crop_layer.translate) + izyx = translate // scale + izyx[axis] = idx + izyx = [int(var) for var in izyx] + i, j, k = izyx + + cropx = self._crop_size_x + cropy = self._crop_size_y + cropz = self._crop_size_z + + highres_crop_layer.data = im1_stack[ + i : i + cropx, j : j + cropy, k : k + cropz + ] + highres_crop_layer.translate = scale * izyx + highres_crop_layer.refresh() + + # self._check_for_empty_layer( + # highres_crop_layer, highres_crop_layer.data + # ) + + if crop_lbls and labels_crop_layer is not None: + labels_crop_layer.data = im2_stack[ + i : i + cropx, j : j + cropy, k : k + cropz + ] + labels_crop_layer.translate = scale * izyx + labels_crop_layer.refresh() + + self._x = i + self._y = j + self._z = k + + # spinbox = SpinBox(name="crop_dims", min=1, value=self._crop_size, max=max(im1_stack.shape), step=1) + # spinbox.changed.connect(lambda event : change_size(event)) + + sliders = [ + ui.Slider(text_label=axis, lower=0, upper=end, step=step) + for axis, end, step in zip("zyx", ends, stepsizes) + ] + for axis, slider in enumerate(sliders): + slider.valueChanged.connect( + lambda event, axis=axis: set_slice( + axis, + event, + self.im1_crop_layer, + self.im2_crop_layer, + self.crop_second_image, + ) + ) + container_widget = ui.ContainerWidget(parent=self) + # Container(layout="vertical") + # container_widget.extend(sliders) + ui.add_widgets( + container_widget.layout, + [ui.combine_blocks(s, s.text_label) for s in sliders], + ) + # vw.window.add_dock_widget([spinbox, container_widget], area="right") + wdgts = vw.window.add_dock_widget( + container_widget, area="right", name="Sliders" + ) + wdgts._close_btn = False + + self.docked_widgets.append(wdgts) + # TEST : trying to dynamically change the size of the cropped volume + # BROKEN for now + # @spinbox.changed.connect + # def change_size(value: int): + # + # logger.debug(value) + # i = self._x + # j = self._y + # k = self._z + # + # self._crop_size = value + # + # cropx = value + # cropy = value + # cropz = value + # highres_crop_layer.data = im1_stack[ + # i : i + cropz, j : j + cropy, k : k + cropx + # ] + # highres_crop_layer.refresh() + # labels_crop_layer.data = im2_stack[ + # i : i + cropz, j : j + cropy, k : k + cropx + # ] + # labels_crop_layer.refresh() + # + + +################################# +################################# +################################# +# code for dynamically changing cropped volume with sliders, one for each dim +# WARNING : broken for now + +# def change_size(axis, value) : + +# logger.debug(value) +# logger.debug(axis) +# index = int(value) +# scale = np.asarray(highres_crop_layer.scale) +# translate = np.asarray(highres_crop_layer.translate) +# izyx = translate // scale +# izyx[axis] = index +# izyx = [int(el) for el in izyx] + +# cropz,cropy,cropx = izyx + +# i = self._x +# j = self._y +# k = self._z + +# self._crop_size_x = cropx +# self._crop_size_y = cropy +# self._crop_size_z = cropz + + +# highres_crop_layer.data = im1_stack[ +# i : i + cropz, j : j + cropy, k : k + cropx +# ] +# highres_crop_layer.refresh() +# labels_crop_layer.data = im2_stack[ +# i : i + cropz, j : j + cropy, k : k + cropx +# ] +# labels_crop_layer.refresh() + + +# # @spinbox.changed.connect +# # spinbox = SpinBox(name=crop_dims, min=1, max=max(im1_stack.shape), step=1) +# # spinbox.changed.connect(lambda event : change_size(event)) + + +# sliders = [ +# Slider(name=axis, min=0, max=end, step=step) +# for axis, end, step in zip("zyx", ends, stepsizes) +# ] +# for axis, slider in enumerate(sliders): +# slider.changed.connect( +# lambda event, axis=axis: set_slice(axis, event) +# ) + +# spinboxes = [ +# SpinBox(name=axes+" crop size", min=1, value=self._crop_size_init, max=end, step=1) +# for axes, end in zip("zyx", im1_stack.shape) +# ] +# for axes, box in enumerate(spinboxes): +# box.changed.connect( +# lambda event, axes=axes : change_size(axis, event) +# ) + + +# container_widget = Container(layout="vertical") +# container_widget.extend(sliders) +# container_widget.extend(spinboxes) +# vw.window.add_dock_widget(container_widget, area="right") diff --git a/napari_cellseg3d/plugin_helper.py b/napari_cellseg3d/code_plugins/plugin_helper.py similarity index 84% rename from napari_cellseg3d/plugin_helper.py rename to napari_cellseg3d/code_plugins/plugin_helper.py index ba090f6f..9a83e0d8 100644 --- a/napari_cellseg3d/plugin_helper.py +++ b/napari_cellseg3d/code_plugins/plugin_helper.py @@ -1,11 +1,11 @@ import pathlib import napari + +# Qt from qtpy.QtCore import QSize from qtpy.QtGui import QIcon from qtpy.QtGui import QPixmap - -# Qt from qtpy.QtWidgets import QVBoxLayout from qtpy.QtWidgets import QWidget @@ -13,21 +13,18 @@ from napari_cellseg3d import interface as ui -class Helper(QWidget): - # widget testing +class Helper(QWidget, metaclass=ui.QWidgetSingleton): def __init__(self, viewer: "napari.viewer.Viewer"): super().__init__() - self.help_url = ( - "https://adaptivemotorcontrollab.github.io/cellseg3d-docs/" - ) + self.help_url = "https://adaptivemotorcontrollab.github.io/CellSeg3d/" self.about_url = "https://wysscenter.ch/advances/3d-computer-vision-for-brain-analysis" self.repo_url = "https://github.com/AdaptiveMotorControlLab/CellSeg3d" self._viewer = viewer path = pathlib.Path(__file__).parent.resolve() - url = str(path) + "/res/logo_alpha.png" + url = str(path) + "../res/logo_alpha.png" image = QPixmap(url) self.logo_label = ui.Button(func=lambda: ui.open_url(self.repo_url)) @@ -40,7 +37,7 @@ def __init__(self, viewer: "napari.viewer.Viewer"): self.logo_label.setToolTip("Open Github page") self.info_label = ui.make_label( - f"You are using napari-cellseg3d v.{'0.0.1rc4'}\n\n" + f"You are using napari-cellseg3d v.{'0.0.2rc1'}\n\n" f"Plugin for cell segmentation developed\n" f"by the Mathis Lab of Adaptive Motor Control\n\n" f"Code by :\nCyril Achard\nMaxime Vidal\nJessy Lauer\nMackenzie Mathis\n" @@ -69,8 +66,6 @@ def build(self): ] ui.add_widgets(vbox, widgets) self.setLayout(vbox) - # self.show() - # self._viewer.window.add_dock_widget(self, name="Help/About...", area="right") def remove_from_viewer(self): self._viewer.window.remove_dock_widget(self) diff --git a/napari_cellseg3d/plugin_metrics.py b/napari_cellseg3d/code_plugins/plugin_metrics.py similarity index 89% rename from napari_cellseg3d/plugin_metrics.py rename to napari_cellseg3d/code_plugins/plugin_metrics.py index 86a3fb98..42c2d89e 100644 --- a/napari_cellseg3d/plugin_metrics.py +++ b/napari_cellseg3d/code_plugins/plugin_metrics.py @@ -11,8 +11,8 @@ from napari_cellseg3d import interface as ui from napari_cellseg3d import utils -from napari_cellseg3d.model_instance_seg import to_semantic -from napari_cellseg3d.plugin_base import BasePluginFolder +from napari_cellseg3d.code_models.model_instance_seg import to_semantic +from napari_cellseg3d.code_plugins.plugin_base import BasePluginFolder DEFAULT_THRESHOLD = 0.5 @@ -26,7 +26,7 @@ def __init__(self, viewer: "napari.viewer.Viewer", parent): viewer: viewer to display the widget in parent : parent widget """ - super().__init__(viewer, parent) + super().__init__(viewer, parent, has_results=False) self._viewer = viewer """Viewer to display widget in""" @@ -40,24 +40,25 @@ def __init__(self, viewer: "napari.viewer.Viewer", parent): ###################################### # interface + self.io_panel = self._build_io_panel() # set new descriptions for Filewidgets - self.image_filewidget.set_description("Ground truth") - self.label_filewidget.set_description("Prediction") + self.image_filewidget.description = "Ground truth" + self.labels_filewidget.description = "Prediction" self.btn_compute_dice = ui.Button("Compute Dice", self.compute_dice) - self.rotate_choice = ui.make_checkbox("Find best orientation") + self.rotate_choice = ui.CheckBox("Find best orientation") self.btn_reset_plot = ui.Button("Clear plots", self.remove_plots) self.lbl_threshold_box = ui.make_label("Score threshold", self) self.threshold_box = ui.DoubleIncrementCounter( - min=0.1, max=1, default=DEFAULT_THRESHOLD, step=0.1 + lower=0.1, upper=1, default=DEFAULT_THRESHOLD, step=0.1 ) - self.btn_result_path.setVisible(False) - self.lbl_result_path.setVisible(False) + self.results_filewidget.button.setVisible(False) + self.results_filewidget.text_field.setVisible(False) self.rotate_choice.setToolTip( "This will rotate and flip your images to find the orientation with the best Dice coefficient.\n" @@ -68,14 +69,15 @@ def __init__(self, viewer: "napari.viewer.Viewer", parent): ) self.btn_reset_plot.setToolTip("Erase all plots") - self.build() + self._build() - def build(self): + def _build(self): """Builds the layout of the widget.""" - self.lbl_filetype.setVisible(False) + self.filetype_choice.label.setVisible(False) - w, self.layout = ui.make_container() + w = ui.ContainerWidget() + self.layout = w.layout metrics_group_w, metrics_group_l = ui.make_group("Data") @@ -83,13 +85,13 @@ def build(self): metrics_group_l, [ ui.combine_blocks( - right_or_below=self.btn_image_files, - left_or_above=self.lbl_image_files, + right_or_below=self.image_filewidget.button, + left_or_above=self.image_filewidget.text_field, min_spacing=70, ), # images -> ground truth ui.combine_blocks( - right_or_below=self.btn_label_files, - left_or_above=self.lbl_label_files, + right_or_below=self.labels_filewidget.button, + left_or_above=self.labels_filewidget.text_field, min_spacing=70, ), # labels -> prediction ], @@ -117,7 +119,7 @@ def build(self): metrics_group_w, param_group_w, self.btn_compute_dice, - self.make_close_button(), + self._make_close_button(), self.btn_reset_plot, ], ) diff --git a/napari_cellseg3d/plugin_model_inference.py b/napari_cellseg3d/code_plugins/plugin_model_inference.py similarity index 50% rename from napari_cellseg3d/plugin_model_inference.py rename to napari_cellseg3d/code_plugins/plugin_model_inference.py index f9cd5615..99733936 100644 --- a/napari_cellseg3d/plugin_model_inference.py +++ b/napari_cellseg3d/code_plugins/plugin_model_inference.py @@ -1,27 +1,25 @@ import warnings +from functools import partial import napari import numpy as np import pandas as pd -# Qt -from qtpy.QtWidgets import QSizePolicy - # local +from napari_cellseg3d import config from napari_cellseg3d import interface as ui from napari_cellseg3d import utils -from napari_cellseg3d.model_framework import ModelFramework -from napari_cellseg3d.model_workers import InferenceWorker - - -# TODO for layer inference : button behaviour/visibility, error if no layer selected, test all funcs +from napari_cellseg3d.code_models.model_framework import ModelFramework +from napari_cellseg3d.code_models.model_workers import InferenceResult +from napari_cellseg3d.code_models.model_workers import InferenceWorker +from napari_cellseg3d.code_plugins.plugin_convert import InstanceWidgets -class Inferer(ModelFramework): +class Inferer(ModelFramework, metaclass=ui.QWidgetSingleton): """A plugin to run already trained models in evaluation mode to preform inference and output a label on all given volumes.""" - def __init__(self, viewer: "napari.viewer.Viewer"): + def __init__(self, viewer: "napari.viewer.Viewer", parent=None): """ Creates an Inference loader plugin with the following widgets : @@ -57,54 +55,62 @@ def __init__(self, viewer: "napari.viewer.Viewer"): Args: viewer (napari.viewer.Viewer): napari viewer to display the widget in """ - super().__init__(viewer) + super().__init__( + viewer, + parent, + loads_labels=False, + ) self._viewer = viewer """Viewer to display the widget in""" + self.enable_utils_menu() - self.worker = None - """Worker for inference, should be an InferenceWorker instance from :doc:model_workers.py""" - - self.transforms = None + self.worker: InferenceWorker = None + """Worker for inference, should be an InferenceWorker instance from model_workers.py""" - self.show_res = False - self.show_res_nbr = 1 - self.show_original = True - self.zoom = [1, 1, 1] + self.model_info: config.ModelInfo = None + """ModelInfo class from config.py""" - self.instance_params = None - self.stats_to_csv = False - - self.keep_on_cpu = False - self.use_window_inference = False - self.window_inference_size = None - self.window_overlap = 0.25 + self.config = config.InfererConfig() + """InfererConfig class from config.py""" + self.worker_config: config.InferenceWorkerConfig = ( + config.InferenceWorkerConfig() + ) + """InferenceWorkerConfig class from config.py""" + self.instance_config: config.InstanceSegConfig = ( + config.InstanceSegConfig() + ) + """InstanceSegConfig class from config.py""" + self.post_process_config: config.PostProcessConfig = ( + config.PostProcessConfig() + ) + """PostProcessConfig class from config.py""" ########################### # interface + self.data_panel = self._build_io_panel() # None - ( - self.view_results_container, - self.view_results_layout, - ) = ui.make_container(T=7, B=0, parent=self) + self.view_results_container = ui.ContainerWidget(t=7, b=0, parent=self) + self.view_results_panel = None - self.view_checkbox = ui.make_checkbox( - "View results in napari", self.toggle_display_number + self.view_checkbox = ui.CheckBox( + "View results in napari", self._toggle_display_number ) - self.display_number_choice = ui.IntIncrementCounter(min=1, default=5) - self.lbl_display_number = ui.make_label("How many ? (max. 10)", self) + self.display_number_choice_slider = ui.Slider( + lower=1, upper=10, default=5, text_label="How many ? " + ) - self.show_original_checkbox = ui.make_checkbox("Show originals") + self.show_original_checkbox = ui.CheckBox("Show originals") ###################### ###################### # TODO : better way to handle SegResNet size reqs ? self.model_input_size = ui.IntIncrementCounter( - min=1, max=1024, default=128 + lower=1, upper=1024, default=128, label="\nModel input size" ) self.model_choice.currentIndexChanged.connect( - self.toggle_display_model_input_size + self._toggle_display_model_input_size ) self.model_choice.setCurrentIndex(0) @@ -115,31 +121,35 @@ def __init__(self, viewer: "napari.viewer.Viewer"): default_z=5, # TODO change default ) - self.aniso_resolutions = [1, 1, 1] + # self.worker_config.post_process_config.zoom.zoom_values = [ + # 1.0, + # 1.0, + # 1.0, + # ] # ui.add_blank(self.aniso_container, aniso_layout) ###################### ###################### - self.thresholding_checkbox = ui.make_checkbox( - "Perform thresholding", self.toggle_display_thresh - ) - - self.thresholding_count = ui.DoubleIncrementCounter( - max=1, default=0.7, step=0.05 + self.thresholding_checkbox = ui.CheckBox( + "Perform thresholding", self._toggle_display_thresh ) - self.thresholding_container, self.thresh_layout = ui.make_container( - T=7, parent=self + self.thresholding_slider = ui.Slider( + lower=1, + default=config.PostProcessConfig().thresholding.threshold_value + * 100, + divide_factor=100.0, + parent=self, ) - self.window_infer_box = ui.CheckBox(title="Use window inference") - self.window_infer_box.clicked.connect(self.toggle_display_window_size) + self.window_infer_box = ui.CheckBox("Use window inference") + self.window_infer_box.clicked.connect(self._toggle_display_window_size) sizes_window = ["8", "16", "32", "64", "128", "256", "512"] # ( # self.window_size_choice, - # self.lbl_window_size_choice, + # self.window_size_choice.label, # ) = ui.make_combobox(sizes_window, label="Window size and overlap") # self.window_overlap = ui.make_n_spinboxes( # max=1, @@ -151,98 +161,72 @@ def __init__(self, viewer: "napari.viewer.Viewer"): self.window_size_choice = ui.DropdownMenu( sizes_window, label="Window size" ) - self.lbl_window_size_choice = self.window_size_choice.label - self.window_overlap_counter = ui.DoubleIncrementCounter( - min=0, - max=1, - default=0.25, - step=0.05, + self.window_overlap_slider = ui.Slider( + default=config.SlidingWindowConfig.window_overlap * 100, + divide_factor=100.0, parent=self, - label="Overlap %", + text_label="Overlap %", ) - self.keep_data_on_cpu_box = ui.CheckBox(title="Keep data on CPU") + self.keep_data_on_cpu_box = ui.CheckBox("Keep data on CPU") window_size_widgets = ui.combine_blocks( self.window_size_choice, - self.lbl_window_size_choice, + self.window_size_choice.label, horizontal=False, ) - # self.window_infer_params = ui.combine_blocks( - # self.window_overlap, - # self.window_infer_params, - # horizontal=False, - # ) - self.window_infer_params = ui.combine_blocks( - window_size_widgets, - self.window_overlap_counter.get_with_label(horizontal=False), - horizontal=False, + self.window_infer_params = ui.ContainerWidget(parent=self) + ui.add_widgets( + self.window_infer_params.layout, + [ + window_size_widgets, + self.window_overlap_slider.container, + ], ) ################## ################## # instance segmentation widgets - self.instance_box = ui.make_checkbox( - "Run instance segmentation", func=self.toggle_display_instance - ) - - self.instance_method_choice = ui.DropdownMenu( - ["Connected components", "Watershed"] - ) + self.instance_widgets = InstanceWidgets(self) - self.instance_prob_thresh = ui.DoubleIncrementCounter( - max=0.99, default=0.7, step=0.05 - ) - self.instance_prob_thresh_lbl = ui.make_label( - "Probability threshold :", self - ) - self.instance_prob_t_container = ui.combine_blocks( - right_or_below=self.instance_prob_thresh, - left_or_above=self.instance_prob_thresh_lbl, - horizontal=False, + self.use_instance_choice = ui.CheckBox( + "Run instance segmentation", func=self._toggle_display_instance ) - self.instance_small_object_thresh = ui.IntIncrementCounter( - max=100, default=10, step=5 - ) - self.instance_small_object_thresh_lbl = ui.make_label( - "Small object removal threshold :", self - ) - self.instance_small_object_t_container = ui.combine_blocks( - right_or_below=self.instance_small_object_thresh, - left_or_above=self.instance_small_object_thresh_lbl, - horizontal=False, - ) - self.save_stats_to_csv_box = ui.make_checkbox( + self.save_stats_to_csv_box = ui.CheckBox( "Save stats to csv", parent=self ) - ( - self.instance_param_container, - self.instance_layout, - ) = ui.make_container(T=7, B=0, parent=self) - ################## ################## - self.btn_start = ui.Button("Start on folder", self.start) - self.btn_start_layer = ui.Button( - "Start on selected layer", - lambda: self.start(on_layer=True), + self.btn_start = ui.Button("Start", self.start) + self.btn_close = self._make_close_button() + + self._set_tooltips() + + self._build() + self._set_io_visibility() + self.folder_choice.toggled.connect( + partial( + self._show_io_element, + self.view_results_panel, + self.folder_choice, + ) ) - self.btn_close = self.make_close_button() + self.folder_choice.toggle() + self.layer_choice.toggle() - # hide unused widgets from parent class - self.label_filewidget.setVisible(False) - self.model_filewidget.setVisible(False) + self._remove_unused() + def _set_tooltips(self): ################## ################## # tooltips self.view_checkbox.setToolTip("Show results in the napari viewer") - self.display_number_choice.setToolTip( + self.display_number_choice_slider.tooltips = ( "Choose how many results to display once the work is done.\n" "Maximum is 10 for clarity" ) @@ -250,7 +234,8 @@ def __init__(self, viewer: "napari.viewer.Viewer"): "Displays the image used for inference in the viewer" ) self.model_input_size.setToolTip( - "Image size on which the model has been trained (default : 128)" + "Image size on which the model has been trained (default : 128)\n" + "DO NOT CHANGE if you are using the provided pre-trained weights" ) thresh_desc = ( @@ -259,7 +244,7 @@ def __init__(self, viewer: "napari.viewer.Viewer"): ) self.thresholding_checkbox.setToolTip(thresh_desc) - self.thresholding_count.setToolTip(thresh_desc) + self.thresholding_slider.tooltips = thresh_desc self.window_infer_box.setToolTip( "Sliding window inference runs the model on parts of the image" "\nrather than the whole image, to reduce memory requirements." @@ -268,36 +253,16 @@ def __init__(self, viewer: "napari.viewer.Viewer"): self.window_size_choice.setToolTip( "Size of the window to run inference with (in pixels)" ) - - self.window_overlap_counter.setToolTip( - "Percentage of overlap between windows to use when using sliding window" - ) - - # self.window_overlap.setToolTip( - # "Amount of overlap between sliding windows" - # ) + self.window_overlap_slider.tooltips = "Percentage of overlap between windows to use when using sliding window" self.keep_data_on_cpu_box.setToolTip( "If enabled, data will be kept on the RAM rather than the VRAM.\nCan avoid out of memory issues with CUDA" ) - self.instance_box.setToolTip( + self.use_instance_choice.setToolTip( "Instance segmentation will convert instance (0/1) labels to labels" " that attempt to assign an unique ID to each cell." ) - self.instance_method_choice.setToolTip( - "Choose which method to use for instance segmentation" - "\nConnected components : all separated objects will be assigned an unique ID. " - "Robust but will not work correctly with adjacent/touching objects\n" - "Watershed : assigns objects ID based on the probability gradient surrounding an object. " - "Requires the model to surround objects in a gradient;" - " can possibly correctly separate unique but touching/adjacent objects." - ) - self.instance_prob_thresh.setToolTip( - "All objects below this probability will be ignored (set to 0)" - ) - self.instance_small_object_thresh.setToolTip( - "Will remove all objects smaller (in volume) than the specified number of pixels" - ) + self.save_stats_to_csv_box.setToolTip( "Will save several statistics for each object to a csv in the results folder. Stats include : " "volume, centroid coordinates, sphericity" @@ -305,126 +270,87 @@ def __init__(self, viewer: "napari.viewer.Viewer"): ################## ################## - self.build() - def check_ready(self): """Checks if the paths to the files are properly set""" - if ( - self.images_filepaths != [""] - and self.images_filepaths != [] - and self.results_path != "" - ) or ( - self.results_path != "" - and self._viewer.layers.selection.active is not None - ): - return True - else: - warnings.formatwarning = utils.format_Warning - warnings.warn("Image and label paths are not correctly set") - return False - def toggle_display_model_input_size(self): + if self.layer_choice.isChecked(): + if self.image_layer_loader.layer_data() is not None: + return True + elif self.folder_choice.isChecked(): + if self.image_filewidget.check_ready(): + return True + return False + + def _toggle_display_model_input_size(self): if ( self.model_choice.currentText() == "SegResNet" or self.model_choice.currentText() == "SwinUNetR" ): self.model_input_size.setVisible(True) + self.model_input_size.label.setVisible(True) else: self.model_input_size.setVisible(False) + self.model_input_size.label.setVisible(False) - def toggle_display_number(self): + def _toggle_display_number(self): """Shows the choices for viewing results depending on whether :py:attr:`self.view_checkbox` is checked""" ui.toggle_visibility(self.view_checkbox, self.view_results_container) - def toggle_display_thresh(self): + def _toggle_display_thresh(self): """Shows the choices for thresholding results depending on whether :py:attr:`self.thresholding_checkbox` is checked""" ui.toggle_visibility( - self.thresholding_checkbox, self.thresholding_container + self.thresholding_checkbox, self.thresholding_slider.container ) - def toggle_display_instance(self): + def _toggle_display_instance(self): """Shows or hides the options for instance segmentation based on current user selection""" - ui.toggle_visibility(self.instance_box, self.instance_param_container) + ui.toggle_visibility(self.use_instance_choice, self.instance_widgets) - def toggle_display_window_size(self): + def _toggle_display_window_size(self): """Show or hide window size choice depending on status of self.window_infer_box""" ui.toggle_visibility(self.window_infer_box, self.window_infer_params) - def build(self): + def _build(self): """Puts all widgets in a layout and adds them to the napari Viewer""" # ui.add_blank(self.view_results_container, view_results_layout) ui.add_widgets( - self.view_results_layout, + self.view_results_container.layout, [ self.view_checkbox, - self.lbl_display_number, - self.display_number_choice, + self.display_number_choice_slider.container, self.show_original_checkbox, ], alignment=None, ) - self.view_results_container.setLayout(self.view_results_layout) - - self.anisotropy_wdgt.build() - - self.thresh_layout.addWidget( - self.thresholding_count, alignment=ui.LEFT_AL - ) - # ui.add_blank(self.thresholding_container, thresh_layout) - self.thresholding_container.setLayout( - self.thresh_layout - ) # thresholding - self.thresholding_container.setVisible(False) - - ui.add_widgets( - self.instance_layout, - [ - self.instance_method_choice, - self.instance_prob_t_container, - self.instance_small_object_t_container, - self.save_stats_to_csv_box, - ], + self.view_results_container.setLayout( + self.view_results_container.layout ) - self.instance_param_container.setLayout(self.instance_layout) + self.anisotropy_wdgt.build() - self.setSizePolicy(QSizePolicy.Maximum, QSizePolicy.MinimumExpanding) ###### ############ ################## - tab, tab_layout = ui.make_container( - B=1, parent=self + tab = ui.ContainerWidget( + b=1, parent=self ) # tab that will contain all widgets L, T, R, B = 7, 20, 7, 11 # margins for group boxes ################################# ################################# - io_group, io_layout = ui.make_group("Data", L, T, R, B, parent=self) + # self.image_filewidget.update_field_color("black") - ui.add_widgets( - io_layout, - [ - ui.combine_blocks( - self.filetype_choice, self.lbl_filetype - ), # file extension - ui.combine_blocks( - self.btn_image_files, self.lbl_image_files - ), # in folder - ui.combine_blocks( - self.btn_result_path, self.lbl_result_path - ), # out folder - ], + self.results_filewidget.text_field.setText( + self.worker_config.results_path ) - self.image_filewidget.set_required(False) - self.image_filewidget.update_field_color("black") + self.results_filewidget.check_ready() - io_group.setLayout(io_layout) - tab_layout.addWidget(io_group) + tab.layout.addWidget(self.data_panel) ################################# ################################# - ui.add_blank(tab, tab_layout) + ui.add_blank(tab, tab.layout) ################################# ################################# # model group @@ -438,19 +364,21 @@ def build(self): [ self.model_choice, self.custom_weights_choice, - self.weights_path_container, + self.weights_filewidget, + self.model_input_size.label, self.model_input_size, ], ) - self.weights_path_container.setVisible(False) - self.lbl_model_choice.setVisible(False) # TODO remove (?) + self.weights_filewidget.setVisible(False) + self.model_choice.label.setVisible( + False + ) # TODO reminder for adding custom model model_group_w.setLayout(model_group_l) - tab_layout.addWidget(model_group_w) - + tab.layout.addWidget(model_group_w) ################################# ################################# - ui.add_blank(tab, tab_layout) + ui.add_blank(tab, tab.layout) ################################# ################################# inference_param_group_w, inference_param_group_l = ui.make_group( @@ -469,11 +397,11 @@ def build(self): inference_param_group_w.setLayout(inference_param_group_l) - tab_layout.addWidget(inference_param_group_w) + tab.layout.addWidget(inference_param_group_w) ################################# ################################# - ui.add_blank(tab, tab_layout) + ui.add_blank(tab, tab.layout) ################################# ################################# # post proc group @@ -481,26 +409,34 @@ def build(self): "Post-processing", parent=self ) + self.thresholding_slider.container.setVisible(False) + ui.add_widgets( post_proc_layout, [ self.anisotropy_wdgt, # anisotropy self.thresholding_checkbox, - self.thresholding_container, # thresholding - self.instance_box, - self.instance_param_container, # instance segmentation + self.thresholding_slider.container, # thresholding + self.use_instance_choice, + self.instance_widgets, + self.save_stats_to_csv_box, + # self.instance_param_container, # instance segmentation ], ) + ModelFramework._show_io_element( + self.save_stats_to_csv_box, self.use_instance_choice + ) self.anisotropy_wdgt.container.setVisible(False) - self.thresholding_container.setVisible(False) - self.instance_param_container.setVisible(False) + self.thresholding_slider.container.setVisible(False) + self.instance_widgets.setVisible(False) + self.save_stats_to_csv_box.setVisible(False) post_proc_group.setLayout(post_proc_layout) - tab_layout.addWidget(post_proc_group, alignment=ui.LEFT_AL) + tab.layout.addWidget(post_proc_group, alignment=ui.LEFT_AL) ################################### ################################### - ui.add_blank(tab, tab_layout) + ui.add_blank(tab, tab.layout) ################################### ################################### display_opt_group, display_opt_layout = ui.make_group( @@ -519,22 +455,22 @@ def build(self): self.view_results_container.setVisible(False) self.view_checkbox.toggle() - self.toggle_display_number() + self._toggle_display_number() # TODO : add custom model handling ? - # self.lbl_label.setText("model.pth directory :") + # self.label_filewidget.text_field.setText("model.pth directory :") display_opt_group.setLayout(display_opt_layout) - tab_layout.addWidget(display_opt_group) + self.view_results_panel = display_opt_group + tab.layout.addWidget(display_opt_group) ################################### - ui.add_blank(self, tab_layout) + ui.add_blank(self, tab.layout) ################################### ################################### ui.add_widgets( - tab_layout, + tab.layout, [ self.btn_start, - self.btn_start_layer, self.btn_close, ], ) @@ -544,15 +480,15 @@ def build(self): # end of tabs, combine into scrollable ui.ScrollArea.make_scrollable( parent=tab, - contained_layout=tab_layout, + contained_layout=tab.layout, min_wh=[200, 100], ) self.addTab(tab, "Inference") - - self.setMinimumSize(180, 100) + tab.adjustSize() + # self.setMinimumSize(180, 100) # self.setBaseSize(210, 400) - def start(self, on_layer=False): + def start(self): """Start the inference process, enables :py:attr:`~self.worker` and does the following: * Checks if the output and input folders are correctly set @@ -576,7 +512,7 @@ def start(self, on_layer=False): """ if not self.check_ready(): - err = "Aborting, please choose correct paths" + err = "Aborting, please choose valid inputs" self.log.print_and_log(err) raise ValueError(err) @@ -585,117 +521,99 @@ def start(self, on_layer=False): pass else: self.worker.start() - self.btn_start_layer.setVisible(False) self.btn_start.setText("Running... Click to stop") else: self.log.print_and_log("Starting...") self.log.print_and_log("*" * 20) - device = self.get_device() + self.model_info = config.ModelInfo( + name=self.model_choice.currentText(), + model_input_size=self.model_input_size.value(), + ) - model_key = self.model_choice.currentText() - model_dict = { # gather model info - "name": model_key, - "class": self.get_model(model_key), - "model_input_size": self.model_input_size.value(), - } + self.weights_config.custom = self.custom_weights_choice.isChecked() - if self.custom_weights_choice.isChecked(): - weights_dict = {"custom": True, "path": self.weights_path} + save_path = self.results_filewidget.text_field.text() + if not self._check_results_path(save_path): + msg = f"ERROR: please set valid results path. Current path is {save_path}" + self.log.print_and_log(msg) + warnings.warn(msg) else: - weights_dict = { - "custom": False, - } - - if self.anisotropy_wdgt.is_enabled(): - self.aniso_resolutions = ( - self.anisotropy_wdgt.get_anisotropy_resolution_xyz( - as_factors=False - ) - ) - self.zoom = ( - self.anisotropy_wdgt.get_anisotropy_resolution_xyz() + if self.results_path is None: + self.results_path = save_path + + zoom_config = config.Zoom( + enabled=self.anisotropy_wdgt.enabled(), + zoom_values=self.anisotropy_wdgt.scaling_xyz(), + ) + thresholding_config = config.Thresholding( + enabled=self.thresholding_checkbox.isChecked(), + threshold_value=self.thresholding_slider.slider_value, + ) + + instance_thresh_config = config.Thresholding( + threshold_value=self.instance_widgets.threshold_slider1.slider_value + ) + instance_small_object_thresh_config = config.Thresholding( + threshold_value=self.instance_widgets.counter1.value() + ) + self.instance_config = config.InstanceSegConfig( + enabled=self.use_instance_choice.isChecked(), + method=self.instance_widgets.method_choice.currentText(), + threshold=instance_thresh_config, + small_object_removal_threshold=instance_small_object_thresh_config, + ) + + self.post_process_config = config.PostProcessConfig( + zoom=zoom_config, + thresholding=thresholding_config, + instance=self.instance_config, + ) + + if self.window_infer_box.isChecked(): + size = int(self.window_size_choice.currentText()) + window_config = config.SlidingWindowConfig( + window_size=size, + window_overlap=self.window_overlap_slider.slider_value, ) else: - self.zoom = [1, 1, 1] - - self.transforms = { # TODO figure out a better way ? - "thresh": [ - self.thresholding_checkbox.isChecked(), - self.thresholding_count.value(), - ], - "zoom": [ - self.anisotropy_wdgt.checkbox.isChecked(), - self.zoom, - ], - } - - self.instance_params = { - "do_instance": self.instance_box.isChecked(), - "method": self.instance_method_choice.currentText(), - "threshold": self.instance_prob_thresh.value(), - "size_small": self.instance_small_object_thresh.value(), - } - self.stats_to_csv = self.save_stats_to_csv_box.isChecked() - # print(f"METHOD : {self.instance_method_choice.currentText()}") - - self.show_res_nbr = self.display_number_choice.value() - - self.keep_on_cpu = self.keep_data_on_cpu_box.isChecked() - self.use_window_inference = self.window_infer_box.isChecked() - self.window_inference_size = int( - self.window_size_choice.currentText() + window_config = config.SlidingWindowConfig() + + self.worker_config = config.InferenceWorkerConfig( + device=self.get_device(), + model_info=self.model_info, + weights_config=self.weights_config, + results_path=self.results_path, + filetype=self.filetype_choice.currentText(), + keep_on_cpu=self.keep_data_on_cpu_box.isChecked(), + compute_stats=self.save_stats_to_csv_box.isChecked(), + post_process_config=self.post_process_config, + sliding_window_config=window_config, ) - self.window_overlap = self.window_overlap_counter.value() - - if not on_layer: - self.worker = InferenceWorker( - device=device, - model_dict=model_dict, - weights_dict=weights_dict, - images_filepaths=self.images_filepaths, - results_path=self.results_path, - filetype=self.filetype_choice.currentText(), - transforms=self.transforms, - instance=self.instance_params, - use_window=self.use_window_inference, - window_infer_size=self.window_inference_size, - window_overlap=self.window_overlap, - keep_on_cpu=self.keep_on_cpu, - stats_csv=self.stats_to_csv, - ) + ##################### + ##################### + ##################### + + if self.folder_choice.isChecked(): + + self.worker_config.images_filepaths = self.images_filepaths + self.worker = InferenceWorker(worker_config=self.worker_config) + + elif self.layer_choice.isChecked(): + + self.worker_config.layer = self.image_layer_loader.layer_data() + self.worker = InferenceWorker(worker_config=self.worker_config) + else: - layer = self._viewer.layers.selection.active - self.worker = InferenceWorker( - device=device, - model_dict=model_dict, - weights_dict=weights_dict, - results_path=self.results_path, - filetype=self.filetype_choice.currentText(), - transforms=self.transforms, - instance=self.instance_params, - use_window=self.use_window_inference, - window_infer_size=self.window_inference_size, - keep_on_cpu=self.keep_on_cpu, - window_overlap=self.window_overlap, - stats_csv=self.stats_to_csv, - layer=layer, - ) + raise ValueError("Please select to load a layer or folder") self.worker.set_download_log(self.log) - yield_connect_show_res = lambda data: self.on_yield( - data, - widget=self, - ) - self.worker.started.connect(self.on_start) self.worker.log_signal.connect(self.log.print_and_log) self.worker.warn_signal.connect(self.log.warn) - self.worker.yielded.connect(yield_connect_show_res) - self.worker.errored.connect( - yield_connect_show_res - ) # TODO fix showing errors from thread + self.worker.yielded.connect(partial(self.on_yield)) # + self.worker.errored.connect(partial(self.on_yield)) self.worker.finished.connect(self.on_finish) if self.get_device(show=False) == "cuda": @@ -711,14 +629,23 @@ def start(self, on_layer=False): else: # once worker is started, update buttons self.worker.start() self.btn_start.setText("Running... Click to stop") - self.btn_start_layer.setVisible(False) def on_start(self): """Catches start signal from worker to call :py:func:`~display_status_report`""" self.display_status_report() - self.show_res = self.view_checkbox.isChecked() - self.show_original = self.show_original_checkbox.isChecked() + self.config = config.InfererConfig( + model_info=self.model_info, + show_results=self.view_checkbox.isChecked(), + show_results_count=self.display_number_choice_slider.slider_value, + show_original=self.show_original_checkbox.isChecked(), + anisotropy_resolution=self.anisotropy_wdgt.resolution_xyz, + ) + if self.layer_choice.isChecked(): + self.config.show_results = True + self.config.show_results_count = 5 + self.config.show_original = False + self.log.print_and_log(f"Worker started at {utils.get_time()}") self.log.print_and_log(f"Saving results to : {self.results_path}") self.log.print_and_log("Worker is running...") @@ -727,25 +654,25 @@ def on_error(self): """Catches errors and tries to clean up. TODO : upgrade""" self.log.print_and_log("Worker errored...") self.log.print_and_log("Trying to clean up...") - self.btn_start.setText("Start on folder") + self.btn_start.setText("Start") self.btn_close.setVisible(True) self.worker = None + self.worker_config = None self.empty_cuda_cache() def on_finish(self): """Catches finished signal from worker, resets workspace for next run.""" self.log.print_and_log(f"\nWorker finished at {utils.get_time()}") self.log.print_and_log("*" * 20) - self.btn_start.setText("Start on folder") - self.btn_start_layer.setVisible(True) + self.btn_start.setText("Start") self.btn_close.setVisible(True) self.worker = None + self.worker_config = None self.empty_cuda_cache() - @staticmethod - def on_yield(data, widget): + def on_yield(self, result: InferenceResult): """ Displays the inference results in napari as long as data["image_id"] is lower than nbr_to_show, and updates the status report docked widget (namely the progress bar) @@ -757,28 +684,36 @@ def on_yield(data, widget): # viewer, progress, show_res, show_res_number, zoon, show_original # check that viewer checkbox is on and that max number of displays has not been reached. - image_id = data["image_id"] - model_name = data["model_name"] - total = len(widget.images_filepaths) + # widget.log.print_and_log(result) + + image_id = result.image_id + model_name = result.model_name + if self.worker_config.images_filepaths is not None: + total = len(self.worker_config.images_filepaths) + else: + total = 1 - viewer = widget._viewer + viewer = self._viewer pbar_value = image_id // total - if image_id == 0: + if pbar_value == 0: pbar_value = 1 - widget.progress.setValue(100 * pbar_value) + self.progress.setValue(100 * pbar_value) - if widget.show_res and image_id <= widget.show_res_nbr: + if ( + self.config.show_results + and image_id <= self.config.show_results_count + ): - zoom = widget.zoom + zoom = self.worker_config.post_process_config.zoom.zoom_values viewer.dims.ndisplay = 3 viewer.scale_bar.visible = True - if widget.show_original and data["original"] is not None: + if self.config.show_original and result.original is not None: original_layer = viewer.add_image( - data["original"], + result.original, colormap="inferno", name=f"original_{image_id}", scale=zoom, @@ -786,37 +721,46 @@ def on_yield(data, widget): ) out_colormap = "twilight" - if widget.transforms["thresh"][0]: + if self.worker_config.post_process_config.thresholding.enabled: out_colormap = "turbo" out_layer = viewer.add_image( - data["result"], + result.result, colormap=out_colormap, name=f"pred_{image_id}_{model_name}", opacity=0.8, ) - if data["instance_labels"] is not None: + if result.instance_labels is not None: + + labels = result.instance_labels + method = self.worker_config.post_process_config.instance.method - labels = data["instance_labels"] - method = widget.instance_params["method"] - number_cells = np.amax(labels) + number_cells = ( + np.unique(labels.flatten()).size - 1 + ) # remove background name = f"({number_cells} objects)_{method}_instance_labels_{image_id}" instance_layer = viewer.add_labels(labels, name=name) - data_dict = data["object stats"] + stats = result.stats - if widget.stats_to_csv and data_dict is not None: + if self.worker_config.compute_stats and stats is not None: - numeric_data = pd.DataFrame(data_dict) + stats_dict = stats.get_dict() + stats_df = pd.DataFrame(stats_dict) + + self.log.print_and_log( + f"Number of instances : {stats.number_objects}" + ) csv_name = f"/{method}_seg_results_{image_id}_{utils.get_date_time()}.csv" - numeric_data.to_csv( - widget.results_path + csv_name, index=False + stats_df.to_csv( + self.worker_config.results_path + csv_name, + index=False, ) - # widget.log.print_and_log( - # f"\nNUMBER OF CELLS : {number_cells}\n" + # self.log.print_and_log( + # f"OBJECTS DETECTED : {number_cells}\n" # ) diff --git a/napari_cellseg3d/plugin_model_training.py b/napari_cellseg3d/code_plugins/plugin_model_training.py similarity index 58% rename from napari_cellseg3d/plugin_model_training.py rename to napari_cellseg3d/code_plugins/plugin_model_training.py index 1e9deed6..ac8aefc3 100644 --- a/napari_cellseg3d/plugin_model_training.py +++ b/napari_cellseg3d/code_plugins/plugin_model_training.py @@ -1,6 +1,6 @@ -import os import shutil import warnings +from functools import partial from pathlib import Path import matplotlib.pyplot as plt @@ -25,16 +25,20 @@ from qtpy.QtWidgets import QSizePolicy # local +from napari_cellseg3d import config from napari_cellseg3d import interface as ui from napari_cellseg3d import utils -from napari_cellseg3d.model_framework import ModelFramework -from napari_cellseg3d.model_workers import TrainingWorker +from napari_cellseg3d.code_models.model_framework import ModelFramework +from napari_cellseg3d.code_models.model_workers import TrainingReport +from napari_cellseg3d.code_models.model_workers import TrainingWorker NUMBER_TABS = 3 DEFAULT_PATCH_SIZE = 64 +logger = utils.LOGGER -class Trainer(ModelFramework): + +class Trainer(ModelFramework, metaclass=ui.QWidgetSingleton): """A plugin to train pre-defined PyTorch models for one-channel segmentation directly in napari. Features parameter selection for training, dynamic loss plotting and automatic saving of the best weights during training through validation.""" @@ -42,15 +46,6 @@ class Trainer(ModelFramework): def __init__( self, viewer: "napari.viewer.Viewer", - data_path="", - label_path="", - results_path="", - model_index=0, - loss_index=0, - epochs=5, - samples=2, - batch=1, - val_interval=1, ): """Creates a Trainer tab widget with the following functionalities : @@ -90,8 +85,6 @@ def __init__( * A choice of using random or deterministic training TODO training plugin: - - * Custom model loading @@ -123,37 +116,24 @@ def __init__( # self.master = parent self._viewer = viewer """napari.viewer.Viewer: viewer in which the widget is displayed""" + self.enable_utils_menu() - self.data_path = "" - self.label_path = "" - self.results_path = "" - self.results_path_folder = "" + self.data_path = None + self.label_path = None + self.results_path = None """Path to the folder inside the results path that contains all results""" - self.save_as_zip = False - """Whether to zip results folder once done. Creates a zipped copy of the results folder.""" - - # recover default values - self.num_samples = samples - """Number of samples to extract""" - self.batch_size = batch - """Batch size""" - self.max_epochs = epochs - """Epochs""" - self.val_interval = val_interval - """At which epochs to perform validation. E.g. if 2, will run validation on epochs 2,4,6...""" - self.patch_size = [] - """The size of samples to be extracted from images""" - self.learning_rate = 1e-3 + self.config = config.TrainerConfig() self.model = None # TODO : custom model loading ? self.worker = None """Training worker for multithreading, should be a TrainingWorker instance from :doc:model_workers.py""" + self.worker_config = None self.data = None """Data dictionary containing file paths""" self.stop_requested = False """Whether the worker should stop or not""" - self.start_time = "" + self.start_time = None self.loss_dict = { "Dice loss": DiceLoss(sigmoid=True), @@ -174,52 +154,66 @@ def __init__( """Plot for dice metric""" self.plot_dock = None """Docked widget with plots""" + self.result_layers = [] + """Layers to display checkpoint""" self.df = None self.loss_values = [] self.validation_values = [] - self.model_choice.setCurrentIndex(model_index) + # self.model_choice.setCurrentIndex(0) ################################ # interface + default = config.TrainingWorkerConfig() - self.zip_choice = ui.make_checkbox("Compress results") + self.zip_choice = ui.CheckBox("Compress results") - self.validation_percent_choice = ui.IntIncrementCounter( - 10, 90, default=80, step=1, parent=self + self.validation_percent_choice = ui.Slider( + lower=10, + upper=90, + default=default.validation_percent * 100, + step=5, + parent=self, ) self.epoch_choice = ui.IntIncrementCounter( - min=2, max=1000, default=self.max_epochs + lower=2, + upper=200, + default=default.max_epochs, + label="Number of epochs : ", ) - self.lbl_epoch_choice = ui.make_label("Number of epochs : ", self) self.loss_choice = ui.DropdownMenu( sorted(self.loss_dict.keys()), label="Loss function" ) self.lbl_loss_choice = self.loss_choice.label - self.loss_choice.setCurrentIndex(loss_index) + self.loss_choice.setCurrentIndex(0) - self.sample_choice = ui.IntIncrementCounter( - min=2, max=50, default=self.num_samples - ) - self.lbl_sample_choice = ui.make_label( - "Number of patches per image : ", self + self.sample_choice_slider = ui.Slider( + lower=2, + upper=50, + default=default.num_samples, + text_label="Number of patches per image : ", ) - self.sample_choice.setVisible(False) - self.lbl_sample_choice.setVisible(False) - self.batch_choice = ui.IntIncrementCounter( - min=1, max=10, default=self.batch_size + self.sample_choice_slider.container.setVisible(False) + + self.batch_choice = ui.Slider( + lower=1, + upper=10, + default=default.batch_size, + text_label="Batch size : ", ) - self.lbl_batch_choice = ui.make_label("Batch size : ", self) self.val_interval_choice = ui.IntIncrementCounter( - default=self.val_interval + default=default.validation_interval, + label="Validation interval : ", ) - self.lbl_val_interv_choice = ui.make_label( - "Validation interval : ", self + + self.epoch_choice.valueChanged.connect(self._update_validation_choice) + self.val_interval_choice.valueChanged.connect( + self._update_validation_choice ) learning_rate_vals = [ @@ -237,10 +231,10 @@ def __init__( self.learning_rate_choice.setCurrentIndex(1) - self.augment_choice = ui.make_checkbox("Augment data") + self.augment_choice = ui.CheckBox("Augment data") self.close_buttons = [ - self.make_close_button() for i in range(NUMBER_TABS) + self._make_close_button() for i in range(NUMBER_TABS) ] """Close buttons list for each tab""" @@ -256,21 +250,23 @@ def __init__( w.setVisible(False) for l in self.patch_size_lbl: l.setVisible(False) - self.sampling_container, l = ui.make_container() + self.sampling_container = ui.ContainerWidget() - self.patch_choice = ui.make_checkbox( - "Extract patches from images", func=self.toggle_patch_dims + self.patch_choice = ui.CheckBox( + "Extract patches from images", func=self._toggle_patch_dims ) - self.patch_choice.clicked.connect(self.toggle_patch_dims) + self.patch_choice.clicked.connect(self._toggle_patch_dims) - self.use_transfer_choice = ui.make_checkbox( - "Transfer weights", self.toggle_transfer_param + self.use_transfer_choice = ui.CheckBox( + "Transfer weights", self._toggle_transfer_param ) - self.use_deterministic_choice = ui.make_checkbox( - "Deterministic training", func=self.toggle_deterministic_param + self.use_deterministic_choice = ui.CheckBox( + "Deterministic training", func=self._toggle_deterministic_param + ) + self.box_seed = ui.IntIncrementCounter( + upper=10000000, default=default.deterministic_config.seed ) - self.box_seed = ui.IntIncrementCounter(max=10000000, default=23498) self.lbl_seed = ui.make_label("Seed", self) self.container_seed = ui.combine_blocks( self.box_seed, self.lbl_seed, horizontal=False @@ -281,84 +277,105 @@ def __init__( self.btn_start = ui.Button("Start training", self.start) - self.btn_model_path.setVisible(False) - self.lbl_model_path.setVisible(False) + # self.btn_model_path.setVisible(False) + # self.lbl_model_path.setVisible(False) ############################ ############################ - # tooltips - self.zip_choice.setToolTip( - "Checking this will save a copy of the results as a zip folder" - ) - self.validation_percent_choice.setToolTip( - "Choose the proportion of images to retain for training.\nThe remaining images will be used for validation" - ) - self.epoch_choice.setToolTip( - "The number of epochs to train for.\nThe more you train, the better the model will fit the training data" - ) - self.loss_choice.setToolTip( - "The loss function to use for training.\nSee the list in the inference guide for more info" - ) - self.sample_choice.setToolTip( - "The number of samples to extract per image" - ) - self.batch_choice.setToolTip( - "The batch size to use for training.\n A larger value will feed more images per iteration to the model,\n" - " which is faster and possibly improves performance, but uses more memory" - ) - self.val_interval_choice.setToolTip( - "The number of epochs to perform before validating data.\n " - "The lower the value, the more often the score of the model will be computed and the more often the weights will be saved." - ) - self.learning_rate_choice.setToolTip( - "The learning rate to use in the optimizer. \nUse a lower value if you're using pre-trained weights" - ) - self.augment_choice.setToolTip( - "Check this to enable data augmentation, which will randomly deform, flip and shift the intensity in images" - " to provide a more general dataset. \nUse this if you're extracting more than 10 samples per image" - ) - [ - w.setToolTip("Size of the sample to extract") - for w in self.patch_size_widgets - ] - self.patch_choice.setToolTip( - "Check this to automatically crop your images in smaller, cubic images for training." - "\nShould be used if you have a small dataset (and large images)" - ) - self.use_deterministic_choice.setToolTip( - "Enable deterministic training for reproducibility." - "Using the same seed with all other parameters being similar should yield the exact same results between two runs." - ) - self.use_transfer_choice.setToolTip( - "Use this you want to initialize the model with pre-trained weights or use your own weights." - ) - self.box_seed.setToolTip("Seed to use for RNG") + def set_tooltips(): + # tooltips + self.zip_choice.setToolTip( + "Checking this will save a copy of the results as a zip folder" + ) + self.validation_percent_choice.tooltips = "Choose the proportion of images to retain for training.\nThe remaining images will be used for validation" + self.epoch_choice.tooltips = "The number of epochs to train for.\nThe more you train, the better the model will fit the training data" + self.loss_choice.setToolTip( + "The loss function to use for training.\nSee the list in the inference guide for more info" + ) + self.sample_choice_slider.tooltips = ( + "The number of samples to extract per image" + ) + self.batch_choice.tooltips = ( + "The batch size to use for training.\n A larger value will feed more images per iteration to the model,\n" + " which is faster and possibly improves performance, but uses more memory" + ) + self.val_interval_choice.tooltips = ( + "The number of epochs to perform before validating data.\n " + "The lower the value, the more often the score of the model will be computed and the more often the weights will be saved." + ) + self.learning_rate_choice.setToolTip( + "The learning rate to use in the optimizer. \nUse a lower value if you're using pre-trained weights" + ) + self.augment_choice.setToolTip( + "Check this to enable data augmentation, which will randomly deform, flip and shift the intensity in images" + " to provide a more general dataset. \nUse this if you're extracting more than 10 samples per image" + ) + [ + w.setToolTip("Size of the sample to extract") + for w in self.patch_size_widgets + ] + self.patch_choice.setToolTip( + "Check this to automatically crop your images in smaller, cubic images for training." + "\nShould be used if you have a small dataset (and large images)" + ) + self.use_deterministic_choice.setToolTip( + "Enable deterministic training for reproducibility." + "Using the same seed with all other parameters being similar should yield the exact same results between two runs." + ) + self.use_transfer_choice.setToolTip( + "Use this you want to initialize the model with pre-trained weights or use your own weights." + ) + self.box_seed.setToolTip("Seed to use for RNG") + ############################ ############################ + set_tooltips() + self._build() - self.build() + def _hide_unused(self): + [ + self._hide_io_element(w) + for w in [ + self.layer_choice, + self.folder_choice, + self.label_layer_loader, + self.image_layer_loader, + ] + ] - def toggle_patch_dims(self): + def _update_validation_choice(self): + validation = self.val_interval_choice + max_epoch = self.epoch_choice.value() + + if validation.value() > max_epoch: + validation.setValue(max_epoch) + validation.setMaximum(max_epoch) + elif validation.maximum() < max_epoch: + validation.setMaximum(max_epoch) + + def get_loss(self, key): + """Getter for loss function selected by user""" + return self.loss_dict[key] + + def _toggle_patch_dims(self): if self.patch_choice.isChecked(): [w.setVisible(True) for w in self.patch_size_widgets] [l.setVisible(True) for l in self.patch_size_lbl] - self.sample_choice.setVisible(True) - self.lbl_sample_choice.setVisible(True) + self.sample_choice_slider.container.setVisible(True) self.sampling_container.setVisible(True) else: [w.setVisible(False) for w in self.patch_size_widgets] [l.setVisible(False) for l in self.patch_size_lbl] - self.sample_choice.setVisible(False) - self.lbl_sample_choice.setVisible(False) + self.sample_choice_slider.container.setVisible(False) self.sampling_container.setVisible(False) - def toggle_transfer_param(self): + def _toggle_transfer_param(self): if self.use_transfer_choice.isChecked(): self.custom_weights_choice.setVisible(True) else: self.custom_weights_choice.setVisible(False) - def toggle_deterministic_param(self): + def _toggle_deterministic_param(self): if self.use_deterministic_choice.isChecked(): self.container_seed.setVisible(True) else: @@ -370,19 +387,19 @@ def check_ready(self): Returns: - * True if paths are set correctly (!=[""]) + * True if paths are set correctly * False and displays a warning if not """ - if self.images_filepaths != [""] and self.labels_filepaths != [""]: + if self.images_filepaths != [] and self.labels_filepaths != []: return True else: warnings.formatwarning = utils.format_Warning warnings.warn("Image and label paths are not correctly set") return False - def build(self): + def _build(self): """Builds the layout of the widget and creates the following tabs and prompts: * Model parameters : @@ -417,12 +434,15 @@ def build(self): * Start (see :py:func:`~start`)""" + # for w in self.children(): + # w.setToolTip(f"{w}") + self.setSizePolicy(QSizePolicy.Maximum, QSizePolicy.MinimumExpanding) ######## ################ ######################## # first tab : model and dataset choices - data_tab, data_tab_layout = ui.make_container() + data_tab = ui.ContainerWidget() ################ # first group : Data data_group, data_layout = ui.make_group("Data") @@ -431,36 +451,38 @@ def build(self): data_layout, [ ui.combine_blocks( - self.filetype_choice, self.lbl_filetype + self.filetype_choice, self.filetype_choice.label ), # file extension - ui.combine_blocks( - self.btn_image_files, self.lbl_image_files - ), # volumes - ui.combine_blocks( - self.btn_label_files, self.lbl_label_files - ), # labels - ui.combine_blocks( - self.btn_result_path, self.lbl_result_path - ), # results folder - # ui.combine_blocks(self.model_choice, self.lbl_model_choice), # model choice # TODO : add custom model choice + self.image_filewidget, + self.labels_filewidget, + self.results_filewidget, + # ui.combine_blocks(self.model_choice, self.model_choice.label), # model choice + # TODO : add custom model choice self.zip_choice, # save as zip ], ) - if self.data_path != "": - self.lbl_image_files.setText(self.data_path) + for w in [ + self.image_filewidget, + self.labels_filewidget, + self.results_filewidget, + ]: + w.check_ready() - if self.label_path != "": - self.lbl_label_files.setText(self.label_path) + if self.data_path is not None: + self.image_filewidget.text_field.setText(self.data_path) - if self.results_path != "": - self.lbl_result_path.setText(self.results_path) + if self.label_path is not None: + self.labels_filewidget.text_field.setText(self.label_path) + + if self.results_path is not None: + self.results_filewidget.text_field.setText(self.results_path) data_group.setLayout(data_layout) - data_tab_layout.addWidget(data_group, alignment=ui.LEFT_AL) + data_tab.layout.addWidget(data_group, alignment=ui.LEFT_AL) # end of first group : Data ################ - ui.add_blank(widget=data_tab, layout=data_tab_layout) + ui.add_blank(widget=data_tab, layout=data_tab.layout) ################ transfer_group_w, transfer_group_l = ui.make_group("Transfer learning") @@ -469,30 +491,30 @@ def build(self): [ self.use_transfer_choice, self.custom_weights_choice, - self.weights_path_container, + self.weights_filewidget, ], ) - self.custom_weights_choice.setVisible(False) + self.weights_filewidget.setVisible(False) transfer_group_w.setLayout(transfer_group_l) - data_tab_layout.addWidget(transfer_group_w, alignment=ui.LEFT_AL) + data_tab.layout.addWidget(transfer_group_w, alignment=ui.LEFT_AL) ################ - ui.add_blank(self, data_tab_layout) + ui.add_blank(self, data_tab.layout) ################ - ui.add_to_group( + ui.GroupedWidget.create_single_widget_group( "Validation (%)", - self.validation_percent_choice, - data_tab_layout, + self.validation_percent_choice.container, + data_tab.layout, ) ################ - ui.add_blank(self, data_tab_layout) + ui.add_blank(self, data_tab.layout) ################ # buttons ui.add_widgets( - data_tab_layout, + data_tab.layout, [ - self.make_next_button(), # next + self._make_next_button(), # next ui.add_blank(self), self.close_buttons[0], # close ], @@ -504,11 +526,13 @@ def build(self): ###### ############ ################## - augment_tab_w, augment_tab_l = ui.make_container() + augment_tab_w = ui.ContainerWidget() + augment_tab_l = augment_tab_w.layout ################## # extract patches or not - patch_size_w, patch_size_l = ui.make_container() + patch_size_w = ui.ContainerWidget() + patch_size_l = patch_size_w.layout [ patch_size_l.addWidget(widget, alignment=ui.LEFT_AL) for widgts in zip(self.patch_size_lbl, self.patch_size_widgets) @@ -516,13 +540,13 @@ def build(self): ] # patch sizes patch_size_w.setLayout(patch_size_l) - sampling_w, sampling_l = ui.make_container() + sampling_w = ui.ContainerWidget() + sampling_l = sampling_w.layout ui.add_widgets( sampling_l, [ - self.lbl_sample_choice, - self.sample_choice, # number of samples + self.sample_choice_slider.container, # number of samples ], ) sampling_w.setLayout(sampling_l) @@ -537,13 +561,15 @@ def build(self): right_or_below=self.sampling_container, horizontal=False, ) - ui.add_to_group("Sampling", sampling, augment_tab_l, B=0, T=11) + ui.GroupedWidget.create_single_widget_group( + "Sampling", sampling, augment_tab_l, b=0, t=11 + ) ####################### ####################### ui.add_blank(augment_tab_w, augment_tab_l) ####################### ####################### - ui.add_to_group( + ui.GroupedWidget.create_single_widget_group( "Augmentation", self.augment_choice, augment_tab_l, @@ -558,8 +584,8 @@ def build(self): ####################### augment_tab_l.addWidget( ui.combine_blocks( - left_or_above=self.make_prev_button(), - right_or_below=self.make_next_button(), + left_or_above=self._make_prev_button(), + right_or_below=self._make_next_button(), l=1, ), alignment=ui.LEFT_AL, @@ -573,35 +599,34 @@ def build(self): ###### ############ ################## - train_tab, train_tab_layout = ui.make_container() + train_tab = ui.ContainerWidget() ################## # solo groups for loss and model - ui.add_blank(train_tab, train_tab_layout) + ui.add_blank(train_tab, train_tab.layout) - ui.add_to_group( + ui.GroupedWidget.create_single_widget_group( "Model", self.model_choice, - train_tab_layout, + train_tab.layout, ) # model choice - self.lbl_model_choice.setVisible(False) - - ui.add_blank(train_tab, train_tab_layout) + self.model_choice.label.setVisible(False) - ui.add_to_group( + ui.add_blank(train_tab, train_tab.layout) + ui.GroupedWidget.create_single_widget_group( "Loss", self.loss_choice, - train_tab_layout, + train_tab.layout, ) # loss choice self.lbl_loss_choice.setVisible(False) # end of solo groups for loss and model ################## - ui.add_blank(train_tab, train_tab_layout) + ui.add_blank(train_tab, train_tab.layout) ################## # training params group train_param_group_w, train_param_group_l = ui.make_group( - "Training parameters", R=1, B=5, T=11 + "Training parameters", r=1, b=5, t=11 ) spacing = 20 @@ -609,16 +634,7 @@ def build(self): ui.add_widgets( train_param_group_l, [ - ui.combine_blocks( - self.batch_choice, - self.lbl_batch_choice, - min_spacing=spacing, - horizontal=False, - l=5, - t=5, - r=5, - b=5, - ), # batch size + self.batch_choice.container, # batch size ui.combine_blocks( self.learning_rate_choice, self.lbl_learning_rate_choice, @@ -629,39 +645,23 @@ def build(self): r=5, b=5, ), # learning rate - ui.combine_blocks( - self.epoch_choice, - self.lbl_epoch_choice, - min_spacing=spacing, - horizontal=False, - l=5, - t=5, - r=5, - b=5, - ), # epochs - ui.combine_blocks( - self.val_interval_choice, - self.lbl_val_interv_choice, - min_spacing=spacing, - horizontal=False, - l=5, - t=5, - r=5, - b=5, - ), # validation interval + self.epoch_choice.label, # epochs + self.epoch_choice, + self.val_interval_choice.label, + self.val_interval_choice, # validation interval ], None, ) train_param_group_w.setLayout(train_param_group_l) - train_tab_layout.addWidget(train_param_group_w) + train_tab.layout.addWidget(train_param_group_w) # end of training params group ################## - ui.add_blank(self, train_tab_layout) + ui.add_blank(self, train_tab.layout) ################## # deterministic choice group seed_w, seed_l = ui.make_group( - "Deterministic training", R=1, B=5, T=11 + "Deterministic training", r=1, b=5, t=11 ) ui.add_widgets( seed_l, @@ -672,18 +672,18 @@ def build(self): self.container_seed.setVisible(False) seed_w.setLayout(seed_l) - train_tab_layout.addWidget(seed_w) + train_tab.layout.addWidget(seed_w) # end of deterministic choice group ################## # buttons - ui.add_blank(self, train_tab_layout) + ui.add_blank(self, train_tab.layout) ui.add_widgets( - train_tab_layout, + train_tab.layout, [ - self.make_prev_button(), # previous + self._make_prev_button(), # previous self.btn_start, # start ui.add_blank(self), self.close_buttons[2], @@ -695,7 +695,7 @@ def build(self): # end of tab layouts ui.ScrollArea.make_scrollable( - contained_layout=data_tab_layout, + contained_layout=data_tab.layout, parent=data_tab, min_wh=[200, 300], ) # , max_wh=[200,1000]) @@ -707,7 +707,7 @@ def build(self): ) ui.ScrollArea.make_scrollable( - contained_layout=train_tab_layout, + contained_layout=train_tab.layout, parent=train_tab, min_wh=[200, 300], ) @@ -716,25 +716,27 @@ def build(self): self.addTab(train_tab, "Training") self.setMinimumSize(220, 100) - def show_dialog_lab(self): - """Shows the dialog to load label files in a path, loads them (see :doc:model_framework) and changes the widget - label :py:attr:`self.lbl_label` accordingly""" - f_name = ui.open_file_dialog(self, self._default_path) - - if f_name: - self.label_path = f_name - self.lbl_label.setText(self.label_path) + self._hide_unused() - def show_dialog_dat(self): - """Shows the dialog to load images files in a path, loads them (see :doc:model_framework) and changes the - widget label :py:attr:`self.lbl_dat` accordingly""" - f_name = ui.open_file_dialog(self, self._default_path) - - if f_name: - self.data_path = f_name - self.lbl_dat.setText(self.data_path) + default_results_path = ( + config.TrainingWorkerConfig().results_path_folder + ) + self.results_filewidget.text_field.setText(default_results_path) + self.results_filewidget.check_ready() + self._check_results_path(default_results_path) + self.results_path = default_results_path + + # def _show_dialog_lab(self): + # """Shows the dialog to load label files in a path, loads them (see :doc:model_framework) and changes the widget + # label :py:attr:`self.label_filewidget.text_field` accordingly""" + # folder = ui.open_folder_dialog(self, self._default_path) + # + # if folder: + # self.label_path = folder + # self.labels_filewidget.text_field.setText(self.label_path) def send_log(self, text): + """Sends a message via the Log attribute""" self.log.print_and_log(text) def start(self): @@ -760,7 +762,6 @@ def start(self): """ self.start_time = utils.get_time_filepath() - self.save_as_zip = self.zip_choice.isChecked() if self.stop_requested: self.log.print_and_log("Worker is already stopping !") @@ -782,76 +783,72 @@ def start(self): self.log.print_and_log("Starting...") self.log.print_and_log("*" * 20) - self.reset_loss_plot() - self.num_samples = self.sample_choice.value() - self.batch_size = self.batch_choice.value() - self.val_interval = self.val_interval_choice.value() + self._reset_loss_plot() + try: self.data = self.create_train_dataset_dict() except ValueError as err: self.data = None raise err - self.max_epochs = self.epoch_choice.value() - - validation_percent = self.validation_percent_choice.value() / 100 - print(f"val % : {validation_percent}") + model_config = config.ModelInfo( + name=self.model_choice.currentText() + ) - self.learning_rate = float(self.learning_rate_choice.currentText()) + self.weights_config.path = self.weights_config.path + self.weights_config.custom = self.custom_weights_choice.isChecked() + self.weights_config.use_pretrained = ( + not self.use_transfer_choice.isChecked() + ) - seed_dict = { - "use deterministic": self.use_deterministic_choice.isChecked(), - "seed": self.box_seed.value(), - } + deterministic_config = config.DeterministicConfig( + enabled=self.use_deterministic_choice.isChecked(), + seed=self.box_seed.value(), + ) - self.patch_size = [] - [ - self.patch_size.append(w.value()) - for w in self.patch_size_widgets - ] + validation_percent = ( + self.validation_percent_choice.slider_value / 100 + ) - model_dict = { - "class": self.get_model(self.model_choice.currentText()), - "name": self.model_choice.currentText(), - } - self.results_path_folder = ( + results_path_folder = Path( self.results_path - + f"/{model_dict['name']}_{utils.get_date_time()}" + + f"/{model_config.name}_{utils.get_date_time()}" ) - os.makedirs( - self.results_path_folder, exist_ok=False + Path(results_path_folder).mkdir( + parents=True, exist_ok=False ) # avoid overwrite where possible - if self.use_transfer_choice.isChecked(): - if self.custom_weights_choice.isChecked(): - weights_path = self.weights_path - else: - weights_path = "use_pretrained" - else: - weights_path = None - - self.log.print_and_log( - f"Saving results to : {self.results_path_folder}" - ) + patch_size = [w.value() for w in self.patch_size_widgets] - self.worker = TrainingWorker( + logger.debug("Loading config...") + self.worker_config = config.TrainingWorkerConfig( device=self.get_device(), - model_dict=model_dict, - weights_path=weights_path, - data_dicts=self.data, + model_info=model_config, + weights_info=self.weights_config, + train_data_dict=self.data, validation_percent=validation_percent, - max_epochs=self.max_epochs, + max_epochs=self.epoch_choice.value(), loss_function=self.get_loss(self.loss_choice.currentText()), - learning_rate=self.learning_rate, - val_interval=self.val_interval, - batch_size=self.batch_size, - results_path=self.results_path_folder, + learning_rate=float(self.learning_rate_choice.currentText()), + validation_interval=self.val_interval_choice.value(), + batch_size=self.batch_choice.slider_value, + results_path_folder=str(results_path_folder), sampling=self.patch_choice.isChecked(), - num_samples=self.num_samples, - sample_size=self.patch_size, + num_samples=self.sample_choice_slider.slider_value, + sample_size=patch_size, do_augmentation=self.augment_choice.isChecked(), - deterministic=seed_dict, + deterministic_config=deterministic_config, + ) # TODO(cyril) continue to put params in config + + self.config = config.TrainerConfig( + save_as_zip=self.zip_choice.isChecked() + ) + + self.log.print_and_log( + f"Saving results to : {results_path_folder}" ) + + self.worker = TrainingWorker(config=self.worker_config) self.worker.set_download_log(self.log) [btn.setVisible(False) for btn in self.close_buttons] @@ -861,9 +858,7 @@ def start(self): self.worker.started.connect(self.on_start) - self.worker.yielded.connect( - lambda data: self.on_yield(data, widget=self) - ) + self.worker.yielded.connect(partial(self.on_yield)) self.worker.finished.connect(self.on_finish) self.worker.errored.connect(self.on_error) @@ -895,38 +890,43 @@ def on_finish(self): self.log.print_and_log("*" * 20) self.log.print_and_log(f"\nWorker finished at {utils.get_time()}") - self.log.print_and_log(f"Saving in {self.results_path_folder}") + self.log.print_and_log( + f"Saving in {self.worker_config.results_path_folder}" + ) self.log.print_and_log(f"Saving last loss plot") + plot_name = self.worker_config.results_path_folder / Path( + f"final_metric_plots_{utils.get_time_filepath()}.png" + ) if self.canvas is not None: self.canvas.figure.savefig( - ( - self.results_path_folder - + f"/final_metric_plots_{utils.get_time_filepath()}.png" - ), + plot_name, format="png", ) self.log.print_and_log("Saving log") - self.save_log_to_path(self.results_path_folder) + self.save_log_to_path(self.worker_config.results_path_folder) self.log.print_and_log("Done") self.log.print_and_log("*" * 10) - self.make_csv() + self._make_csv() self.btn_start.setText("Start") [btn.setVisible(True) for btn in self.close_buttons] - del self.worker - self.worker = None - self.empty_cuda_cache() + # del self.worker - if self.save_as_zip: + # self.empty_cuda_cache() + + if self.config.save_as_zip: shutil.make_archive( - self.results_path_folder, "zip", self.results_path_folder + self.worker_config.results_path_folder, + "zip", + self.worker_config.results_path_folder, ) + self.worker = None # if zipfile.is_zipfile(self.results_path_folder+".zip"): # if not shutil.rmtree.avoids_symlink_attacks: @@ -934,7 +934,7 @@ def on_finish(self): # shutil.rmtree(self.results_path_folder) - self.results_path_folder = "" + # self.results_path_folder = "" # self.clean_cache() # trying to fix memory leak @@ -942,37 +942,70 @@ def on_error(self): """Catches errored signal from worker""" self.log.print_and_log(f"WORKER ERRORED at {utils.get_time()}") self.worker = None - self.empty_cuda_cache() + # self.empty_cuda_cache() # self.clean_cache() - @staticmethod - def on_yield(data, widget): - # print( + def on_yield(self, report: TrainingReport): + # logger.info( # f"\nCatching results : for epoch {data['epoch']}, # loss is {data['losses']} and validation is {data['val_metrics']}" # ) - if data["plot"]: - widget.progress.setValue( - 100 * (data["epoch"] + 1) // widget.max_epochs + if report == TrainingReport(): + return + + if report.show_plot: + + try: + layer_name = "Training_checkpoint_" + rge = range(len(report.images)) + + self.log.print_and_log(len(report.images)) + + if report.epoch + 1 == self.worker_config.validation_interval: + for i in rge: + layer = self._viewer.add_image( + report.images[i], + name=layer_name + str(i), + colormap="twilight", + ) + self.result_layers.append(layer) + else: + for i in rge: + if layer_name + str(i) not in [ + layer.name for layer in self.result_layers + ]: + new_layer = self._viewer.add_image( + report.images[i], + name=layer_name + str(i), + colormap="twilight", + ) + self.result_layers.append(new_layer) + self.result_layers[i].data = report.images[i] + self.result_layers[i].refresh() + except Exception as e: + logger.error(e) + + self.progress.setValue( + 100 * (report.epoch + 1) // self.worker_config.max_epochs ) - widget.update_loss_plot(data["losses"], data["val_metrics"]) - widget.loss_values = data["losses"] - widget.validation_values = data["val_metrics"] + self.update_loss_plot(report.loss_values, report.validation_metric) + self.loss_values = report.loss_values + self.validation_values = report.validation_metric - if widget.stop_requested: - widget.log.print_and_log( + if self.stop_requested: + self.log.print_and_log( "Saving weights from aborted training in results folder" ) torch.save( - data["weights"], - os.path.join( - widget.results_path_folder, + report.weights, + Path(self.worker_config.results_path_folder) + / Path( f"latest_weights_aborted_training_{utils.get_time_filepath()}.pth", ), ) - widget.log.print_and_log("Saving complete") - widget.stop_requested = False + self.log.print_and_log("Saving complete") + self.stop_requested = False # def clean_cache(self): # """Attempts to clear memory after training""" @@ -988,9 +1021,9 @@ def on_yield(data, widget): # if self.get_device(show=False).type == "cuda": # self.empty_cuda_cache() - def make_csv(self): + def _make_csv(self): - size_column = range(1, self.max_epochs + 1) + size_column = range(1, self.worker_config.max_epochs + 1) if len(self.loss_values) == 0 or self.loss_values is None: warnings.warn("No loss values to add to csv !") @@ -1001,11 +1034,15 @@ def make_csv(self): "epoch": size_column, "loss": self.loss_values, "validation": utils.fill_list_in_between( - self.validation_values, self.val_interval - 1, "" + self.validation_values, + self.worker_config.validation_interval - 1, + "", )[: len(size_column)], } ) - path = os.path.join(self.results_path_folder, "training.csv") + path = Path(self.worker_config.results_path_folder) / Path( + "training.csv" + ) self.df.to_csv(path, index=False) def plot_loss(self, loss, dice_metric): @@ -1021,10 +1058,15 @@ def plot_loss(self, loss, dice_metric): # self.train_loss_plot.set_ylim(0, 1) # update metrics - x = [self.val_interval * (i + 1) for i in range(len(dice_metric))] + x = [ + self.worker_config.validation_interval * (i + 1) + for i in range(len(dice_metric)) + ] y = dice_metric - epoch_min = (np.argmax(y) + 1) * self.val_interval + epoch_min = ( + np.argmax(y) + 1 + ) * self.worker_config.validation_interval dice_min = np.max(y) self.dice_metric_plot.plot(x, y, zorder=1) @@ -1047,14 +1089,16 @@ def plot_loss(self, loss, dice_metric): ) self.canvas.draw_idle() - plot_path = self.results_path_folder + "/Loss_plots" - os.makedirs(plot_path, exist_ok=True) + plot_path = self.worker_config.results_path_folder / Path( + "../Loss_plots" + ) + Path(plot_path).mkdir(parents=True, exist_ok=True) if self.canvas is not None: self.canvas.figure.savefig( - ( + str( plot_path - + f"/checkpoint_metric_plots_{utils.get_date_time()}.png" + / f"checkpoint_metric_plots_{utils.get_date_time()}.png" ), format="png", ) @@ -1070,9 +1114,9 @@ def update_loss_plot(self, loss, metric): """ epoch = len(loss) - if epoch < self.val_interval * 2: + if epoch < self.worker_config.validation_interval * 2: return - elif epoch == self.val_interval * 2: + elif epoch == self.worker_config.validation_interval * 2: bckgrd_color = (0, 0, 0, 0) # '#262930' with plt.style.context("dark_background"): @@ -1103,9 +1147,17 @@ def update_loss_plot(self, loss, metric): # tab_index = self.addTab(self.canvas, "Loss plot") # self.setCurrentIndex(tab_index) - self.plot_dock = self._viewer.window.add_dock_widget( - self.canvas, name="Loss plots", area="bottom" - ) + try: + self.plot_dock = self._viewer.window.add_dock_widget( + self.canvas, name="Loss plots", area="bottom" + ) + self.plot_dock._close_btn = False + except AttributeError as e: + logger.error(e) + logger.error( + "Plot dock widget could not be added. Should occur in testing only" + ) + self.docked_widgets.append(self.plot_dock) self.plot_loss(loss, metric) else: @@ -1116,7 +1168,7 @@ def update_loss_plot(self, loss, metric): self.plot_loss(loss, metric) - def reset_loss_plot(self): + def _reset_loss_plot(self): if ( self.train_loss_plot is not None and self.dice_metric_plot is not None diff --git a/napari_cellseg3d/code_plugins/plugin_review.py b/napari_cellseg3d/code_plugins/plugin_review.py new file mode 100644 index 00000000..1ad84667 --- /dev/null +++ b/napari_cellseg3d/code_plugins/plugin_review.py @@ -0,0 +1,475 @@ +import warnings +from pathlib import Path + +import matplotlib.pyplot as plt +import napari +import numpy as np +from magicgui import magicgui +from matplotlib.backends.backend_qt5agg import ( + FigureCanvasQTAgg as FigureCanvas, +) +from matplotlib.figure import Figure + +# Qt +from qtpy.QtWidgets import QLineEdit +from qtpy.QtWidgets import QSizePolicy +from tifffile import imwrite + +# local +from napari_cellseg3d import config +from napari_cellseg3d import interface as ui +from napari_cellseg3d import utils +from napari_cellseg3d.code_plugins.plugin_base import BasePluginSingleImage +from napari_cellseg3d.code_plugins.plugin_review_dock import Datamanager + +warnings.formatwarning = utils.format_Warning +logger = utils.LOGGER + + +class Reviewer(BasePluginSingleImage, metaclass=ui.QWidgetSingleton): + """A plugin for selecting volumes and labels file and launching the review process. + Inherits from : :doc:`plugin_base`""" + + def __init__(self, viewer: "napari.viewer.Viewer", parent=None): + """Creates a Reviewer plugin with several buttons : + + * Open file prompt to select volumes directory + + * Open file prompt to select labels directory + + * A dropdown menu with a choice of png or tif filetypes + + * A checkbox if you want to create a new status csv for the dataset + + * A button to launch the review process + """ + + super().__init__( + viewer, + parent, + loads_images=True, + loads_labels=True, + has_results=True, + ) + + # self._viewer = viewer # should not be needed + self.config = config.ReviewConfig() + self.enable_utils_menu() + + ####################### + # UI + self.io_panel = self._build_io_panel() + + self.layer_choice.setText("New review") + self.folder_choice.setText("Existing review") + + self.csv_textbox = QLineEdit(self) + self.csv_textbox.setSizePolicy(QSizePolicy.Fixed, QSizePolicy.Fixed) + + self.new_csv_choice = ui.CheckBox("Create new dataset ?") + + self.btn_start = ui.Button("Start reviewing", self.run_review, self) + + self.lbl_mod = ui.make_label("Name", self) + + self.warn_label = ui.make_label( + "WARNING : You already have a review session running.\n" + "Launching another will close the current one,\n" + " make sure to save your work beforehand", + None, + ) + + self.anisotropy_widgets = ui.AnisotropyWidgets( + self, default_x=1.5, default_y=1.5, default_z=5 + ) + + ########################### + # tooltips + self.csv_textbox.setToolTip("Name of the csv results file") + self.new_csv_choice.setToolTip( + "Ignore any pre-existing csv with the specified name and create a new one" + ) + ########################### + + self._build() + + self.image_filewidget.text_field.textChanged.connect( + self._update_results_path + ) + print(f"{self}") + + def _update_results_path(self): + p = self.image_filewidget.text_field.text() + if p is not None and p != "" and Path(p).is_file(): + self.results_filewidget.text_field.setText(str(Path(p).parent)) + + def _build(self): + """Build buttons in a layout and add them to the napari Viewer""" + + self.setSizePolicy(QSizePolicy.Maximum, QSizePolicy.MinimumExpanding) + + tab = ui.ContainerWidget(0, 0, 1, 1) + layout = tab.layout + + # ui.add_blank(self, layout) + ########################### + self.filetype_choice.setVisible(False) + layout.addWidget(self.io_panel) + self._set_io_visibility() + self.layer_choice.toggle() + ########################### + ui.add_blank(self, layout) + ########################### + ui.GroupedWidget.create_single_widget_group( + "Image parameters", self.anisotropy_widgets, layout + ) + ########################### + ui.add_blank(self, layout) + ########################### + csv_param_w, csv_param_l = ui.make_group("CSV parameters") + + ui.add_widgets( + csv_param_l, + [ + ui.combine_blocks( + self.csv_textbox, + self.lbl_mod, + horizontal=False, + l=5, + t=0, + r=5, + b=5, + ), + self.new_csv_choice, + self.results_filewidget, + ], + ) + + # self._hide_io_element(self.results_filewidget, self.folder_choice) + # self._show_io_element(self.results_filewidget) + + self.results_filewidget.text_field.setText( + str( + Path.home() / Path("cellseg3d/review") + ) # TODO(cyril) : check proper behaviour + ) + + csv_param_w.setLayout(csv_param_l) + layout.addWidget(csv_param_w) + ########################### + ui.add_blank(self, layout) + ########################### + + ui.add_widgets(layout, [self.btn_start, self._make_close_button()]) + + ui.ScrollArea.make_scrollable( + contained_layout=layout, parent=tab, min_wh=[190, 300] + ) + + self.addTab(tab, "Review") + + self.setMinimumSize(180, 100) + # self.show() + # self._viewer.window.add_dock_widget(self, name="Reviewer", area="right") + self.results_filewidget.check_ready() + self.results_path = self.results_filewidget.text_field.text() + + def check_image_data(self): + """Checks that images are present and that sizes match""" + + cfg = self.config + + if cfg.image is None: + raise ValueError("Review requires at least one image") + + if cfg.labels is not None: + if cfg.image.shape != cfg.labels.shape: + warnings.warn( + "Image and label dimensions do not match ! Please load matching images" + ) + + def _prepare_data(self): + + if self.layer_choice.isChecked(): + self.config.image = self.image_layer_loader.layer_data() + self.config.labels = self.label_layer_loader.layer_data() + else: + self.config.image = utils.load_images( + self.image_filewidget.text_field.text() + ) + self.config.labels = utils.load_images( + self.labels_filewidget.text_field.text() + ) + + self.check_image_data() + self._check_results_path(self.results_filewidget.text_field.text()) + + self.config.csv_path = self.results_filewidget.text_field.text() + self.config.model_name = self.csv_textbox.text() + + self.config.new_csv = self.new_csv_choice.isChecked() + self.config.filetype = self.filetype_choice.currentText() + + if self.anisotropy_widgets.enabled: + zoom = self.anisotropy_widgets.scaling_zyx() + else: + zoom = [1, 1, 1] + self.config.zoom_factor = zoom + + def run_review(self): + + """Launches review process by loading the files from the chosen folders, + and adds several widgets to the napari Viewer. + If the review process has been launched once before, + closes the window entirely and launches the review process in a fresh window. + + TODO: + + * Save work done before leaving + + See launch_review + + Returns: + napari.viewer.Viewer: self.viewer + """ + + print("New review session\n" + "*" * 20) + previous_viewer = self._viewer + try: + + self._prepare_data() + + self._viewer, self.docked_widgets = self.launch_review() + self._reset() + previous_viewer.close() + except ValueError as e: + warnings.warn( + f"An exception occurred : {e}. Please ensure you have entered all required parameters." + ) + + def _reset(self): + self.remove_docked_widgets() + + def launch_review(self): + """Launch the review process, loading the original image, the labels & the raw labels (from prediction) + in the viewer. + + Adds several widgets to the viewer : + + * A data manager widget that lets the user choose a directory to save the labels in, and a save button to quickly + save. + + * A "checked/not checked" button to let the user confirm that a slice has been checked or not. + + + **This writes in a csv file under the corresponding slice the slice status (i.e. checked or not checked) + to allow tracking of the review process for a given dataset.** + + * A plot widget that, when shift-clicking on the volume or label, + displays the chosen location on several projections (x-y, y-z, x-z), + to allow the user to have a better all-around view of the object + and determine whether it should be labeled or not. + + Returns : list of all docked widgets + """ + images_original = self.config.image + if self.config.labels is not None: + base_label = self.config.labels + else: + base_label = np.zeros_like(images_original) + + viewer = napari.Viewer() + + viewer.scale_bar.visible = True + + viewer.add_image( + images_original, + name="volume", + colormap="inferno", + contrast_limits=[200, 1000], + scale=self.config.zoom_factor, + ) # anything bigger than 255 will get mapped to 255... they did it like this because it must have rgb images + viewer.add_labels( + base_label, name="labels", seed=0.6, scale=self.config.zoom_factor + ) + + @magicgui( + dirname={"mode": "d", "label": "Save labels in... "}, + call_button="Save", + # call_button_2="Save & quit", + ) + def file_widget( + dirname=Path(self.config.csv_path), + ): # file name where to save annotations + # """Take a filename and do something with it.""" + # logger.debug(("The filename is:", dirname) + + dirname = Path(self.config.csv_path) + # def saver(): + out_dir = file_widget.dirname.value + + # logger.debug("The directory is:", out_dir) + + def quicksave(): + # if not self.config.as_stack: + if viewer.layers["labels"] is not None: + name = str(Path(out_dir) / "labels_reviewed.tif") + dat = viewer.layers["labels"].data + imwrite(name, data=dat) + + # else: + # if viewer.layers["labels"] is not None: + # dir_name = os.path.join(str(out_dir), "labels_reviewed") + # dat = viewer.layers["labels"].data + # utils.save_stack( + # dat, dir_name, filetype=self.config.filetype + # ) + + return dirname, quicksave() + + file_widget_dock = viewer.window.add_dock_widget( + file_widget, name=" ", area="bottom" + ) + file_widget_dock._close_btn = False + + with plt.style.context("dark_background"): + canvas = FigureCanvas(Figure(figsize=(3, 15))) + + xy_axes = canvas.figure.add_subplot(3, 1, 1) + canvas.figure.suptitle( + "Shift-click on image for plot \n", fontsize=8 + ) + xy_axes.imshow(np.zeros((100, 100), np.int16)) + xy_axes.scatter(50, 50, s=10, c="green", alpha=0.25) + xy_axes.set_xlabel("x axis") + xy_axes.set_ylabel("y axis") + yz_axes = canvas.figure.add_subplot(3, 1, 2) + yz_axes.imshow(np.zeros((100, 100), np.int16)) + yz_axes.scatter(50, 50, s=10, c="green", alpha=0.25) + yz_axes.set_xlabel("y axis") + yz_axes.set_ylabel("z axis") + zx_axes = canvas.figure.add_subplot(3, 1, 3) + zx_axes.imshow(np.zeros((100, 100), np.int16)) + zx_axes.scatter(50, 50, s=10, c="green", alpha=0.25) + zx_axes.set_xlabel("x axis") + zx_axes.set_ylabel("z axis") + + # canvas.figure.tight_layout() + canvas.figure.subplots_adjust( + left=0.1, bottom=0.1, right=1, top=0.95, wspace=0, hspace=0.4 + ) + + canvas.setSizePolicy(QSizePolicy.Minimum, QSizePolicy.Maximum) + + canvas_dock = viewer.window.add_dock_widget( + canvas, name=" ", area="right" + ) + canvas_dock._close_btn = False + + @viewer.mouse_drag_callbacks.append + def update_canvas_canvas(viewer, event): + + if "shift" in event.modifiers: + try: + cursor_position = np.round(viewer.cursor.position).astype( + int + ) + logger.debug(f"plot @ {cursor_position}") + + cropped_volume = crop_volume_around_point( + [ + cursor_position[0], + cursor_position[1], + cursor_position[2], + ], + viewer.layers["volume"], + self.config.zoom_factor, + ) + + ########## + ########## + # DEBUG + # viewer.add_image(cropped_volume, name="DEBUG_crop_plot") + + xy_axes.imshow( + cropped_volume[50], cmap="inferno", vmin=200, vmax=2000 + ) + yz_axes.imshow( + cropped_volume.transpose(1, 0, 2)[50], + cmap="inferno", + vmin=200, + vmax=2000, + ) + zx_axes.imshow( + cropped_volume.transpose(2, 0, 1)[50], + cmap="inferno", + vmin=200, + vmax=2000, + ) + canvas.draw_idle() + except Exception as e: + logger.error(e) + + # Qt widget defined in docker.py + dmg = Datamanager(parent=viewer) + dmg.prepare( + self.config.csv_path, + self.config.filetype, + self.config.model_name, + self.config.new_csv, + ) + datamananger = viewer.window.add_dock_widget( + dmg, name=" ", area="left" + ) + datamananger._close_btn = False + + def update_button(axis_event): + + slice_num = axis_event.value[0] + logger.debug(f"slice num is {slice_num}") + dmg.update_dm(slice_num) + + viewer.dims.events.current_step.connect(update_button) + + def crop_volume_around_point(points, layer, zoom_factor): + if zoom_factor != [1, 1, 1]: + data = np.array(layer.data, dtype=np.int16) + volume = utils.resize(data, zoom_factor) + # image = ndimage.zoom(layer.data, zoom_factor, mode="nearest") # cleaner but much slower... + else: + volume = layer.data + + min_coordinates = [point - 50 for point in points] + max_coordinates = [point + 50 for point in points] + inferior_bound = [ + min_coordinate if min_coordinate < 0 else 0 + for min_coordinate in min_coordinates + ] + superior_bound = [ + max_coordinate - volume.shape[i] + if volume.shape[i] < max_coordinate + else 0 + for i, max_coordinate in enumerate(max_coordinates) + ] + + crop_slice = tuple( + slice(np.maximum(0, min_coordinate), max_coordinate) + for min_coordinate, max_coordinate in zip( + min_coordinates, max_coordinates + ) + ) + + # if self.config.as_stack: + # crop_temp = volume[crop_slice].persist().compute() + # else: + crop_temp = volume[crop_slice] + + cropped_volume = np.zeros((100, 100, 100), np.int16) + cropped_volume[ + -inferior_bound[0] : 100 - superior_bound[0], + -inferior_bound[1] : 100 - superior_bound[1], + -inferior_bound[2] : 100 - superior_bound[2], + ] = crop_temp + return cropped_volume + + return viewer, [file_widget, canvas, dmg] diff --git a/napari_cellseg3d/plugin_dock.py b/napari_cellseg3d/code_plugins/plugin_review_dock.py similarity index 87% rename from napari_cellseg3d/plugin_dock.py rename to napari_cellseg3d/code_plugins/plugin_review_dock.py index eefae7d6..02a1c474 100644 --- a/napari_cellseg3d/plugin_dock.py +++ b/napari_cellseg3d/code_plugins/plugin_review_dock.py @@ -1,6 +1,6 @@ -import os import warnings -from datetime import datetime, timedelta +from datetime import datetime +from datetime import timedelta from pathlib import Path import napari @@ -50,10 +50,11 @@ def __init__(self, parent: "napari.viewer.Viewer"): self.time_label.setVisible(False) self.pause_box = ui.CheckBox( - "Pause", self.pause_timer, parent=self, fixed=True + "Pause timer", self.pause_timer, parent=self, fixed=True ) - io_panel, io_layout = ui.make_container() + io_panel = ui.ContainerWidget() + io_layout = io_panel.layout io_layout.addWidget(self.button, alignment=ui.ABS_AL) io_layout.addWidget( ui.combine_blocks( @@ -72,8 +73,8 @@ def __init__(self, parent: "napari.viewer.Viewer"): # self.setMaximumHeight(GUI_MAXIMUM_HEIGHT) # self.setMaximumWidth(GUI_MAXIMUM_WIDTH) - self.df = "" - self.csv_path = "" + self.df = None + self.csv_path = None self.slice_num = 0 self.filetype = "" self.filename = None @@ -117,7 +118,7 @@ def update_time_csv(self): self.df.at[0, "time"] = str_time self.df.to_csv(self.csv_path) - def prepare(self, label_dir, filetype, model_type, checkbox, as_folder): + def prepare(self, label_dir, filetype, model_type, checkbox): """Initialize the Datamanager, which loads the csv file and updates it with the index of the current slice. @@ -133,11 +134,11 @@ def prepare(self, label_dir, filetype, model_type, checkbox, as_folder): print("csv path try :") print(label_dir) self.filetype = filetype - self.as_folder = as_folder if not self.as_folder: - self.filename = os.path.split(label_dir)[1] - label_dir = os.path.split(label_dir)[0] + p = Path(label_dir) + self.filename = p.name + label_dir = p.parent print("Loading single image") print(self.filename) print(label_dir) @@ -146,7 +147,7 @@ def prepare(self, label_dir, filetype, model_type, checkbox, as_folder): print(self.csv_path, checkbox) # print(self.viewer.dims.current_step[0]) - self.update(self.viewer.dims.current_step[0]) + self.update_dm(self.viewer.dims.current_step[0]) def load_csv(self, label_dir, model_type, checkbox): """ @@ -164,7 +165,7 @@ def load_csv(self, label_dir, model_type, checkbox): # label_dir = os.path.dirname(label_dir) print("label dir") print(label_dir) - csvs = sorted(list(Path(label_dir).glob(f"{model_type}*.csv"))) + csvs = sorted(list(Path(str(label_dir)).glob(f"{model_type}*.csv"))) if len(csvs) == 0: df, csv_path = self.create_csv( label_dir, model_type @@ -176,10 +177,7 @@ def load_csv(self, label_dir, model_type, checkbox): csv_path = ( csv_path.split("_train")[0] + "_train" - + str( - int(os.path.splitext(csv_path.split("_train")[1])[0]) - + 1 - ) + + str(int(Path(csv_path.split("_train")[1]).parent) + 1) + ".csv" ) # adds 1 to current csv name number df.to_csv(csv_path) @@ -210,7 +208,9 @@ def create_csv(self, label_dir, model_type, filename=None): labels = sorted( list( path.name - for path in Path(label_dir).glob("./*" + self.filetype) + for path in Path(str(label_dir)).glob( + "./*" + self.filetype + ) ) ) else: @@ -230,7 +230,7 @@ def create_csv(self, label_dir, model_type, filename=None): ) df.at[0, "time"] = "00:00:00" - csv_path = os.path.join(label_dir, f"{model_type}_train0.csv") + csv_path = str(Path(label_dir) / Path(f"{model_type}_train0.csv")) print("csv path for create") print(csv_path) df.to_csv(csv_path) @@ -243,13 +243,12 @@ def update_button(self): self.df.at[self.df.index[self.slice_num], "train"] ) # puts button values at value of 1st csv item - def update(self, slice_num): + def update_dm(self, slice_num): """Updates the Datamanager with the index of the current slice, and updates the text with the status contained in the csv (e.g. checked/not checked). Args: slice_num (int): index of the current slice - """ self.slice_num = slice_num self.update_time_csv() @@ -280,22 +279,6 @@ def button_func(self): # updates csv every time you press button... self.df.at[self.df.index[self.slice_num], "train"] = "Not checked" self.df.to_csv(self.csv_path) - # def move_data(self): - # shutil.copy( - # self.df.at[self.df.index[self.slice_num], "filename"], - # self.train_data_dir, - # ) - # - # def delete_data(self): - # os.remove( - # os.path.join( - # self.train_data_dir, - # os.path.basename( - # self.df.at[self.df.index[self.slice_num], "filename"] - # ), - # ) - # ) - # # def check_all_data_and_mod(self): # for i in range(len(self.df)): # if self.df.at[self.df.index[i], "train"] == "Checked": diff --git a/napari_cellseg3d/code_plugins/plugin_utilities.py b/napari_cellseg3d/code_plugins/plugin_utilities.py new file mode 100644 index 00000000..6c726c25 --- /dev/null +++ b/napari_cellseg3d/code_plugins/plugin_utilities.py @@ -0,0 +1,120 @@ +import napari + +# Qt +from qtpy.QtCore import qInstallMessageHandler +from qtpy.QtWidgets import QLayout +from qtpy.QtWidgets import QSizePolicy +from qtpy.QtWidgets import QVBoxLayout +from qtpy.QtWidgets import QWidget + +# local +import napari_cellseg3d.interface as ui +from napari_cellseg3d.code_plugins.plugin_convert import AnisoUtils +from napari_cellseg3d.code_plugins.plugin_convert import RemoveSmallUtils +from napari_cellseg3d.code_plugins.plugin_convert import ThresholdUtils +from napari_cellseg3d.code_plugins.plugin_convert import ToInstanceUtils +from napari_cellseg3d.code_plugins.plugin_convert import ToSemanticUtils +from napari_cellseg3d.code_plugins.plugin_crop import Cropping + +UTILITIES_WIDGETS = { + "Crop": Cropping, + "Correct anisotropy": AnisoUtils, + "Remove small objects": RemoveSmallUtils, + "Convert to instance labels": ToInstanceUtils, + "Convert to semantic labels": ToSemanticUtils, + "Threshold": ThresholdUtils, +} + + +class Utilities(QWidget, metaclass=ui.QWidgetSingleton): + def __init__(self, viewer: "napari.viewer.Viewer"): + super().__init__() + self._viewer = viewer + + attr_names = ["crop", "aniso", "small", "inst", "sem", "thresh"] + self._create_utils_widgets(attr_names) + + # self.crop = Cropping(self._viewer) + # self.sem = ToSemanticUtils(self._viewer) + # self.aniso = AnisoUtils(self._viewer) + # self.inst = ToInstanceUtils(self._viewer) + # self.thresh = ThresholdUtils(self._viewer) + # self.small = RemoveSmallUtils(self._viewer) + + self.utils_choice = ui.DropdownMenu( + UTILITIES_WIDGETS.keys(), label="Utilities" + ) + + self._build() + + self.utils_choice.currentIndexChanged.connect(self._update_visibility) + # self._dock_util() + self._update_visibility() + qInstallMessageHandler(ui.handle_adjust_errors_wrapper(self)) + + def _build(self): + + layout = QVBoxLayout() + ui.add_widgets(layout, self.utils_widgets) + layout.addWidget(self.utils_choice.label, alignment=ui.BOTT_AL) + layout.addWidget(self.utils_choice, alignment=ui.BOTT_AL) + + layout.setSizeConstraint(QLayout.SetFixedSize) + self.setLayout(layout) + self.setMinimumHeight(1000) + self.setSizePolicy(QSizePolicy.MinimumExpanding, QSizePolicy.Fixed) + self._update_visibility() + + def _create_utils_widgets(self, names): + for key, name in zip(UTILITIES_WIDGETS, names): + setattr(self, name, UTILITIES_WIDGETS[key](self._viewer)) + + self.utils_widgets = [] + for n in names: + wid = getattr(self, n) + self.utils_widgets.append(wid) + + if len(self.utils_widgets) != len(UTILITIES_WIDGETS.keys()): + raise RuntimeError( + "One or several utility widgets are missing/erroneous" + ) + # TODO how to auto-update list based on UTILITIES_WIDGETS ? + + def _update_visibility(self): + widget_class = UTILITIES_WIDGETS[self.utils_choice.currentText()] + # print("vis. updated") + # print(self.utils_widgets) + self._hide_all() + for i, w in enumerate(self.utils_widgets): + if isinstance(w, widget_class): + w.setVisible(True) + w.adjustSize() + # else: + # self.utils_widgets[i].setVisible(False) + + def _hide_all(self): + for w in self.utils_widgets: + w.setVisible(False) + # self.setWindowState(Qt.WindowMaximized) + # if self.parent() is not None: + # print(self.parent().windowState()) + # print(int(self.parent().parent().windowState())) + # self.parent().parent().showNormal() + # self.parent().parent().showMaximized() + # state = int(self.parent().parent().windowState()) + # if state == 0: + # self.parent().parent().adjustSize() + # elif state == 2: + # self.parent().parent().showNormal() + # self.parent().parent().showMaximized() + # pass + + # def _dock_util(self): + # for i in range(len(self.utils_widgets)): + # docked = self._viewer.window.add_dock_widget( + # widget=self.utils_widgets[i] + # ) + # self.docked_widgets.append(docked) + + # def remove_from_viewer(self): + # self._viewer.window.remove_dock_widget(self) diff --git a/napari_cellseg3d/config.py b/napari_cellseg3d/config.py new file mode 100644 index 00000000..eba01b07 --- /dev/null +++ b/napari_cellseg3d/config.py @@ -0,0 +1,221 @@ +import datetime +import warnings +from dataclasses import dataclass +from pathlib import Path +from typing import List +from typing import Optional + +import napari +import numpy as np + +from napari_cellseg3d.code_models.model_instance_seg import binary_connected +from napari_cellseg3d.code_models.model_instance_seg import binary_watershed + +# from napari_cellseg3d.models import model_TRAILMAP as TRAILMAP +from napari_cellseg3d.code_models.models import model_SegResNet as SegResNet +from napari_cellseg3d.code_models.models import model_SwinUNetR as SwinUNetR +from napari_cellseg3d.code_models.models import ( + model_TRAILMAP_MS as TRAILMAP_MS, +) +from napari_cellseg3d.code_models.models import model_VNet as VNet +from napari_cellseg3d.utils import LOGGER + +logger = LOGGER + +# TODO(cyril) DOCUMENT !!! and add default values +# TODO(cyril) add JSON load/save + +MODEL_LIST = { + "SegResNet": SegResNet, + "VNet": VNet, + # "TRAILMAP": TRAILMAP, + "TRAILMAP_MS": TRAILMAP_MS, + "SwinUNetR": SwinUNetR, + # "test" : DO NOT USE, reserved for testing +} + +INSTANCE_SEGMENTATION_METHOD_LIST = { + "Watershed": binary_watershed, + "Connected components": binary_connected, +} + +WEIGHTS_DIR = str( + Path(__file__).parent.resolve() / Path("code_models/models/pretrained") +) + + +################ +# Review + + +@dataclass +class ReviewConfig: + image: np.array = None + labels: np.array = None + csv_path: str = Path.home() / Path("cellseg3d/review") + model_name: str = "" + new_csv: bool = True + filetype: str = ".tif" + zoom_factor: List[int] = None + + +@dataclass # TODO create custom reader for JSON to load project +class ReviewSession: + project_name: str + image_path: str + labels_path: str + csv_path: str + aniso_zoom: List[int] + time_taken: datetime.timedelta + + +################ +# Model & weights + + +@dataclass +class ModelInfo: + """Dataclass recording model info : + - name (str): name of the model""" + + name: str = next(iter(MODEL_LIST)) + model_input_size: Optional[List[int]] = None + + def get_model(self): + try: + return MODEL_LIST[self.name] + except KeyError as e: + msg = f"Model {self.name} is not defined" + warnings.warn(msg) + logger.warning(msg) + raise KeyError(e) + + @staticmethod + def get_model_name_list(): + logger.info( + f"Model list :\n" + str(f"{name}\n" for name in MODEL_LIST.keys()) + ) + return MODEL_LIST.keys() + + +@dataclass +class WeightsInfo: + path: str = WEIGHTS_DIR + custom: bool = False + use_pretrained: Optional[bool] = False + + +################ +# Post processing & instance segmentation + + +@dataclass +class Thresholding: + enabled: bool = True + threshold_value: float = 0.8 + + +@dataclass +class Zoom: + enabled: bool = True + zoom_values: List[float] = None + + +@dataclass +class InstanceSegConfig: + enabled: bool = False + method: str = None + threshold: Thresholding = Thresholding(enabled=False, threshold_value=0.85) + small_object_removal_threshold: Thresholding = Thresholding( + enabled=True, threshold_value=20 + ) + + +@dataclass +class PostProcessConfig: + zoom: Zoom = Zoom() + thresholding: Thresholding = Thresholding() + instance: InstanceSegConfig = InstanceSegConfig() + + +################ +# Inference configs + + +@dataclass +class SlidingWindowConfig: + window_size: int = None + window_overlap: float = 0.25 + + def is_enabled(self): + return self.window_size is not None + + +@dataclass +class InfererConfig: + """Class to record params for Inferer plugin""" + + model_info: ModelInfo = None + show_results: bool = False + show_results_count: int = 5 + show_original: bool = True + anisotropy_resolution: List[int] = None + + +@dataclass +class InferenceWorkerConfig: + """Class to record configuration for Inference job""" + + device: str = "cpu" + model_info: ModelInfo = ModelInfo() + weights_config: WeightsInfo = WeightsInfo() + results_path: str = str(Path.home() / Path("cellseg3d/inference")) + filetype: str = ".tif" + keep_on_cpu: bool = False + compute_stats: bool = False + post_process_config: PostProcessConfig = PostProcessConfig() + sliding_window_config: SlidingWindowConfig = SlidingWindowConfig() + + images_filepaths: str = None + layer: napari.layers.Layer = None + + +################ +# Training configs + + +@dataclass +class DeterministicConfig: + """Class to record deterministic config""" + + enabled: bool = False + seed: int = 23498 + + +@dataclass +class TrainerConfig: + """Class to record trainer plugin config""" + + save_as_zip: bool = False + + +@dataclass +class TrainingWorkerConfig: + """Class to record config for Trainer plugin""" + + device: str = "cpu" + model_info: ModelInfo = None + weights_info: WeightsInfo = None + train_data_dict: dict = None + validation_percent: float = 0.8 + max_epochs: int = 5 + loss_function: callable = None + learning_rate: np.float64 = 1e-3 + validation_interval: int = 2 + batch_size: int = 1 + results_path_folder: str = str(Path.home() / Path("cellseg3d/training")) + sampling: bool = False + num_samples: int = 2 + sample_size: List[int] = None + do_augmentation: bool = True + deterministic_config: DeterministicConfig = DeterministicConfig() diff --git a/napari_cellseg3d/drafts.py b/napari_cellseg3d/dev_scripts/drafts.py similarity index 100% rename from napari_cellseg3d/drafts.py rename to napari_cellseg3d/dev_scripts/drafts.py diff --git a/napari_cellseg3d/dev_scripts/weight_conversion.py b/napari_cellseg3d/dev_scripts/weight_conversion.py index cb9c3e47..6cdb9c43 100644 --- a/napari_cellseg3d/dev_scripts/weight_conversion.py +++ b/napari_cellseg3d/dev_scripts/weight_conversion.py @@ -3,8 +3,8 @@ import torch -from napari_cellseg3d.models.model_TRAILMAP import get_net -from napari_cellseg3d.models.unet.model import UNet3D +from napari_cellseg3d.code_models.models import get_net +from napari_cellseg3d.code_models.models.unet.model import UNet3D # not sure this actually works when put here diff --git a/napari_cellseg3d/interface.py b/napari_cellseg3d/interface.py index dfb3be31..5e340542 100644 --- a/napari_cellseg3d/interface.py +++ b/napari_cellseg3d/interface.py @@ -1,10 +1,21 @@ -from typing import Optional +import threading +import warnings +from functools import partial from typing import List +from typing import Optional +import napari +# Qt +from qtpy import QtCore +from qtpy.QtCore import QObject from qtpy.QtCore import Qt + +# from qtpy.QtCore import QtWarningMsg from qtpy.QtCore import QUrl +from qtpy.QtGui import QCursor from qtpy.QtGui import QDesktopServices +from qtpy.QtGui import QTextCursor from qtpy.QtWidgets import QCheckBox from qtpy.QtWidgets import QComboBox from qtpy.QtWidgets import QDoubleSpinBox @@ -15,19 +26,27 @@ from qtpy.QtWidgets import QLabel from qtpy.QtWidgets import QLayout from qtpy.QtWidgets import QLineEdit +from qtpy.QtWidgets import QMenu from qtpy.QtWidgets import QPushButton +from qtpy.QtWidgets import QRadioButton from qtpy.QtWidgets import QScrollArea from qtpy.QtWidgets import QSizePolicy +from qtpy.QtWidgets import QSlider from qtpy.QtWidgets import QSpinBox +from qtpy.QtWidgets import QTextEdit from qtpy.QtWidgets import QVBoxLayout from qtpy.QtWidgets import QWidget +# Local from napari_cellseg3d import utils """ User interface functions and aliases""" +############### +# show debug tooltips +SHOW_LABELS_DEBUG_TOOLTIP = False ############### # aliases LEFT_AL = Qt.AlignmentFlag.AlignLeft @@ -47,8 +66,241 @@ dark_red = "#72071d" # crimson red default_cyan = "#8dd3c7" # turquoise cyan (default matplotlib line color under dark background context) napari_grey = "#262930" # napari background color (grey) +napari_param_grey = "#414851" # napari parameters menu color (lighter gray) +napari_param_darkgrey = "#202228" # napari default LineEdit color ############### +logger = utils.LOGGER + +################## +# Singleton UI widgets +################## + + +class QWidgetSingleton(type(QObject)): + """ + To be used as a metaclass when making a singleton QWidget, + meaning only one instance exists at a time. + Avoids unnecessary memory overhead and keeps user parameters even when a widget is closed + """ + + _instances = {} + + def __call__(cls, *args, **kwargs): + """ + Ensure only one instance of a QWidget with QWidgetSingleton as a metaclass exists at a time + + """ + if cls not in cls._instances: + cls._instances[cls] = super(QWidgetSingleton, cls).__call__( + *args, **kwargs + ) + return cls._instances[cls] + + +################## +# Screen size adjustment error handler +################## + + +def handle_adjust_errors(widget, type, context, msg: str): + """Qt message handler that attempts to react to errors when setting the window size + and resizes the main window""" + pass + # head = msg.split(": ")[0] + # if type == QtWarningMsg and head == "QWindowsWindow::setGeometry": + # logger.warning( + # f"Qt resize error : {msg}\nhas been handled by attempting to resize the window" + # ) + # try: + # if widget.parent() is not None: + # state = int(widget.parent().parent().windowState()) + # if state == 0: # normal state + # widget.parent().parent().adjustSize() + # logger.debug("Non-max. size adjust attempt") + # logger.debug(f"{widget.parent().parent()}") + # elif state == 2: # maximized state + # widget.parent().parent().showNormal() + # widget.parent().parent().showMaximized() + # logger.debug("Maximized size adjust attempt") + # except RuntimeError: + # pass + + +def handle_adjust_errors_wrapper(widget): + """Returns a callable that can be used with qInstallMessageHandler directly""" + return partial(handle_adjust_errors, widget) + + +################## +# Context menu for utilities +################## + + +class UtilsDropdown(metaclass=utils.Singleton): + """Singleton class for use in instantiating only one Utility dropdown menu that can be accessed from the plugin.""" + + caller_widget = None + + def dropdown_menu_call(self, widget, event): + """Calls the utility dropdown menu at the location of a CTRL+right-click""" + # ### DEBUG ### # + # print(event.modifiers) + # print("menu call") + # print(widget) + # print(self) + ################## + if self.caller_widget is None: + self.caller_widget = widget + + if event.button == 2 and "control" in event.modifiers: + dragged = False + yield + # on move + while event.type == "mouse_move": + # print(event.position) + dragged = True + yield + # on release + if dragged: + # print("drag end") + pass + else: + # print("clicked!") + if widget is self.caller_widget: + # print(f"authorized widget {widget} to show menu") + pos = QCursor.pos() + self.show_utils_menu(widget, pos) + # else: + # print(f"blocked widget {widget} from opening utils") + + def show_utils_menu(self, widget, event): + """ + Shows the context menu for utilities. Use with dropdown_menu_call. + Args: + widget: widget to show context menu in + event: mouse press event + """ + from napari_cellseg3d.code_plugins.plugin_utilities import ( + UTILITIES_WIDGETS, + ) + + menu = QMenu(widget.window()) + menu.setStyleSheet(f"background-color: {napari_grey}; color: white;") + + actions = [] + for title in UTILITIES_WIDGETS.keys(): + a = menu.addAction(f"Utilities : {title}") + actions.append(a) + + action = menu.exec_(event) + + for possible_action in actions: + if action == possible_action: + text = possible_action.text().split(": ")[1] + widget = UTILITIES_WIDGETS[text](widget._viewer) + widget._viewer.window.add_dock_widget(widget) + + +############## +# Log widget +############## + + +class Log(QTextEdit): + """Class to implement a log for important user info. Should be thread-safe.""" + + def __init__(self, parent=None): + """Creates a log with a lock for multithreading + + Args: + parent (QWidget): parent widget to add Log instance to. + """ + super().__init__(parent) + + # from qtpy.QtCore import QMetaType + # parent.qRegisterMetaType("QTextCursor") + + self.lock = threading.Lock() + + # def receive_log(self, text): + # self.print_and_log(text) + def write(self, message): + """ + Write message to log in a thread-safe manner + Args: + message: string to be printed + """ + self.lock.acquire() + try: + if not hasattr(self, "flag"): + self.flag = False + message = message.replace("\r", "").rstrip() + if message: + method = "replace_last_line" if self.flag else "append" + QtCore.QMetaObject.invokeMethod( + self, + method, + QtCore.Qt.QueuedConnection, + QtCore.Q_ARG(str, message), + ) + self.flag = True + else: + self.flag = False + + finally: + self.lock.release() + + @QtCore.Slot(str) + def replace_last_line(self, text): + """Replace last line. For use in progress bar""" + self.lock.acquire() + try: + cursor = self.textCursor() + cursor.movePosition(QTextCursor.End) + cursor.select(QTextCursor.BlockUnderCursor) + cursor.removeSelectedText() + cursor.insertBlock() + self.setTextCursor(cursor) + self.insertPlainText(text) + finally: + self.lock.release() + + def print_and_log(self, text, printing=True): + """Utility used to both print to terminal and log text to a QTextEdit + item in a thread-safe manner. Use only for important user info. + + Args: + text (str): Text to be printed and logged + printing (bool): Whether to print the message as well or not using logger.info(). Defaults to True. + + """ + self.lock.acquire() + try: + if printing: + logger.info(text) + # causes issue if you clik on terminal (tied to CMD QuickEdit mode on Windows) + self.moveCursor(QTextCursor.End) + self.insertPlainText(f"\n{text}") + self.verticalScrollBar().setValue( + self.verticalScrollBar().maximum() + ) + finally: + self.lock.release() + + def warn(self, warning): + """Show warnings.warn from another thread""" + self.lock.acquire() + try: + warnings.warn(warning) + finally: + self.lock.release() + + +############## +# UI elements +############## + def toggle_visibility(checkbox, widget): """Toggles the visibility of a widget based on the status of a checkbox. @@ -67,6 +319,41 @@ def add_label(widget, label, label_before=True, horizontal=True): return combine_blocks(label, widget, horizontal=horizontal) +class ContainerWidget(QWidget): + def __init__( + self, l=0, t=0, r=1, b=11, vertical=True, parent=None, fixed=True + ): + """ + Creates a container widget that can contain other widgets + Args: + l: left margin in pixels + t: top margin in pixels + r: right margin in pixels + b: bottom margin in pixels + vertical: if True, renders vertically. Horizontal otherwise + parent: parent QWidget + fixed: uses QLayout.SetFixedSize if True + """ + + super().__init__(parent) + self.layout = None + + if vertical: + self.layout = QVBoxLayout(self) + else: + self.layout = QHBoxLayout(self) + + self.layout.setContentsMargins(l, t, r, b) + if fixed: + self.layout.setSizeConstraint(QLayout.SetFixedSize) + + +class RadioButton(QRadioButton): + def __init__(self, text: str = None, parent=None): + + super().__init__(text, parent) + + class Button(QPushButton): """Class for a button with a title and connected to a function when clicked. Inherits from QPushButton. @@ -96,6 +383,7 @@ def __init__( self.setSizePolicy(QSizePolicy.Fixed, QSizePolicy.Fixed) def visibility_condition(self, checkbox): + """Provide a QCheckBox to use to determine whether to show the button or not""" toggle_visibility(checkbox, self) @@ -149,6 +437,158 @@ def __init__( self.toggled.connect(func) +class Slider(QSlider): + """Shortcut class to create a Slider widget""" + + def __init__( + self, + lower: int = 0, + upper: int = 100, + step: int = 1, + default: int = 0, + divide_factor: float = 1.0, + parent=None, + orientation=Qt.Horizontal, + text_label: str = None, + ): + + super().__init__(orientation, parent) + + self.setMaximum(upper) + self.setMinimum(lower) + self.setSingleStep(step) + + self.setSizePolicy(QSizePolicy.Fixed, QSizePolicy.Fixed) + + self.text_label = None + self.container = ContainerWidget( + # parent=self.parent + ) + + self._divide_factor = divide_factor + self._value_label = QLineEdit(self.value_text, parent=self) + + if self._divide_factor == 1: + self._value_label.setFixedWidth(20) + else: + self._value_label.setFixedWidth(30) + self._value_label.setAlignment(Qt.AlignCenter) + self._value_label.setSizePolicy( + QSizePolicy.Policy.Fixed, QSizePolicy.Policy.Fixed + ) + + self._value_label.setStyleSheet( + f"background-color: {napari_param_grey};" + f"border-radius: 5px;" + "min - height: 12px;" + "min - width: 12px;" + ) + + if text_label is not None: + self.text_label = make_label(text_label, parent=self) + + if default < lower: + self._warn_outside_bounds(default) + default = lower + elif default > upper: + self._warn_outside_bounds(default) + default = upper + + self.valueChanged.connect(self._update_value_label) + self._value_label.textChanged.connect(self._update_slider) + + self.slider_value = default + + self._build_container() + + def _build_container(self): + self.container.layout + + if self.text_label is not None: + add_widgets( + self.container.layout, + [ + self.text_label, + combine_blocks(self._value_label, self, b=0), + ], + ) + else: + add_widgets( + self.container.layout, + [combine_blocks(self._value_label, self, b=0)], + ) + + def _warn_outside_bounds(self, default): + warnings.warn( + f"Default value {default} was outside of the ({self.minimum()}:{self.maximum()}) range" + ) + + def _update_slider(self): + """Update slider when value is changed""" + if self._value_label.text() == "": + return + + value = float(self._value_label.text()) * self._divide_factor + + if value < self.minimum(): + self.slider_value = self.minimum() + return + if value > self.maximum(): + self.slider_value = self.maximum() + return + + self.slider_value = value + + def _update_value_label(self): + """Update label, to connect to when slider is dragged""" + self._value_label.setText(str(self.value_text)) + + @property + def tooltips(self): + return self.toolTip() + + @tooltips.setter + def tooltips(self, tooltip: str): + self.setToolTip(tooltip) + self._value_label.setToolTip(tooltip) + + if self.text_label is not None: + self.text_label.setToolTip(tooltip) + + @property + def slider_value(self): + """Get value of the slider divided by self._divide_factor to implement floats in Slider""" + if self._divide_factor == 1.0: + return self.value() + + try: + return self.value() / self._divide_factor + except ZeroDivisionError as e: + raise ZeroDivisionError( + f"Divide factor cannot be 0 for Slider : {e}" + ) + + @property + def value_text(self): + """Get value of the slide bar as string""" + return str(self.slider_value) + + @slider_value.setter + def slider_value(self, value: int): + """Set a value (int) divided by self._divide_factor""" + if value < self.minimum() or value > self.maximum(): + raise ValueError( + f"The value for the slider ({value}) cannot be out of ({self.minimum()};{self.maximum()}) " + ) + + self.setValue(value) + + divided = value / self._divide_factor + if self._divide_factor == 1.0: + divided = int(divided) + self._value_label.setText(str(divided)) + + class AnisotropyWidgets(QWidget): """Class that creates widgets for anisotropy handling. Includes : - A checkbox to hides or shows the controls @@ -175,13 +615,13 @@ def __init__( self._layout.setSpacing(0) self._layout.setContentsMargins(0, 0, 0, 0) - self.container, self._boxes_layout = make_container(T=7, parent=parent) - self.checkbox = make_checkbox( + self.container = ContainerWidget(t=7, parent=parent) + self.checkbox = CheckBox( "Anisotropic data", self._toggle_display_aniso, parent ) self.box_widgets = DoubleIncrementCounter.make_n( - n=3, min=1.0, max=1000, default=1, step=0.5 + n=3, lower=1.0, upper=1000.0, default=1.0, step=0.5 ) self.box_widgets[0].setValue(default_x) self.box_widgets[1].setValue(default_y) @@ -209,50 +649,43 @@ def __init__( self.build() if always_visible: - self.toggle_permanent_visibility() + self._toggle_permanent_visibility() def _toggle_display_aniso(self): - """Shows the choices for correcting anisotropy when viewing results depending on whether :py:attr:`self.checkbox` is checked""" + """Shows the choices for correcting anisotropy + when viewing results depending on whether :py:attr:`self.checkbox` is checked""" toggle_visibility(self.checkbox, self.container) def build(self): """Builds the layout of the widget""" [ - self._boxes_layout.addWidget(widget, alignment=HCENTER_AL) + self.container.layout.addWidget(widget, alignment=HCENTER_AL) for widgets in zip(self.box_widgets_lbl, self.box_widgets) for widget in widgets ] # anisotropy - self.container.setLayout(self._boxes_layout) + self.container.setLayout(self.container.layout) self.container.setVisible(False) add_widgets(self._layout, [self.checkbox, self.container]) self.setLayout(self._layout) - def get_anisotropy_resolution_xyz(self, as_factors=True): - """ - Args : - as_factors: if True, returns zoom factors, otherwise returns the input resolution - - Returns : the resolution in microns for each of the three dimensions. ZYX order suitable for napari scale""" + def resolution_xyz(self): + """The resolution selected for each of the three dimensions. XYZ order (for MONAI)""" + return [w.value() for w in self.box_widgets] - resolution = [w.value() for w in self.box_widgets] - if as_factors: - return self.anisotropy_zoom_factor(resolution) + def scaling_xyz(self): + """The scaling factors for each of the three dimensions. XYZ order (for MONAI)""" + return self.anisotropy_zoom_factor(self.resolution_xyz()) - return resolution + def resolution_zyx(self): + """The resolution selected for each of the three dimensions. ZYX order (for napari)""" + res = self.resolution_xyz() + return [res[2], res[1], res[0]] - def get_anisotropy_resolution_zyx(self, as_factors=True): - """ - Args : - as_factors: if True, returns zoom factors, otherwise returns the input resolution - - Returns : the resolution in microns for each of the three dimensions. XYZ order suitable for MONAI""" - resolution = [w.value() for w in self.box_widgets] - if as_factors: - resolution = self.anisotropy_zoom_factor(resolution) - - return [resolution[2], resolution[1], resolution[0]] + def scaling_zyx(self): + """The scaling factors for each of the three dimensions. ZYX order (for napari)""" + return self.anisotropy_zoom_factor(self.resolution_zyx()) @staticmethod def anisotropy_zoom_factor(aniso_res): @@ -269,19 +702,87 @@ def anisotropy_zoom_factor(aniso_res): zoom_factors = [base / res for res in aniso_res] return zoom_factors - def is_enabled(self): + def enabled(self): """Returns : whether anisotropy correction has been enabled or not""" return self.checkbox.isChecked() - def toggle_permanent_visibility(self): + def _toggle_permanent_visibility(self): """Hides the checkbox and always display resolution spinboxes""" self.checkbox.toggle() self.checkbox.setVisible(False) -class FilePathWidget( - QWidget -): # TODO upgrade logic, include load as folder, highlight if incorrect ? +class LayerSelecter(ContainerWidget): + def __init__( + self, viewer, name="Layer", layer_type=napari.layers.Layer, parent=None + ): + super().__init__(parent=parent, fixed=False) + self._viewer = viewer + + self.image = None + self.layer_type = layer_type + + self.layer_list = DropdownMenu(parent=self, label=name, fixed=False) + # self.layer_list.setSizeAdjustPolicy(QComboBox.AdjustToContents) # use tooltip instead ? + + self._viewer.layers.events.inserted.connect(partial(self._add_layer)) + self._viewer.layers.events.removed.connect(partial(self._remove_layer)) + + self.layer_list.currentIndexChanged.connect(self._update_tooltip) + + add_widgets(self.layout, [self.layer_list.label, self.layer_list]) + self._check_for_layers() + + def _check_for_layers(self): + + for layer in self._viewer.layers: + if isinstance(layer, self.layer_type): + self.layer_list.addItem(layer.name) + + def _update_tooltip(self): + + self.layer_list.setToolTip(self.layer_list.currentText()) + + def _add_layer(self, event): + + inserted_layer = event.value + + if isinstance(inserted_layer, self.layer_type): + self.layer_list.addItem(inserted_layer.name) + + def _remove_layer(self, event): + + removed_layer = event.value + + if isinstance( + removed_layer, self.layer_type + ) and removed_layer.name in [ + self.layer_list.itemText(i) for i in range(self.layer_list.count()) + ]: + + index = self.layer_list.findText(removed_layer.name) + self.layer_list.removeItem(index) + + def set_layer_type(self, type): # no @property due to Qt constraint + self.layer_type = type + [self.layer_list.removeItem(i) for i in range(self.layer_list.count())] + self._check_for_layers() + + def layer(self): + return self._viewer.layers[self.layer_name()] + + def layer_name(self): + return self.layer_list.currentText() + + def layer_data(self): + if self.layer_list.count() < 1: + warnings.warn("Please select a valid layer !") + return + + return self._viewer.layers[self.layer_name()].data + + +class FilePathWidget(QWidget): # TODO include load as folder """Widget to handle the choice of file paths for data throughout the plugin. Provides the following elements : - An "Open" button to show a file dialog (defined externally) - A QLineEdit in read only to display the chosen path/file""" @@ -292,6 +793,7 @@ def __init__( file_function: callable, parent: Optional[QWidget] = None, required: Optional[bool] = True, + default: Optional[str] = None, ): """Creates a FilePathWidget. Args: @@ -306,26 +808,53 @@ def __init__( self._layout.setContentsMargins(0, 0, 0, 0) self._initial_desc = description - self.text_field = QLineEdit(description, self) + self._text_field = QLineEdit(description, self) + + self._button = Button("Open", file_function, parent=self, fixed=True) - self.button = Button("Open", file_function, parent=self, fixed=True) + self._text_field.setReadOnly(True) # for user only + if default is not None: + self._text_field.setText(default) - self.text_field.setReadOnly(True) + self._required = required - self.set_required(required) + self.build() + self.check_ready() def build(self): """Builds the layout of the widget""" - add_widgets(self._layout, [self.text_field, self.button]) + add_widgets( + self._layout, + [combine_blocks(self.button, self.text_field, min_spacing=5, b=0)], + ABS_AL, + ) self.setLayout(self._layout) - def get_text_field(self): + @property + def tooltips(self): + return self._text_field.toolTip() + + @tooltips.setter + def tooltips(self, tooltip: str): + self._text_field.setToolTip(tooltip) + self._button.setToolTip(tooltip) + + @property + def text_field(self): """Get text field with file path""" - return self.text_field + return self._text_field - def get_button(self): + @text_field.setter + def text_field(self, text: str): + """Sets the initial description in the text field, makes it the new default path""" + self._initial_desc = text + self.tooltips = text + self._text_field.setText(text) + + @property + def button(self): """Get "Open" button""" - return self.button + return self._button def check_ready(self): """Check if a path is correctly set""" @@ -334,16 +863,25 @@ def check_ready(self): self.text_field.setToolTip("Mandatory field !") return False else: - self.update_field_color("black") + self.update_field_color(f"{napari_param_darkgrey}") return True - def set_required(self, is_required): + @property + def required(self): + return self._required + + @required.setter + def required(self, is_required): """If set to True, will be colored red if incorrectly set""" if is_required: self.text_field.textChanged.connect(self.check_ready) else: - self.text_field.textChanged.disconnect(self.check_ready) + try: + self.text_field.textChanged.disconnect(self.check_ready) + except TypeError: + return self.check_ready() + self._required = is_required def update_field_color(self, color: str): """Updates the background of the text field""" @@ -351,11 +889,6 @@ def update_field_color(self, color: str): self.text_field.style().unpolish(self.text_field) self.text_field.style().polish(self.text_field) - def set_description(self, text: str): - """Sets the initial description ins the text field""" - self._initial_desc = text - self.text_field.setText(text) - class ScrollArea(QScrollArea): """Creates a QScrollArea and sets it up, then adds the contained_layout to it.""" @@ -376,7 +909,6 @@ def __init__( base_wh (Optional[List[int]]): array of two ints for respectively the initial width and initial height of the scrollable area. Defaults to None, lets Qt decide if None parent (Optional[QWidget]): array of two ints for respectively the initial width and initial height of the scrollable area. Defaults to None, lets Qt decide if None """ - # TODO : optimize the number of created objects ? super().__init__(parent) self._container_widget = ( @@ -441,11 +973,11 @@ def set_spinbox( fixed: Optional[bool] = True, ): """Args: - class_ : QSpinBox or QDoubleSpinBox - min (Optional[int]): minimum value, defaults to 0 - max (Optional[int]): maximum value, defaults to 10 - default (Optional[int]): default value, defaults to 0 - step (Optional[int]): step value, defaults to 1 + box : QSpinBox or QDoubleSpinBox + min : minimum value, defaults to 0 + max : maximum value, defaults to 10 + default : default value, defaults to 0 + step : step value, defaults to 1 fixed (bool): if True, sets the QSizePolicy of the spinbox to Fixed""" box.setMinimum(min) @@ -494,17 +1026,17 @@ class DoubleIncrementCounter(QDoubleSpinBox): def __init__( self, - min=0, - max=10, - default=0, - step=1, + lower: Optional[float] = 0.0, + upper: Optional[float] = 10.0, + default: Optional[float] = 0.0, + step: Optional[float] = 1.0, parent: Optional[QWidget] = None, fixed: Optional[bool] = True, label: Optional[str] = None, ): """Args: - min (Optional[float]): minimum value, defaults to 0 - max (Optional[float]): maximum value, defaults to 10 + lower (Optional[float]): minimum value, defaults to 0 + upper (Optional[float]): maximum value, defaults to 10 default (Optional[float]): default value, defaults to 0 step (Optional[float]): step value, defaults to 1 parent: parent widget, defaults to None @@ -512,20 +1044,30 @@ def __init__( label (Optional[str]): if provided, creates a label with the chosen title to use with the counter""" super().__init__(parent) - set_spinbox(self, min, max, default, step, fixed) + set_spinbox(self, lower, upper, default, step, fixed) + + self.layout = None if label is not None: self.label = make_label(name=label) - # def setToolTip(self, a0: str) -> None: - # self.setToolTip(a0) - # if self.label is not None: - # self.label.setToolTip(a0) + @property + def tooltips(self): + return self.toolTip() + + @tooltips.setter + def tooltips(self, tooltip: str): + """Sets the tooltip of both the DoubleIncrementCounter and its label""" + self.setToolTip(tooltip) + if self.label is not None: + self.label.setToolTip(tooltip) - def get_with_label(self, horizontal=True): - return add_label(self, self.label, horizontal=horizontal) + @property + def precision(self): + return self.decimals() - def set_precision(self, decimals): + @precision.setter + def precision(self, decimals: int): """Sets the precision of the box to the specified number of decimals""" self.setDecimals(decimals) @@ -533,14 +1075,16 @@ def set_precision(self, decimals): def make_n( cls, n: int = 2, - min=0, - max=10, - default=0, - step=1, + lower: float = 0, + upper: float = 10, + default: float = 0, + step: float = 1, parent: Optional[QWidget] = None, fixed: Optional[bool] = True, ): - return make_n_spinboxes(cls, n, min, max, default, step, parent, fixed) + return make_n_spinboxes( + cls, n, lower, upper, default, step, parent, fixed + ) class IntIncrementCounter(QSpinBox): @@ -548,8 +1092,8 @@ class IntIncrementCounter(QSpinBox): def __init__( self, - min=0, - max=10, + lower=0, + upper=10, default=0, step=1, parent: Optional[QWidget] = None, @@ -557,31 +1101,45 @@ def __init__( label: Optional[str] = None, ): """Args: - min (Optional[int]): minimum value, defaults to 0 - max (Optional[int]): maximum value, defaults to 10 + lower (Optional[int]): minimum value, defaults to 0 + upper (Optional[int]): maximum value, defaults to 10 default (Optional[int]): default value, defaults to 0 step (Optional[int]): step value, defaults to 1 parent: parent widget, defaults to None fixed (bool): if True, sets the QSizePolicy of the spinbox to Fixed""" super().__init__(parent) - set_spinbox(self, min, max, default, step, fixed) + set_spinbox(self, lower, upper, default, step, fixed) + self.label = None + self.container = None + if label is not None: - self.label = make_label(label, self) + self.label = make_label(name=label) + + @property + def tooltips(self): + return self.toolTip() + + @tooltips.setter + def tooltips(self, tooltip): + self.setToolTip(tooltip) + self.label.setToolTip(tooltip) @classmethod def make_n( cls, n: int = 2, - min=0, - max=10, - default=0, - step=1, + lower: int = 0, + upper: int = 10, + default: int = 0, + step: int = 1, parent: Optional[QWidget] = None, fixed: Optional[bool] = True, ): - return make_n_spinboxes(cls, n, min, max, default, step, parent, fixed) + return make_n_spinboxes( + cls, n, lower, upper, default, step, parent, fixed + ) def add_blank(widget, layout=None): @@ -603,8 +1161,7 @@ def add_blank(widget, layout=None): def open_file_dialog( widget, - possible_paths: list = [""], - load_as_folder: bool = False, + possible_paths: list = [], filetype: str = "Image file (*.tif *.tiff)", ): """Opens a window to choose a file directory using QFileDialog. @@ -619,17 +1176,24 @@ def open_file_dialog( """ default_path = utils.parse_default_path(possible_paths) - if not load_as_folder: - f_name = QFileDialog.getOpenFileName( - widget, "Choose file", default_path, filetype - ) - return f_name - else: - print(default_path) - filenames = QFileDialog.getExistingDirectory( - widget, "Open directory", default_path - ) - return filenames + + f_name = QFileDialog.getOpenFileName( + widget, "Choose file", default_path, filetype + ) + return f_name + + +def open_folder_dialog( + widget, + possible_paths: list = [], +): + default_path = utils.parse_default_path(possible_paths) + + logger.info(f"Default : {default_path}") + filenames = QFileDialog.getExistingDirectory( + widget, "Open directory", default_path + ) + return filenames def make_label(name, parent=None): # TODO update to child class @@ -643,99 +1207,58 @@ def make_label(name, parent=None): # TODO update to child class """ if parent is not None: - return QLabel(name, parent) + label = QLabel(name, parent) + if SHOW_LABELS_DEBUG_TOOLTIP: + label.setToolTip(f"{label}") + return label else: - return QLabel(name) - - -def add_to_group(title, widget, layout, L=7, T=20, R=7, B=11): - """Adds a single widget to a layout as a named group with margins specified. - - Args: - title: title of the group - widget: widget to add in the group - layout: layout to add the group in - L: left margin (in pixels) - T: top margin (in pixels) - R: right margin (in pixels) - B: bottom margin (in pixels) - - """ - group, layout_internal = make_group(title, L, T, R, B) - layout_internal.addWidget(widget) - group.setLayout(layout_internal) - layout.addWidget(group) + label = QLabel(name) + if SHOW_LABELS_DEBUG_TOOLTIP: + label.setToolTip(f"{label}") + return label -def make_group(title, L=7, T=20, R=7, B=11, parent=None): # TODO : child class +def make_group(title, l=7, t=20, r=7, b=11, parent=None): """Creates a group widget and layout, with a header (`title`) and content margins for top/left/right/bottom `L, T, R, B` (in pixels) Group widget and layout returned will have a Fixed size policy. Args: title (str): Title of the group - L (int): left margin - T (int): top margin - R (int): right margin - B (int): bottom margin + l (int): left margin + t (int): top margin + r (int): right margin + b (int): bottom margin parent (QWidget) : parent widget. If None, no parent is set """ - if parent is None: - group = QGroupBox(title) - else: - group = QGroupBox(title, parent=parent) - group.setSizePolicy(QSizePolicy.Fixed, QSizePolicy.Fixed) - layout = QVBoxLayout() - layout.setContentsMargins(L, T, R, B) - layout.setSizeConstraint(QLayout.SetFixedSize) + group = GroupedWidget(title, l, t, r, b, parent=parent) + layout = group.layout return group, layout -def make_container( - L=0, T=0, R=1, B=11, vertical=True, parent=None -): # TODO child class? - """Creates a QWidget and a layout for the purpose of containing other modules, with a Fixed layout. - - Args: - parent : parent widget. If None, no widget is set - L (int): left margin of layout - T (int): top margin of layout - R (int): right margin of layout - B (int): bottom margin of layout - vertical (bool): if False, uses QHBoxLayout instead of QVboxLayout. Default: True - - Returns: - QWidget : widget that contains the other widgets. Fixed size. - QBoxLayout : H/V Box layout to add contained widgets in. Fixed size. - """ - if parent is None: - container_widget = QWidget() - else: - container_widget = QWidget(parent) +class GroupedWidget(QGroupBox): + """Subclass of QGroupBox designed to easily group widgets belonging to a same category""" - if vertical: - container_layout = QVBoxLayout() - else: - container_layout = QHBoxLayout() - container_layout.setContentsMargins(L, T, R, B) - container_layout.setSizeConstraint(QLayout.SetFixedSize) + def __init__(self, title, l=7, t=20, r=7, b=11, parent=None): + super().__init__(title, parent) - return container_widget, container_layout + self.setSizePolicy(QSizePolicy.Fixed, QSizePolicy.Fixed) + self.layout = QVBoxLayout() + self.layout.setContentsMargins(l, t, r, b) + self.layout.setSizeConstraint(QLayout.SetFixedSize) -def make_combobox(): # TODO finish child class conversion - """Creates a dropdown menu with a title and adds specified entries to it + def set_layout(self): + self.setLayout(self.layout) - Args: - entries (array(str)): Entries to add to the dropdown menu. Defaults to None, no entries if None - parent (QWidget): parent QWidget to add dropdown menu to. Defaults to None, no parent is set if None - label (str) : if not None, creates a QLabel with the contents of 'label', and returns the label as well - fixed (bool): if True, will set the size policy of the dropdown menu to Fixed in h and w. Defaults to True. - - Returns: - QComboBox : created dropdown menu - """ - raise NotImplementedError + @classmethod + def create_single_widget_group( + cls, title, widget, layout, l=7, t=20, r=7, b=11 + ): + group = cls(title, l, t, r, b) + group.layout.addWidget(widget) + group.setLayout(group.layout) + layout.addWidget(group) def add_widgets(layout, widgets, alignment=LEFT_AL): @@ -754,28 +1277,7 @@ def add_widgets(layout, widgets, alignment=LEFT_AL): layout.addWidget(w, alignment=alignment) -def make_checkbox( # TODO update calls to class - title: str = None, - func: callable = None, - parent: QWidget = None, - fixed: bool = True, -): - """Creates a checkbox with a title and connects it to a function when clicked - - Args: - title (str-like): title of the checkbox. Defaults to None, if None no title is set - func (callable): function to execute when checkbox is toggled. Defaults to None, no binding is made if None - parent (QWidget): parent QWidget to add checkbox to. Defaults to None, no parent is set if None - fixed (bool): if True, will set the size policy of the checkbox to Fixed in h and w. Defaults to True. - - Returns: - QCheckBox : created widget - """ - - return CheckBox(title, func, parent, fixed) - - -def combine_blocks( +def combine_blocks( # TODO FIXME PLEASE this is a horrible design right_or_below, left_or_above, min_spacing=0, @@ -789,10 +1291,14 @@ def combine_blocks( Weird argument names due the initial implementation of it. # TODO maybe fix arg names Args: - horizontal (bool): whether to stack widgets vertically (False) or horizontally (True) left_or_above (QWidget): First widget, to be added on the left/above of "second" right_or_below (QWidget): Second widget, to be displayed right/below of "first" min_spacing (int): Minimum spacing between the two widgets (from the start of label to the start of button) + horizontal (bool): whether to stack widgets vertically (False) or horizontally (True) + l (int): left spacing in pixels + t (int): top spacing in pixels + r (int): right spacing in pixels + b (int): bottom spacing in pixels Returns: QWidget: new QWidget containing the merged widget and label diff --git a/napari_cellseg3d/launch_review.py b/napari_cellseg3d/launch_review.py deleted file mode 100644 index 3069e141..00000000 --- a/napari_cellseg3d/launch_review.py +++ /dev/null @@ -1,323 +0,0 @@ -import os -from pathlib import Path - -import matplotlib.pyplot as plt -import napari -import numpy as np -from magicgui import magicgui -from matplotlib.backends.backend_qt5agg import ( - FigureCanvasQTAgg as FigureCanvas, -) -from matplotlib.figure import Figure -from qtpy.QtWidgets import QSizePolicy -from scipy import ndimage -from tifffile import imwrite - -from napari_cellseg3d import utils -from napari_cellseg3d.plugin_dock import Datamanager - - -def launch_review( - original, - base, - raw, - r_path, - model_type, - checkbox, - filetype, - as_folder, - zoom_factor, -): - """Launch the review process, loading the original image, the labels & the raw labels (from prediction) - in the viewer. - - Adds several widgets to the viewer : - - * A data manager widget that lets the user choose a directory to save the labels in, and a save button to quickly - save. - - * A "checked/not checked" button to let the user confirm that a slice has been checked or not. - - - **This writes in a csv file under the corresponding slice the slice status (i.e. checked or not checked) - to allow tracking of the review process for a given dataset.** - - * A plot widget that, when shift-clicking on the volume or label, - displays the chosen location on several projections (x-y, y-z, x-z), - to allow the user to have a better all-around view of the object - and determine whether it should be labeled or not. - - Args: - - original (dask.array.Array): The original images/volumes that have been labeled - - base (dask.array.Array): The labels for the volume - - raw (dask.array.Array): The raw labels from the prediction - - r_path (str): Path to the raw labels - - model_type (str): The name of the model to be displayed in csv filenames. - - checkbox (bool): Whether the "new model" checkbox has been checked or not, to create a new csv - - filetype (str): The file extension of the volumes and labels. - - as_folder (bool): Whether to load as folder or single file - - zoom_factor (array(int)): zoom factors for each axis - - Returns : list of all docked widgets - """ - images_original = original - base_label = base - - viewer = napari.Viewer() - - viewer.scale_bar.visible = True - - viewer.add_image( - images_original, - name="volume", - colormap="inferno", - contrast_limits=[200, 1000], - scale=zoom_factor, - ) # anything bigger than 255 will get mapped to 255... they did it like this because it must have rgb images - viewer.add_labels(base_label, name="labels", seed=0.6, scale=zoom_factor) - - if raw is not None: # raw labels is from the prediction - viewer.add_image( - ndimage.gaussian_filter(raw, sigma=3), - colormap="magenta", - name="low_confident", - blending="additive", - scale=zoom_factor, - ) - else: - pass - - # def label_and_sort(base_label): # assigns a different id for every different cell ? - # labeled = ndimage.label(base_label, structure=np.ones((3, 3, 3)))[0] - # - # mks, nums = np.unique(labeled, return_counts=True) - # - # idx_list = list(np.argsort(nums[1:])) - # nums = np.sort(nums[1:]) - # labeled_sorted = np.zeros_like(labeled) - # for i, idx in enumerate(idx_list): - # labeled_sorted = np.where(labeled == mks[1:][idx], i + 1, labeled_sorted) - # return labeled_sorted, nums - # - # def label_ct(labeled_array, nums, value): - # labeled_temp = copy.copy(labeled_array) - # idx = np.abs(nums - value).argmin() - # labeled_temp = np.where((labeled_temp < idx) & (labeled_temp != 0), 255, 0) - # return labeled_temp - - # def show_so_layer(args): - # labeled_c, labeled_sorted, nums = args - # so_layer = viewer.add_image(labeled_c, colormap='cyan', name='small_object', blending='additive') - # - # object_slider = QSlider(Qt.Horizontal) - # object_slider.setMinimum(0) - # object_slider.setMaximum(500) - # object_slider.setSingleStep(10) - # object_slider.setValue(10) - # - # object_slider.valueChanged[int].connect(lambda value=object_slider: calc_object_callback(so_layer, value, - # labeled_sorted, nums)) - # - # lbl = QLabel('object size') - # - # slider_widget = utils.combine_blocks(lbl, object_slider) - # - # viewer.window.add_dock_widget(slider_widget, name='object_size_slider', area='left') - # - # def calc_object_callback(t_layer, value, labeled_array, nums): - # t_layer.data = label_ct(labeled_array, nums, value) - - # @thread_worker(connect={"returned": show_so_layer}) - # def create_label(): - # labeled_sorted, nums = label_and_sort(base_label) - # labeled_c = label_ct(labeled_sorted, nums, 10) - # return labeled_c, labeled_sorted, nums - # - # worker = create_label() - # if not as_folder: - # r_path = os.path.dirname(r_path) - - @magicgui( - dirname={"mode": "d", "label": "Save labels in... "}, - call_button="Save", - # call_button_2="Save & quit", - ) - def file_widget( - dirname=Path(r_path), - ): # file name where to save annotations - # """Take a filename and do something with it.""" - # print("The filename is:", dirname) - - dirname = Path(r_path) - # def saver(): - out_dir = file_widget.dirname.value - - # print("The directory is:", out_dir) - - def quicksave(): - if not as_folder: - if viewer.layers["labels"] is not None: - name = os.path.join(str(out_dir), "labels_reviewed.tif") - dat = viewer.layers["labels"].data - imwrite(name, data=dat) - - else: - if viewer.layers["labels"] is not None: - dir_name = os.path.join(str(out_dir), "labels_reviewed") - dat = viewer.layers["labels"].data - utils.save_stack(dat, dir_name, filetype=filetype) - - # def quicksave_quit(): - # quicksave() - # viewer.window.close() - - return dirname, quicksave() # , quicksave_quit() - - # gui = file_widget.show(run=True) # dirpicker.show(run=True) - - viewer.window.add_dock_widget(file_widget, name=" ", area="bottom") - - # @magicgui(call_button="Save") - - # gui2 = saver.show(run=True) # saver.show(run=True) - # viewer.window.add_dock_widget(gui2, name=" ", area="bottom") - - # viewer.window._qt_window.tabifyDockWidget(gui, gui2) #not with FunctionGui ? - - # draw canvas - - with plt.style.context("dark_background"): - canvas = FigureCanvas(Figure(figsize=(3, 15))) - - xy_axes = canvas.figure.add_subplot(3, 1, 1) - canvas.figure.suptitle("Shift-click on image for plot \n", fontsize=8) - xy_axes.imshow(np.zeros((100, 100), np.int16)) - xy_axes.scatter(50, 50, s=10, c="green", alpha=0.25) - xy_axes.set_xlabel("x axis") - xy_axes.set_ylabel("y axis") - yz_axes = canvas.figure.add_subplot(3, 1, 2) - yz_axes.imshow(np.zeros((100, 100), np.int16)) - yz_axes.scatter(50, 50, s=10, c="green", alpha=0.25) - yz_axes.set_xlabel("y axis") - yz_axes.set_ylabel("z axis") - zx_axes = canvas.figure.add_subplot(3, 1, 3) - zx_axes.imshow(np.zeros((100, 100), np.int16)) - zx_axes.scatter(50, 50, s=10, c="green", alpha=0.25) - zx_axes.set_xlabel("x axis") - zx_axes.set_ylabel("z axis") - - # canvas.figure.tight_layout() - canvas.figure.subplots_adjust( - left=0.1, bottom=0.1, right=1, top=0.95, wspace=0, hspace=0.4 - ) - - canvas.setSizePolicy(QSizePolicy.Minimum, QSizePolicy.Maximum) - - viewer.window.add_dock_widget(canvas, name=" ", area="right") - - @viewer.mouse_drag_callbacks.append - def update_canvas_canvas(viewer, event): - - if "shift" in event.modifiers: - try: - cursor_position = np.round(viewer.cursor.position).astype(int) - print(f"plot @ {cursor_position}") - - cropped_volume = crop_volume_around_point( - [ - cursor_position[0], - cursor_position[1], - cursor_position[2], - ], - viewer.layers["volume"], - zoom_factor, - ) - - ########## - ########## - # DEBUG - # viewer.add_image(cropped_volume, name="DEBUG_crop_plot") - - xy_axes.imshow( - cropped_volume[50], cmap="inferno", vmin=200, vmax=2000 - ) - yz_axes.imshow( - cropped_volume.transpose(1, 0, 2)[50], - cmap="inferno", - vmin=200, - vmax=2000, - ) - zx_axes.imshow( - cropped_volume.transpose(2, 0, 1)[50], - cmap="inferno", - vmin=200, - vmax=2000, - ) - canvas.draw_idle() - except Exception as e: - print(e) - - # Qt widget defined in docker.py - dmg = Datamanager(parent=viewer) - dmg.prepare(r_path, filetype, model_type, checkbox, as_folder) - viewer.window.add_dock_widget(dmg, name=" ", area="left") - - def update_button(axis_event): - - slice_num = axis_event.value[0] - print(f"slice num is {slice_num}") - dmg.update(slice_num) - - viewer.dims.events.current_step.connect(update_button) - - def crop_volume_around_point(points, layer, zoom_factor): - if zoom_factor != [1, 1, 1]: - data = np.array(layer.data, dtype=np.int16) - volume = utils.resize(data, zoom_factor) - # image = ndimage.zoom(layer.data, zoom_factor, mode="nearest") # cleaner but much slower... - else: - volume = layer.data - - min_coordinates = [point - 50 for point in points] - max_coordinates = [point + 50 for point in points] - inferior_bound = [ - min_coordinate if min_coordinate < 0 else 0 - for min_coordinate in min_coordinates - ] - superior_bound = [ - max_coordinate - volume.shape[i] - if volume.shape[i] < max_coordinate - else 0 - for i, max_coordinate in enumerate(max_coordinates) - ] - - crop_slice = tuple( - slice(np.maximum(0, min_coordinate), max_coordinate) - for min_coordinate, max_coordinate in zip( - min_coordinates, max_coordinates - ) - ) - - if as_folder: - crop_temp = volume[crop_slice].persist().compute() - else: - crop_temp = volume[crop_slice] - - cropped_volume = np.zeros((100, 100, 100), np.int16) - cropped_volume[ - -inferior_bound[0] : 100 - superior_bound[0], - -inferior_bound[1] : 100 - superior_bound[1], - -inferior_bound[2] : 100 - superior_bound[2], - ] = crop_temp - return cropped_volume - - return viewer, [file_widget, canvas, dmg] diff --git a/napari_cellseg3d/log_utility.py b/napari_cellseg3d/log_utility.py deleted file mode 100644 index 1ae9b2a0..00000000 --- a/napari_cellseg3d/log_utility.py +++ /dev/null @@ -1,89 +0,0 @@ -import threading -import warnings - -from qtpy import QtCore -from qtpy.QtGui import QTextCursor -from qtpy.QtWidgets import QTextEdit - - -class Log(QTextEdit): - """Class to implement a log for important user info. Should be thread-safe.""" - - def __init__(self, parent): - """Creates a log with a lock for multithreading - - Args: - parent (QWidget): parent widget to add Log instance to. - """ - super().__init__(parent) - - # from qtpy.QtCore import QMetaType - # parent.qRegisterMetaType("QTextCursor") - - self.lock = threading.Lock() - - # def receive_log(self, text): - # self.print_and_log(text) - def write(self, message): - self.lock.acquire() - try: - if not hasattr(self, "flag"): - self.flag = False - message = message.replace("\r", "").rstrip() - if message: - method = "replace_last_line" if self.flag else "append" - QtCore.QMetaObject.invokeMethod( - self, - method, - QtCore.Qt.QueuedConnection, - QtCore.Q_ARG(str, message), - ) - self.flag = True - else: - self.flag = False - - finally: - self.lock.release() - - @QtCore.Slot(str) - def replace_last_line(self, text): - self.lock.acquire() - try: - cursor = self.textCursor() - cursor.movePosition(QTextCursor.End) - cursor.select(QTextCursor.BlockUnderCursor) - cursor.removeSelectedText() - cursor.insertBlock() - self.setTextCursor(cursor) - self.insertPlainText(text) - finally: - self.lock.release() - - def print_and_log(self, text, printing=True): - """Utility used to both print to terminal and log text to a QTextEdit - item in a thread-safe manner. Use only for important user info. - - Args: - text (str): Text to be printed and logged - printing (bool): Whether to print the message as well or not using print(). Defaults to True. - - """ - self.lock.acquire() - try: - if printing: - print(text) - # causes issue if you clik on terminal (tied to CMD QuickEdit mode on Windows) - self.moveCursor(QTextCursor.End) - self.insertPlainText(f"\n{text}") - self.verticalScrollBar().setValue( - self.verticalScrollBar().maximum() - ) - finally: - self.lock.release() - - def warn(self, warning): - self.lock.acquire() - try: - warnings.warn(warning) - finally: - self.lock.release() diff --git a/napari_cellseg3d/napari.yaml b/napari_cellseg3d/napari.yaml index 10af0cb2..82058b9e 100644 --- a/napari_cellseg3d/napari.yaml +++ b/napari_cellseg3d/napari.yaml @@ -17,11 +17,11 @@ contributions: python_name: napari_cellseg3d.plugins:Utilities - id: napari-cellseg3d.infer - title: Create Inferer + title: Create Inference widget python_name: napari_cellseg3d.plugins:Inferer - id: napari-cellseg3d.train - title: Create Trainer + title: Create Trainer widget python_name: napari_cellseg3d.plugins:Trainer diff --git a/napari_cellseg3d/plugin_base.py b/napari_cellseg3d/plugin_base.py deleted file mode 100644 index 8c5bd8e8..00000000 --- a/napari_cellseg3d/plugin_base.py +++ /dev/null @@ -1,316 +0,0 @@ -import glob -import os - -import napari -from qtpy.QtWidgets import QSizePolicy -from qtpy.QtWidgets import QTabWidget - -from napari_cellseg3d import interface as ui - - -class BasePluginSingleImage(QTabWidget): - """A basic plugin template for working with **single images**""" - - def __init__(self, viewer: "napari.viewer.Viewer", parent=None): - """Creates a Base plugin with several buttons pre-defined but not added to a layout : - - * Open file prompt to select images directory - - * Open file prompt to select labels directory - - * A checkbox to choose whether to load a folder of images or a single 3D file - If toggled, shows a filetype option to select the extension - - * A close button that closes the widget - - * A dropdown menu with a choice of png or tif filetypes - """ - - super().__init__() - - self.parent = parent - """Parent widget""" - self._viewer = viewer - """napari.viewer.Viewer: viewer in which the widget is displayed""" - - self.docked_widgets = [] - - self.image_path = "" - """str: path to image folder""" - - self.label_path = "" - """str: path to label folder""" - - self.results_path = "" - """str: path to results folder""" - - self.filetype = "" - """str: filetype, .tif or .png""" - self.as_folder = False - """bool: Whether to load a single file or a folder as a stack""" - - self._default_path = [self.image_path, self.label_path] - - self.image_filewidget = ui.FilePathWidget( - "Image path", self.show_dialog_images, self - ) - self.btn_image = self.image_filewidget.get_button() - """Button to load image folder""" - self.lbl_image = self.image_filewidget.get_text_field() - - self.label_filewidget = ui.FilePathWidget( - "Label path", self.show_dialog_labels, self - ) - self.lbl_label = self.label_filewidget.get_text_field() - self.btn_label = self.label_filewidget.get_button() - """Button to load label folder""" - - self.filetype_choice = ui.DropdownMenu([".png", ".tif"]) - - self.file_handling_box = ui.make_checkbox( - "Load as folder ?", self.show_filetype_choice - ) - """Checkbox to choose single file or directory loader handling""" - - self.file_handling_box.setSizePolicy( - QSizePolicy.Fixed, QSizePolicy.Fixed - ) - - self.btn_close = ui.Button("Close", self.remove_from_viewer, self) - - # self.lbl_ft = QLabel("Filetype :", self) - # self.lbl_ft2 = QLabel("(Folders of .png or single .tif files)", self) - - def build(self): - """Method to be defined by children classes""" - raise NotImplementedError - - def show_filetype_choice(self): - """Method to show/hide the filetype choice when "loading as folder" is (de)selected""" - show = self.file_handling_box.isChecked() - if show is not None: - self.filetype_choice.setVisible(show) - # self.lbl_ft.setVisible(show) - - def show_file_dialog(self): - """Open file dialog and process path depending on single file/folder loading behaviour""" - f_or_dir_name = ui.open_file_dialog( - self, self._default_path, self.file_handling_box.isChecked() - ) - if not self.file_handling_box.isChecked(): - f_or_dir_name = str(f_or_dir_name[0]) - self.filetype = os.path.splitext(f_or_dir_name)[1] - - print(f_or_dir_name) - - return f_or_dir_name - - def show_dialog_images(self): - """Show file dialog and set image path""" - f_name = self.show_file_dialog() - if type(f_name) is str and f_name != "": - self.image_path = f_name - self.lbl_image.setText(self.image_path) - self.update_default() - - def show_dialog_labels(self): - """Show file dialog and set label path""" - f_name = self.show_file_dialog() - if type(f_name) is str and f_name != "": - self.label_path = f_name - self.lbl_label.setText(self.label_path) - self.update_default() - - def load_results_path(self): - """Show file dialog to set :py:attr:`~results_path`""" - dir = ui.open_file_dialog(self, self._default_path, True) - if dir != "" and type(dir) is str and os.path.isdir(dir): - self.results_path = dir - self.lbl_result_path.setText(self.results_path) - self.update_default() - - def update_default(self): - """Updates default path for smoother navigation when opening file dialogs""" - self._default_path = [self.image_path, self.label_path] - - def remove_from_viewer(self): - """Removes the widget from the napari window. - Can be re-implemented in children classes if needed""" - - self.remove_docked_widgets() - - if self.parent is not None: - self.parent.remove_from_viewer() # TODO keep this way ? - return - self._viewer.window.remove_dock_widget(self) - - def remove_docked_widgets(self): - """Removes all docked widgets from napari window""" - try: - if len(self.docked_widgets) != 0: - [ - self._viewer.window.remove_dock_widget(w) - for w in self.docked_widgets - if w is not None - ] - return True - except LookupError: - return False - - -class BasePluginFolder(QTabWidget): - """A basic plugin template for working with **folders of images**""" - - def __init__(self, viewer: "napari.viewer.Viewer", parent=None): - """Creates a plugin template with the following widgets defined but not added in a layout : - - * A button to load a folder of images - - * A button to load a folder of labels - - * A button to set a results folder - - * A dropdown menu to select the file extension to be loaded from the folders""" - super().__init__() - self.parent = parent - self._viewer = viewer - """Viewer to display the widget in""" - - self.images_filepaths = [""] - """array(str): paths to images for training or inference""" - self.labels_filepaths = [""] - """array(str): paths to labels for training""" - self.results_path = "" - """str: path to output folder,to save results in""" - - self._default_path = [ - self.images_filepaths, - self.labels_filepaths, - self.results_path, - ] - - self.docked_widgets = [] - """List of docked widgets (returned by :py:func:`viewer.window.add_dock_widget())`, - can be used to remove docked widgets""" - - ####################################################### - # interface - self.image_filewidget = ui.FilePathWidget( - "Images directory", self.load_image_dataset, self - ) - self.btn_image_files = self.image_filewidget.get_button() - self.lbl_image_files = self.image_filewidget.get_text_field() - - self.label_filewidget = ui.FilePathWidget( - "Labels directory", self.load_label_dataset, self - ) - self.btn_label_files = self.label_filewidget.get_button() - self.lbl_label_files = self.label_filewidget.get_text_field() - - self.filetype_choice = ui.DropdownMenu( - [".tif", ".tiff"], label="File format" - ) - self.lbl_filetype = self.filetype_choice.label - """Allows to choose which file will be loaded from folder""" - - self.results_filewidget = ui.FilePathWidget( - "Results directory", self.load_results_path, self - ) - self.btn_result_path = self.results_filewidget.get_button() - self.lbl_result_path = self.results_filewidget.get_text_field() - ####################################################### - - def make_close_button(self): - btn = ui.Button("Close", self.remove_from_viewer) - btn.setToolTip( - "Close the window and all docked widgets. Make sure to save your work !" - ) - return btn - - def make_prev_button(self): - btn = ui.Button( - "Previous", lambda: self.setCurrentIndex(self.currentIndex() - 1) - ) - return btn - - def make_next_button(self): - btn = ui.Button( - "Next", lambda: self.setCurrentIndex(self.currentIndex() + 1) - ) - return btn - - def load_dataset_paths(self): - """Loads all image paths (as str) in a given folder for which the extension matches the set filetype - - Returns: - array(str): all loaded file paths - """ - filetype = self.filetype_choice.currentText() - directory = ui.open_file_dialog(self, self._default_path, True) - # print(directory) - file_paths = sorted(glob.glob(os.path.join(directory, "*" + filetype))) - # print(file_paths) - return file_paths - - def load_image_dataset(self): - """Show file dialog to set :py:attr:`~images_filepaths`""" - filenames = self.load_dataset_paths() - # print(filenames) - if filenames != "" and filenames != [""] and filenames != []: - self.images_filepaths = sorted(filenames) - # print(filenames) - path = os.path.dirname(filenames[0]) - self.lbl_image_files.setText(path) - # print(path) - self._default_path[0] = path - - def load_label_dataset(self): - """Show file dialog to set :py:attr:`~labels_filepaths`""" - filenames = self.load_dataset_paths() - if filenames != "" and filenames != [""] and filenames != []: - self.labels_filepaths = sorted(filenames) - path = os.path.dirname(filenames[0]) - self.lbl_label_files.setText(path) - self.update_default() - - def load_results_path(self): - """Show file dialog to set :py:attr:`~results_path`""" - dir = ui.open_file_dialog(self, self._default_path, True) - if dir != "" and type(dir) is str and os.path.isdir(dir): - self.results_path = dir - self.lbl_result_path.setText(self.results_path) - self.update_default() - - def build(self): - raise NotImplementedError("Should be defined in children classes") - - def update_default(self): - """Update default path for smoother file dialogs""" - self._default_path = [ - path - for path in [ - os.path.dirname(self.images_filepaths[0]), - os.path.dirname(self.labels_filepaths[0]), - self.results_path, - ] - if (path != [""] and path != "") - ] - - def remove_docked_widgets(self): - """Removes docked widgets and resets checks for status report""" - if len(self.docked_widgets) != 0: - [ - self._viewer.window.remove_dock_widget(w) - for w in self.docked_widgets - if w is not None - ] - self.docked_widgets = [] - self.container_docked = False - - def remove_from_viewer(self): - """Close the widget and the docked widgets, if any""" - self.remove_docked_widgets() - if self.parent is not None: - self.parent.remove_from_viewer() - return - self._viewer.window.remove_dock_widget(self) diff --git a/napari_cellseg3d/plugin_convert.py b/napari_cellseg3d/plugin_convert.py deleted file mode 100644 index e61f3f4a..00000000 --- a/napari_cellseg3d/plugin_convert.py +++ /dev/null @@ -1,447 +0,0 @@ -import os - -import napari -import numpy as np -from tifffile import imwrite, imread - -import napari_cellseg3d.interface as ui -from napari_cellseg3d import utils -from napari_cellseg3d.model_instance_seg import clear_small_objects -from napari_cellseg3d.model_instance_seg import to_instance -from napari_cellseg3d.model_instance_seg import to_semantic -from napari_cellseg3d.plugin_base import BasePluginFolder - - -class ConvertUtils(BasePluginFolder): - """Utility widget that allows to convert labels from instance to semantic and the reverse.""" - - def __init__(self, viewer: "napari.viewer.Viewer", parent): - """Builds a ConvertUtils widget with the following buttons: - - * A button to convert a folder of labels to semantic labels - - * A button to convert a folder of labels to instance labels - - * A button to convert a currently selected layer to semantic labels - - * A button to convert a currently selected layer to instance labels - """ - - super().__init__(viewer, parent) - - self._viewer = viewer - - ######################## - # interface - - # label conversion - self.btn_convert_folder_semantic = ui.Button( - "Convert to semantic labels", func=self.folder_to_semantic - ) - self.btn_convert_layer_semantic = ui.Button( - "Convert to semantic labels", func=self.layer_to_semantic - ) - self.btn_convert_folder_instance = ui.Button( - "Convert to instance labels", func=self.folder_to_instance - ) - self.btn_convert_layer_instance = ui.Button( - "Convert to instance labels", func=self.layer_to_instance - ) - # remove small - self.btn_remove_small_folder = ui.Button( - "Remove small in folder", func=self.folder_remove_small - ) - self.btn_remove_small_layer = ui.Button( - "Remove small in layer", func=self.layer_remove_small - ) - self.small_object_thresh_choice = ui.IntIncrementCounter( - min=1, max=1000, default=15 - ) - - # convert anisotropy - self.anisotropy_converter = ui.AnisotropyWidgets( - parent=self, always_visible=True - ) - self.btn_aniso_folder = ui.Button( - "Correct anisotropy in folder", self.folder_anisotropy, self - ) - self.btn_aniso_layer = ui.Button( - "Correct anisotropy in layer", self.layer_anisotropy, self - ) - - self.lbl_error = ui.make_label("", self) - self.lbl_error.setVisible(False) - - self.btn_image_files.setVisible(False) - self.lbl_image_files.setVisible(False) - - # self.results_filewidget.set_required(True) - self.label_filewidget.set_required(False) - # TODO improve not ready check for labels since optional until using folder conversion - ############################### - # tooltips - self.btn_convert_folder_semantic.setToolTip( - "Convert specified folder to semantic (0/1)" - ) - self.btn_convert_folder_instance.setToolTip( - "Convert specified folder to instance (unique ID per object)" - ) - self.btn_convert_layer_instance.setToolTip( - "Convert currently selected layer to instance (unique ID per object)" - ) - self.btn_convert_layer_semantic.setToolTip( - "Convert currently selected layer to semantic (0/1)" - ) - - self.btn_remove_small_layer.setToolTip( - "Remove small objects on selected layer image" - ) - self.btn_remove_small_folder.setToolTip( - "Remove small objects in all images of selected folder" - ) - self.small_object_thresh_choice.setToolTip( - "All objects in the image smaller in volume than this number of pixels will be removed" - ) - self.btn_aniso_layer.setToolTip( - "Resize the selected layer to be isotropic, based on the chosen resolutions above." - "\nDOES NOT WORK WITH INSTANCE LABELS, CONVERT TO SEMANTIC FIRST" - ) - self.btn_aniso_folder.setToolTip( - "Resize the images in the selected folder to be isotropic, based on the chosen resolutions above." - "\nDOES NOT WORK WITH INSTANCE LABELS, CONVERT TO SEMANTIC FIRST" - ) - ############################### - - self.build() - - def build(self): - """Builds the layout of the widget with the following buttons : - - * Set path to results - - * Set path to labels - - * A button to convert a folder of labels to semantic labels - - * A button to convert a folder of labels to instance labels - - * A button to convert a currently selected layer to semantic labels - - * A button to convert a currently selected layer to instance labels - """ - - l, t, r, b = 7, 20, 7, 11 - - w, layout = ui.make_container() - - results_widget = ui.combine_blocks( - right_or_below=self.btn_result_path, - left_or_above=self.lbl_result_path, - min_spacing=70, - ) - - ui.add_to_group( - "Results", - results_widget, - layout, - L=3, - T=11, - R=3, - B=3, - ) - ############################### - ui.add_blank(layout=layout, widget=self) - ############################### - aniso_group_w, aniso_group_l = ui.make_group( - "Correct anisotropy", l, t, r, b, parent=None - ) - - ui.add_widgets( - aniso_group_l, - [ - self.anisotropy_converter, - ], - ui.LEFT_AL, - ) - - aniso_group_w.setLayout(aniso_group_l) - layout.addWidget(aniso_group_w) - - ############################### - ui.add_blank(layout=layout, widget=self) - ############################################################# - small_group_w, small_group_l = ui.make_group( - "Remove small objects", l, t, r, b, parent=None - ) - - ui.add_widgets( - small_group_l, - [ - self.small_object_thresh_choice, - ], - ui.HCENTER_AL, - ) - - small_group_w.setLayout(small_group_l) - layout.addWidget(small_group_w) - ######################################### - ui.add_blank(layout=layout, widget=self) - ############################################################# - layer_group_w, layer_group_l = ui.make_group( - "Convert selected layer", l, t, r, b, parent=None - ) - - ui.add_widgets( - layer_group_l, - [ - self.btn_convert_layer_instance, - self.btn_convert_layer_semantic, - self.btn_remove_small_layer, - self.btn_aniso_layer, - ], - ui.HCENTER_AL, - ) - - layer_group_w.setLayout(layer_group_l) - layout.addWidget(layer_group_w) - ############################### - ui.add_blank(layout=layout, widget=self) - ############################### - folder_group_w, folder_group_l = ui.make_group( - "Convert folder", l, t, r, b, parent=None - ) - - folder_group_l.addWidget( - ui.combine_blocks( - right_or_below=self.btn_label_files, - left_or_above=self.lbl_label_files, - min_spacing=70, - ) - ) - - ui.add_widgets( - folder_group_l, - [ - self.btn_convert_folder_instance, - self.btn_convert_folder_semantic, - self.btn_remove_small_folder, - self.btn_aniso_folder, - ], - ui.HCENTER_AL, - ) - - folder_group_w.setLayout(folder_group_l) - layout.addWidget(folder_group_w) - ############################### - ui.add_blank(layout=layout, widget=self) - - ui.add_widgets( - layout, - [ - ui.add_blank(self), - self.make_close_button(), - ui.add_blank(self), - self.lbl_error, - ], - ) - - ui.ScrollArea.make_scrollable( - layout, self, min_wh=[230, 400], base_wh=[230, 450] - ) - - def folder_to_semantic(self): - """Converts folder of labels to semantic labels""" - if not self.check_ready_folder(): - return - - folder_name = f"converted_to_semantic_labels_{utils.get_date_time()}" - - images = [ - to_semantic(file, is_file_path=True) - for file in self.labels_filepaths - ] - - self.save_folder(folder_name, images) - - def layer_to_semantic(self): - """Converts selected layer to semantic labels""" - if not self.check_ready_layer(): - return - - im = self._viewer.layers.selection.active.data - name = self._viewer.layers.selection.active.name - semantic_labels = to_semantic(im) - - self.save_layer( - f"{name}_semantic_{utils.get_time_filepath()}" - + self.filetype_choice.currentText(), - semantic_labels, - ) - - self._viewer.add_labels(semantic_labels, name=f"converted_semantic") - - def folder_to_instance(self): - """Converts the chosen folder to instance labels""" - if not self.check_ready_folder(): - return - - images = [ - to_instance(file, is_file_path=True) - for file in self.labels_filepaths - ] - - self.save_folder( - f"converted_to_instance_labels_{utils.get_date_time()}", images - ) - - def layer_to_instance(self): - """Converts the selected layer to instance labels""" - if not self.check_ready_layer(): - return - - im = [self._viewer.layers.selection.active.data] - name = self._viewer.layers.selection.active.name - instance_labels = to_instance(im) - - self.save_layer( - f"{name}_instance_{utils.get_time_filepath()}" - + self.filetype_choice.currentText(), - instance_labels, - ) - - self._viewer.add_labels(instance_labels, name=f"converted_instance") - - def layer_remove_small(self): - """Removes small objects in selected layer""" - if not self.check_ready_layer(): - return - - im = self._viewer.layers.selection.active.data - name = self._viewer.layers.selection.active.name - - cleared_labels = clear_small_objects( - im, self.small_object_thresh_choice.value() - ) - - self.save_layer( - f"{name}_cleared_{utils.get_time_filepath()}" - + self.filetype_choice.currentText(), - cleared_labels, - ) - - self._viewer.add_image(cleared_labels, name=f"small_cleared") - - def folder_remove_small(self): - """Removes small objects in folder of labels""" - if not self.check_ready_folder(): - return - - images = [ - clear_small_objects( - file, - self.small_object_thresh_choice.value(), - is_file_path=True, - ) - for file in self.labels_filepaths - ] - - self.save_folder(f"small_cleared_{utils.get_date_time()}", images) - - def layer_anisotropy(self): - """Corrects anisotropy in the currently selected image""" - if not self.check_ready_layer(): - return - - name = self._viewer.layers.selection.active.name - zoom_factor = self.anisotropy_converter.get_anisotropy_resolution_zyx() - - vol = np.array( - self._viewer.layers.selection.active.data, dtype=np.int16 - ) - isotropic_image = utils.resize(vol, zoom_factor) - - self.save_layer( - f"{name}_isotropic_{utils.get_time_filepath()}" - + self.filetype_choice.currentText(), - isotropic_image, - ) - - self._viewer.add_image(isotropic_image, name=f"isotropic") - - def folder_anisotropy(self): - """Removes anisotropy in folder of images or labels""" - if not self.check_ready_folder(): - return - - zoom_factor = self.anisotropy_converter.get_anisotropy_resolution_zyx() - images = [ - utils.resize(imread(file), zoom_factor) - for file in self.labels_filepaths - ] - - self.save_folder(f"isotropic_{utils.get_date_time()}", images) - - def check_ready_folder(self): # TODO add color change - """Check if results and source folders are correctly set""" - if self.results_path == "": - err = "ERROR : please set results folder" - print(err) - self.lbl_error.setText(err) - self.lbl_error.setVisible(True) - return False - if self.labels_filepaths != [""]: - self.lbl_error.setVisible(False) - return True - - err = "ERROR : please set valid source labels folder" - print(err) - self.lbl_error.setText(err) - self.lbl_error.setVisible(True) - return False - - def check_ready_layer(self): # TODO add color change - """Check if results and layer are selected""" - if self.results_path == "": - err = "ERROR : please set results folder" - print(err) - self.lbl_error.setText(err) - self.lbl_error.setVisible(True) - return False - if self._viewer.layers.selection.active is None: - err = "ERROR : Please select a single layer" - print(err) - self.lbl_error.setText(err) - self.lbl_error.setVisible(True) - return False - self.lbl_error.setVisible(False) - return True - - def save_layer(self, file_name, image): - - path = os.path.join(self.results_path, file_name) - print(self.results_path) - print(path) - - if self.results_path != "": - imwrite( - path, - image, - ) - - def save_folder(self, folder_name, images): - - results_folder = os.path.join( - self.results_path, - folder_name, - ) - - os.makedirs(results_folder, exist_ok=False) - - for file, image in zip(self.labels_filepaths, images): - - path = os.path.join(results_folder, os.path.basename(file)) - - imwrite( - path, - image, - ) diff --git a/napari_cellseg3d/plugin_crop.py b/napari_cellseg3d/plugin_crop.py deleted file mode 100644 index d7ad0172..00000000 --- a/napari_cellseg3d/plugin_crop.py +++ /dev/null @@ -1,492 +0,0 @@ -import os -import warnings - -import napari -import numpy as np -from magicgui import magicgui -from magicgui.widgets import Container -from magicgui.widgets import Slider - -# Qt -from qtpy.QtWidgets import QSizePolicy -from tifffile import imwrite - -# local -from napari_cellseg3d import interface as ui -from napari_cellseg3d import utils -from napari_cellseg3d.plugin_base import BasePluginSingleImage - -DEFAULT_CROP_SIZE = 64 - - -class Cropping(BasePluginSingleImage): - """A utility plugin for cropping 3D volumes.""" - - def __init__(self, viewer: "napari.viewer.Viewer", parent): - """Creates a Cropping plugin with several buttons : - - * Open file prompt to select volumes directory - - * Open file prompt to select labels directory - - * A dropdown menu with a choice of png or tif filetypes - - * Three spinboxes to choose the dimensions of the cropped volume in x, y, z - - * A button to launch the cropping process (see :doc:`plugin_crop`) - - * A button to close the widget - """ - - super().__init__(viewer, parent) - - self.btn_start = ui.Button("Start", self.start, self) - - self.crop_label_choice = ui.make_checkbox( - "Crop labels simultaneously", self.toggle_label_path - ) - self.lbl_label.setVisible(False) - self.btn_label.setVisible(False) - - self.box_widgets = ui.IntIncrementCounter.make_n( - 3, 1, 1000, DEFAULT_CROP_SIZE - ) - self.box_lbl = [ - ui.make_label("Size in " + axis + " of cropped volume :", self) - for axis in "xyz" - ] - - self.aniso_widgets = ui.AnisotropyWidgets(self) - ########### - for box in self.box_widgets: - box.setSizePolicy(QSizePolicy.Fixed, QSizePolicy.Fixed) - self._x = 0 - self._y = 0 - self._z = 0 - self._crop_size_x = DEFAULT_CROP_SIZE - self._crop_size_y = DEFAULT_CROP_SIZE - self._crop_size_z = DEFAULT_CROP_SIZE - - self.aniso_factors = [1, 1, 1] - - self.image = None - self.image_layer = None - self.label = None - self.label_layer = None - - self.highres_crop_layer = None - self.labels_crop_layer = None - - self.crop_labels = False - - self.build() - - def toggle_label_path(self): - if self.crop_label_choice.isChecked(): - self.lbl_label.setVisible(True) - self.btn_label.setVisible(True) - else: - self.lbl_label.setVisible(False) - self.btn_label.setVisible(False) - - def build(self): - """Build buttons in a layout and add them to the napari Viewer""" - - w, layout = ui.make_container(0, 0, 1, 11) - - data_group_w, data_group_l = ui.make_group("Data") - - ui.add_widgets( - data_group_l, - [ - ui.combine_blocks(self.btn_image, self.lbl_image), - self.crop_label_choice, # whether to crop labels or no - ui.combine_blocks(self.btn_label, self.lbl_label), - self.file_handling_box, - self.filetype_choice, - self.aniso_widgets, - ], - ) - - self.crop_label_choice.toggle() - self.toggle_label_path() - - self.filetype_choice.setVisible(False) - - data_group_w.setLayout(data_group_l) - layout.addWidget(data_group_w) - ###################### - ui.add_blank(self, layout) - ###################### - dim_group_w, dim_group_l = ui.make_group("Dimensions") - [ - dim_group_l.addWidget(widget, alignment=ui.LEFT_AL) - for list in zip(self.box_lbl, self.box_widgets) - for widget in list - ] - dim_group_w.setLayout(dim_group_l) - layout.addWidget(dim_group_w) - ##################### - ##################### - ui.add_blank(self, layout) - ##################### - ##################### - ui.add_widgets( - layout, - [ - self.btn_start, - self.btn_close, - ], - ) - - ui.ScrollArea.make_scrollable(layout, self, min_wh=[180, 100]) - - def quicksave(self): - """Quicksaves the cropped volume in the folder from which they originate, with their original file extension. - - * If images are present, saves the cropped version as a single file or image stacks folder depending on what was loaded. - - * If labels are present, saves the cropped version as a single file or 2D stacks folder depending on what was loaded. - """ - - viewer = self._viewer - - time = utils.get_date_time() - if not self.as_folder: - if self.image is not None: - im_filename = os.path.basename(self.image_path).split(".")[0] - # print(im_filename) - im_dir = os.path.split(self.image_path)[0] + "/cropped" - # print(im_dir) - os.makedirs(im_dir, exist_ok=True) - viewer.layers["cropped"].save( - im_dir + "/" + im_filename + "_cropped_" + time + ".tif" - ) - - # print(self.label) - if self.label is not None: - im_filename = os.path.basename(self.label_path).split(".")[0] - # print(im_filename) - im_dir = os.path.split(self.label_path)[0] + "/cropped" - # print(im_dir) - name = ( - im_dir - + "/" - + im_filename - + "_labels_cropped_" - + time - + ".tif" - ) - dat = viewer.layers["cropped_labels"].data - os.makedirs(im_dir, exist_ok=True) - imwrite(name, data=dat) - - else: - if self.image is not None: - - # im_filename = os.path.basename(self.image_path).split(".")[0] - im_dir = os.path.split(self.image_path)[0] - - dat = viewer.layers["cropped"].data - dir_name = im_dir + "/volume_cropped_" + time - utils.save_stack(dat, dir_name, filetype=self.filetype) - - # print(self.label) - if self.label is not None: - - # im_filename = os.path.basename(self.image_path).split(".")[0] - im_dir = os.path.split(self.label_path)[0] - - dir_name = im_dir + "/labels_cropped_" + time - # print(f"dir name {dir_name}") - dat = viewer.layers["cropped_labels"].data - utils.save_stack(dat, dir_name, filetype=self.filetype) - - def check_ready(self): - - if self.image_path == "" or ( - self.crop_labels and self.label_path == "" - ): - warnings.warn("Please set all required paths correctly") - return False - return True - - def reset(self): - """Resets all layers and docked widgets""" - - self._viewer.layers.clear() - - self.remove_docked_widgets() - - def start(self): - """Launches cropping process by loading the files from the chosen folders, - and adds control widgets to the napari Viewer for moving the cropped volume. - """ - - self.as_folder = self.file_handling_box.isChecked() - self.filetype = self.filetype_choice.currentText() - self.crop_labels = self.crop_label_choice.isChecked() - - if self.aniso_widgets.is_enabled(): - self.aniso_factors = ( - self.aniso_widgets.get_anisotropy_resolution_zyx() - ) - - if not self.check_ready(): - return - - self.reset() - - self.image = utils.load_images( - self.image_path, self.filetype, self.as_folder - ) - - if len(self.image.shape) > 3: - self.image = np.squeeze(self.image) - - if self.crop_labels: - self.label = utils.load_images( - self.label_path, self.filetype, self.as_folder - ) - - if len(self.label.shape) > 3: - self.label = np.squeeze( - self.label - ) # if channel/batch remnants from MONAI - - vw = self._viewer - - vw.dims.ndisplay = 3 - vw.scale_bar.visible = True - - # add image and labels - self.image_layer = vw.add_image( - self.image, - colormap="inferno", - contrast_limits=[200, 1000], - opacity=0.7, - scale=self.aniso_factors, - ) - - if self.crop_labels: - self.label_layer = vw.add_labels( - self.label, scale=self.aniso_factors, visible=False - ) - - @magicgui(call_button="Quicksave") - def save_widget(): - return self.quicksave() - - save = self._viewer.window.add_dock_widget( - save_widget, name="", area="left" - ) - self.docked_widgets.append(save) - - self.add_crop_sliders() - - def add_crop_sliders( - self, - ): - # modified version of code posted by Juan Nunez Iglesias here : - # https://forum.image.sc/t/napari-viewing-3d-image-of-large-tif-stack-cropping-image-w-general-shape/55500/2 - vw = self._viewer - - image_stack = np.array(self.image) - - self._crop_size_x, self._crop_size_y, self._crop_size_z = [ - box.value() for box in self.box_widgets - ] - - self._x = 0 - self._y = 0 - self._z = 0 - - # print(f"Crop variables") - # print(image_stack.shape) - - # define crop sizes and boundaries for the image - crop_sizes = [self._crop_size_x, self._crop_size_y, self._crop_size_z] - for i in range(len(crop_sizes)): - if crop_sizes[i] > image_stack.shape[i]: - crop_sizes[i] = image_stack.shape[i] - warnings.warn( - f"WARNING : Crop dimension in axis {i} was too large at {crop_sizes[i]}, it was set to {image_stack.shape[i]}" - ) - cropx, cropy, cropz = crop_sizes - # shapez, shapey, shapex = image_stack.shape - ends = np.asarray(image_stack.shape) - np.asarray(crop_sizes) + 1 - - stepsizes = ends // 100 - - # print(crop_sizes) - - # print(ends) - # print(stepsizes) - - self.highres_crop_layer = vw.add_image( - image_stack[:cropx, :cropy, :cropz], - name="cropped", - blending="additive", - colormap="twilight_shifted", - scale=self.image_layer.scale, - ) - - if self.crop_labels: - label_stack = self.label - self.labels_crop_layer = vw.add_labels( - self.label[:cropx, :cropy, :cropz], - name="cropped_labels", - scale=self.label_layer.scale, - ) - - def set_slice( - axis, - value, - highres_crop_layer, - labels_crop_layer=None, - crop_lbls=False, - ): - """ "Update cropped volume position""" - idx = int(value) - scale = np.asarray(highres_crop_layer.scale) - translate = np.asarray(highres_crop_layer.translate) - izyx = translate // scale - izyx[axis] = idx - izyx = [int(var) for var in izyx] - i, j, k = izyx - - cropx = self._crop_size_x - cropy = self._crop_size_y - cropz = self._crop_size_z - - highres_crop_layer.data = image_stack[ - i : i + cropx, j : j + cropy, k : k + cropz - ] - highres_crop_layer.translate = scale * izyx - highres_crop_layer.refresh() - - if crop_lbls and labels_crop_layer is not None: - labels_crop_layer.data = label_stack[ - i : i + cropx, j : j + cropy, k : k + cropz - ] - labels_crop_layer.translate = scale * izyx - labels_crop_layer.refresh() - - self._x = i - self._y = j - self._z = k - - # spinbox = SpinBox(name="crop_dims", min=1, value=self._crop_size, max=max(image_stack.shape), step=1) - # spinbox.changed.connect(lambda event : change_size(event)) - - sliders = [ - Slider(name=axis, min=0, max=end, step=step) - for axis, end, step in zip("zyx", ends, stepsizes) - ] - for axis, slider in enumerate(sliders): - slider.changed.connect( - lambda event, axis=axis: set_slice( - axis, - event, - self.highres_crop_layer, - self.labels_crop_layer, - self.crop_labels, - ) - ) - container_widget = Container(layout="vertical") - container_widget.extend(sliders) - # vw.window.add_dock_widget([spinbox, container_widget], area="right") - wdgts = vw.window.add_dock_widget(container_widget, area="right") - self.docked_widgets.append(wdgts) - # TEST : trying to dynamically change the size of the cropped volume - # BROKEN for now - # @spinbox.changed.connect - # def change_size(value: int): - # - # print(value) - # i = self._x - # j = self._y - # k = self._z - # - # self._crop_size = value - # - # cropx = value - # cropy = value - # cropz = value - # highres_crop_layer.data = image_stack[ - # i : i + cropz, j : j + cropy, k : k + cropx - # ] - # highres_crop_layer.refresh() - # labels_crop_layer.data = label_stack[ - # i : i + cropz, j : j + cropy, k : k + cropx - # ] - # labels_crop_layer.refresh() - # - - -################################# -################################# -################################# -# code for dynamically changing cropped volume with sliders, one for each dim -# WARNING : broken for now - -# def change_size(axis, value) : - -# print(value) -# print(axis) -# index = int(value) -# scale = np.asarray(highres_crop_layer.scale) -# translate = np.asarray(highres_crop_layer.translate) -# izyx = translate // scale -# izyx[axis] = index -# izyx = [int(el) for el in izyx] - -# cropz,cropy,cropx = izyx - -# i = self._x -# j = self._y -# k = self._z - -# self._crop_size_x = cropx -# self._crop_size_y = cropy -# self._crop_size_z = cropz - - -# highres_crop_layer.data = image_stack[ -# i : i + cropz, j : j + cropy, k : k + cropx -# ] -# highres_crop_layer.refresh() -# labels_crop_layer.data = label_stack[ -# i : i + cropz, j : j + cropy, k : k + cropx -# ] -# labels_crop_layer.refresh() - - -# # @spinbox.changed.connect -# # spinbox = SpinBox(name=crop_dims, min=1, max=max(image_stack.shape), step=1) -# # spinbox.changed.connect(lambda event : change_size(event)) - - -# sliders = [ -# Slider(name=axis, min=0, max=end, step=step) -# for axis, end, step in zip("zyx", ends, stepsizes) -# ] -# for axis, slider in enumerate(sliders): -# slider.changed.connect( -# lambda event, axis=axis: set_slice(axis, event) -# ) - -# spinboxes = [ -# SpinBox(name=axes+" crop size", min=1, value=self._crop_size_init, max=end, step=1) -# for axes, end in zip("zyx", image_stack.shape) -# ] -# for axes, box in enumerate(spinboxes): -# box.changed.connect( -# lambda event, axes=axes : change_size(axis, event) -# ) - - -# container_widget = Container(layout="vertical") -# container_widget.extend(sliders) -# container_widget.extend(spinboxes) -# vw.window.add_dock_widget(container_widget, area="right") diff --git a/napari_cellseg3d/plugin_review.py b/napari_cellseg3d/plugin_review.py deleted file mode 100644 index 19f4f3d9..00000000 --- a/napari_cellseg3d/plugin_review.py +++ /dev/null @@ -1,226 +0,0 @@ -import os -import warnings - -import napari -import numpy as np -import pims -import skimage.io as io - -# Qt -from qtpy.QtWidgets import QLineEdit -from qtpy.QtWidgets import QSizePolicy - -# local -from napari_cellseg3d import interface as ui -from napari_cellseg3d import utils -from napari_cellseg3d.launch_review import launch_review -from napari_cellseg3d.plugin_base import BasePluginSingleImage - -warnings.formatwarning = utils.format_Warning - - -class Reviewer(BasePluginSingleImage): - """A plugin for selecting volumes and labels file and launching the review process. - Inherits from : :doc:`plugin_base`""" - - def __init__(self, viewer: "napari.viewer.Viewer"): - """Creates a Reviewer plugin with several buttons : - - * Open file prompt to select volumes directory - - * Open file prompt to select labels directory - - * A dropdown menu with a choice of png or tif filetypes - - * A checkbox if you want to create a new status csv for the dataset - - * A button to launch the review process (see :doc:`launch_review`) - """ - - super().__init__(viewer) - - # self._viewer = viewer - - self.textbox = QLineEdit(self) - self.textbox.setSizePolicy(QSizePolicy.Fixed, QSizePolicy.Fixed) - - self.checkBox = ui.make_checkbox("Create new dataset ?") - - self.btn_start = ui.Button("Start reviewing", self.run_review, self) - - self.lbl_mod = ui.make_label("Name", self) - - self.warn_label = ui.make_label( - "WARNING : You already have a review session running.\n" - "Launching another will close the current one,\n" - " make sure to save your work beforehand", - None, - ) - - self.anisotropy_widgets = ui.AnisotropyWidgets( - self, default_x=1.5, default_y=1.5, default_z=5 - ) - - ########################### - # tooltips - self.textbox.setToolTip("Name of the csv results file") - self.checkBox.setToolTip( - "Ignore any pre-existing csv with the specified name and create a new one" - ) - ########################### - - self.build() - - def build(self): - """Build buttons in a layout and add them to the napari Viewer""" - - self.setSizePolicy(QSizePolicy.Maximum, QSizePolicy.MinimumExpanding) - - tab, layout = ui.make_container(0, 0, 1, 1) - - # ui.add_blank(self, layout) - ########################### - data_group_w, data_group_l = ui.make_group("Data") - - ui.add_widgets( - data_group_l, - [ - ui.combine_blocks( - self.filetype_choice, - self.file_handling_box, - horizontal=False, - ), - ui.combine_blocks(self.btn_image, self.lbl_image), - ui.combine_blocks(self.btn_label, self.lbl_label), - ], - ) - - self.filetype_choice.setVisible(False) - - data_group_w.setLayout(data_group_l) - layout.addWidget(data_group_w) - ########################### - ui.add_blank(self, layout) - ########################### - ui.add_to_group("Image parameters", self.anisotropy_widgets, layout) - ########################### - ui.add_blank(self, layout) - ########################### - csv_param_w, csv_param_l = ui.make_group("CSV parameters") - - ui.add_widgets( - csv_param_l, - [ - ui.combine_blocks( - self.textbox, - self.lbl_mod, - horizontal=False, - l=5, - t=0, - r=5, - b=5, - ), - self.checkBox, - ], - ) - - csv_param_w.setLayout(csv_param_l) - layout.addWidget(csv_param_w) - ########################### - ui.add_blank(self, layout) - ########################### - - ui.add_widgets(layout, [self.btn_start, self.btn_close]) - - ui.ScrollArea.make_scrollable( - contained_layout=layout, parent=tab, min_wh=[190, 300] - ) - - self.addTab(tab, "Review") - - self.setMinimumSize(180, 100) - # self.show() - # self._viewer.window.add_dock_widget(self, name="Reviewer", area="right") - - def run_review(self): - - """Launches review process by loading the files from the chosen folders, - and adds several widgets to the napari Viewer. - If the review process has been launched once before, - closes the window entirely and launches the review process in a fresh window. - - TODO: - - * Save work done before leaving - - See :doc:`launch_review` - - Returns: - napari.viewer.Viewer: self.viewer - """ - - self.reset() - - self.filetype = self.filetype_choice.currentText() - self.as_folder = self.file_handling_box.isChecked() - if self.anisotropy_widgets.is_enabled(): - zoom = self.anisotropy_widgets.get_anisotropy_resolution_zyx( - as_factors=True - ) - else: - zoom = [1, 1, 1] - - images = utils.load_images( - self.image_path, self.filetype, self.as_folder - ) - if ( - self.label_path == "" # TODO check if it works - ): # saves empty images of the same size as original images - if self.as_folder: - labels = np.zeros_like(images.compute()) # dask to numpy - self.label_path = os.path.join( - os.path.dirname(self.image_path), self.textbox.text() - ) - os.makedirs(self.label_path, exist_ok=True) - - for i in range(len(labels)): - io.imsave( - os.path.join( - self.label_path, str(i).zfill(4) + self.filetype - ), - labels[i], - ) - else: - labels = utils.load_saved_masks( - self.label_path, - self.filetype, - self.as_folder, - ) - try: - labels_raw = utils.load_raw_masks( - self.label_path + "_raw", self.filetype - ) - except pims.UnknownFormatError: - labels_raw = None - except FileNotFoundError: - # TODO : might not work, test with predi labels later - labels_raw = None - - print("New review session\n" + "*" * 20) - previous_viewer = self._viewer - self._viewer, self.docked_widgets = launch_review( - images, - labels, - labels_raw, - self.label_path, - self.textbox.text(), - self.checkBox.isChecked(), - self.filetype, - self.as_folder, - zoom, - ) - previous_viewer.close() - - def reset(self): - self._viewer.layers.clear() - self.remove_docked_widgets() diff --git a/napari_cellseg3d/plugin_utilities.py b/napari_cellseg3d/plugin_utilities.py deleted file mode 100644 index de4360a7..00000000 --- a/napari_cellseg3d/plugin_utilities.py +++ /dev/null @@ -1,41 +0,0 @@ -import napari - -# Qt -from qtpy.QtWidgets import QSizePolicy -from qtpy.QtWidgets import QTabWidget - -from napari_cellseg3d.plugin_convert import ConvertUtils - -# local -from napari_cellseg3d.plugin_crop import Cropping -from napari_cellseg3d.plugin_metrics import MetricsUtils - - -class Utilities(QTabWidget): - def __init__(self, viewer: "napari.viewer.Viewer"): - - super().__init__() - - self._viewer = viewer - - self.cropping_tab = Cropping(viewer, parent=self) - self.metrics_tab = MetricsUtils(viewer, parent=self) - self.convert_tab = ConvertUtils(viewer, parent=self) - - self.build() - - def build(self): - - self.addTab(self.convert_tab, "Convert") - self.addTab(self.metrics_tab, "Metrics") - self.addTab(self.cropping_tab, "Crop") - - self.cropping_tab.setSizePolicy(QSizePolicy.Fixed, QSizePolicy.Fixed) - self.metrics_tab.setSizePolicy(QSizePolicy.Fixed, QSizePolicy.Fixed) - self.convert_tab.setSizePolicy(QSizePolicy.Fixed, QSizePolicy.Fixed) - - self.setBaseSize(230, 550) - self.setMinimumSize(230, 100) - - def remove_from_viewer(self): - self._viewer.window.remove_dock_widget(self) diff --git a/napari_cellseg3d/plugins.py b/napari_cellseg3d/plugins.py index 44f472f6..f0d74386 100644 --- a/napari_cellseg3d/plugins.py +++ b/napari_cellseg3d/plugins.py @@ -1,8 +1,8 @@ -from napari_cellseg3d.plugin_helper import Helper -from napari_cellseg3d.plugin_model_inference import Inferer -from napari_cellseg3d.plugin_model_training import Trainer -from napari_cellseg3d.plugin_review import Reviewer -from napari_cellseg3d.plugin_utilities import Utilities +from napari_cellseg3d.code_plugins.plugin_helper import Helper +from napari_cellseg3d.code_plugins.plugin_model_inference import Inferer +from napari_cellseg3d.code_plugins.plugin_model_training import Trainer +from napari_cellseg3d.code_plugins.plugin_review import Reviewer +from napari_cellseg3d.code_plugins.plugin_utilities import Utilities def napari_experimental_provide_dock_widget(): diff --git a/napari_cellseg3d/utils.py b/napari_cellseg3d/utils.py index 498c4830..4c3ef7d4 100644 --- a/napari_cellseg3d/utils.py +++ b/napari_cellseg3d/utils.py @@ -1,10 +1,8 @@ -import os +import logging import warnings from datetime import datetime from pathlib import Path - -import cv2 import numpy as np from dask_image.imread import imread as dask_imread from pandas import DataFrame @@ -12,15 +10,38 @@ from skimage import io from skimage.filters import gaussian from tifffile import imread as tfl_imread -from tqdm import tqdm + +LOGGER = logging.getLogger(__name__) +############### +# Global logging level setting +LOGGER.setLevel(logging.DEBUG) +# LOGGER.setLevel(logging.INFO) +############### """ utils.py ==================================== -Definitions of utility functions and variables +Definitions of utility functions, classes, and variables """ +class Singleton(type): + """ + Singleton class that can only be instantiated once at a time, + with said unique instance always being accessed on call. + Should be used as a metaclass for classes without inheritance (object type) + """ + + _instances = {} + + def __call__(cls, *args, **kwargs): + if cls not in cls._instances: + cls._instances[cls] = super(Singleton, cls).__call__( + *args, **kwargs + ) + return cls._instances[cls] + + def normalize_x(image): """Normalizes the values of an image array to be between [-1;1] rather than [0;255] @@ -110,6 +131,7 @@ def resize(image, zoom_factors): isotropic_image = Zoom( zoom_factors, keep_size=False, + mode="nearest-exact", padding_mode="empty", )(np.expand_dims(image, axis=0)) return isotropic_image[0] @@ -260,9 +282,7 @@ def annotation_to_input(label_ermito): def check_csv(project_path, ext): - if not os.path.isfile( - os.path.join(project_path, os.path.basename(project_path) + ".csv") - ): + if not Path(Path(project_path) / Path(project_path).name).is_file(): cols = [ "project", "type", @@ -279,14 +299,14 @@ def check_csv(project_path, ext): "notes", ] df = DataFrame(index=[], columns=cols) - filename_pattern_original = os.path.join( - project_path, f"dataset/Original_size/Original/*{ext}" + filename_pattern_original = Path(project_path) / Path( + f"dataset/Original_size/Original/*{ext}" ) images_original = dask_imread(filename_pattern_original) z, y, x = images_original.shape record = Series( [ - os.path.basename(project_path), + Path(project_path).name, "dataset", ".tif", 0, @@ -297,28 +317,26 @@ def check_csv(project_path, ext): x, datetime.datetime.now(), "", - os.path.join(project_path, "dataset/Original_size/Original"), + Path(project_path) / Path("dataset/Original_size/Original"), "", ], index=df.columns, ) df = df.append(record, ignore_index=True) - df.to_csv( - os.path.join(project_path, os.path.basename(project_path) + ".csv") - ) + df.to_csv(Path(project_path) / Path(project_path).name) else: pass -def check_annotations_dir(project_path): - if not os.path.isdir( - os.path.join(project_path, "annotations/Original_size/master") - ): - os.makedirs( - os.path.join(project_path, "annotations/Original_size/master") - ) - else: - pass +# def check_annotations_dir(project_path): +# if not Path( +# Path(project_path) / Path("annotations/Original_size/master") +# ).is_dir(): +# os.makedirs( +# os.path.join(project_path, "annotations/Original_size/master") +# ) +# else: +# pass def fill_list_in_between(lst, n, elem): @@ -347,27 +365,27 @@ def fill_list_in_between(lst, n, elem): return new_list -def check_zarr(project_path, ext): - if not len( - list( - (Path(project_path) / "dataset" / "Original_size").glob("./*.zarr") - ) - ): - filename_pattern_original = os.path.join( - project_path, f"dataset/Original_size/Original/*{ext}" - ) - images_original = dask_imread(filename_pattern_original) - images_original.to_zarr( - os.path.join(project_path, f"dataset/Original_size/Original.zarr") - ) - else: - pass +# def check_zarr(project_path, ext): +# if not len( +# list( +# (Path(project_path) / "dataset" / "Original_size").glob("./*.zarr") +# ) +# ): +# filename_pattern_original = os.path.join( +# project_path, f"dataset/Original_size/Original/*{ext}" +# ) +# images_original = dask_imread(filename_pattern_original) +# images_original.to_zarr( +# os.path.join(project_path, f"dataset/Original_size/Original.zarr") +# ) +# else: +# pass -def check(project_path, ext): - check_csv(project_path, ext) - check_zarr(project_path, ext) - check_annotations_dir(project_path) +# def check(project_path, ext): +# check_csv(project_path, ext) +# check_zarr(project_path, ext) +# check_annotations_dir(project_path) def parse_default_path(possible_paths): @@ -379,19 +397,19 @@ def parse_default_path(possible_paths): Returns: the chosen default path """ - - # print("paths :") - # print(default_paths) - # print(default_path) - - default_paths = [ - p for p in possible_paths if (p != "" and p != [""] and len(p) >= 3) - ] + default_paths = [] + if any(path is not None for path in possible_paths): + default_paths = [ + p for p in possible_paths if p is not None and len(p) > 2 + ] + # default_paths = [ + # path for path in default_paths if path is not None and path != [] + # ] + print(default_paths) if len(default_paths) == 0: - default_path = os.path.expanduser("~") - else: - default_path = max(default_paths) - return default_path + return str(Path.home()) + default_path = max(default_paths, key=len) + return str(default_path) def get_date_time(): @@ -433,10 +451,10 @@ def load_images(dir_or_path, filetype="", as_folder: bool = False): """ if not as_folder: - filename_pattern_original = os.path.join(dir_or_path) + filename_pattern_original = Path(dir_or_path) # print(filename_pattern_original) elif as_folder and filetype != "": - filename_pattern_original = os.path.join(dir_or_path + "/*" + filetype) + filename_pattern_original = Path(dir_or_path + "/*" + filetype) # print(filename_pattern_original) else: raise ValueError("If loading as a folder, filetype must be specified") @@ -471,14 +489,6 @@ def load_saved_masks(mod_mask_dir, filetype, as_folder: bool): return base_label -def load_raw_masks(raw_mask_dir, filetype): - images_raw = load_images(raw_mask_dir, filetype) - # TODO : check that there is no problem with compute when loading as single file - images_raw = images_raw.compute() - base_label = np.where((126 < images_raw) & (images_raw < 171), 255, 0) - return base_label - - def save_stack(images, out_path, filetype=".png", check_warnings=False): """Saves the files in labels at location out_path as a stack of len(labels) .png files @@ -487,82 +497,17 @@ def save_stack(images, out_path, filetype=".png", check_warnings=False): out_path: path to the directory for saving """ num = images.shape[0] - os.makedirs(out_path, exist_ok=True) + p = Path(out_path) + p.mkdir(exist_ok=True) for i in range(num): label = images[i] io.imsave( - os.path.join(out_path, str(i).zfill(4) + filetype), + Path(out_path) / Path(str(i).zfill(4) + filetype), label, check_contrast=check_warnings, ) -def load_X_gray(folder_path): - image_files = [] - for file in os.listdir(folder_path): - base, ext = os.path.splitext(file) - if ext == ".png": - image_files.append(file) - else: - pass - - image_files.sort() - - img = cv2.imread( - folder_path + os.sep + image_files[0], cv2.IMREAD_GRAYSCALE - ) - - images = np.zeros( - (len(image_files), img.shape[0], img.shape[1], 1), np.float32 - ) - for i, image_file in tqdm(enumerate(image_files)): - image = cv2.imread( - folder_path + os.sep + image_file, cv2.IMREAD_GRAYSCALE - ) - image = image[:, :, np.newaxis] - images[i] = normalize_x(image) - - print(images.shape) - - return images, image_files - - -def load_Y_gray(folder_path, thresh=None, normalize=False): - image_files = [] - for file in os.listdir(folder_path): - base, ext = os.path.splitext(file) - if ext == ".png": - image_files.append(file) - else: - pass - - image_files.sort() - - img = cv2.imread( - folder_path + os.sep + image_files[0], cv2.IMREAD_GRAYSCALE - ) - - images = np.zeros( - (len(image_files), img.shape[0], img.shape[1], 1), np.float32 - ) - - for i, image_file in tqdm(enumerate(image_files)): - image = cv2.imread( - folder_path + os.sep + image_file, cv2.IMREAD_GRAYSCALE - ) - if thresh: - ret, image = cv2.threshold(image, thresh, 255, cv2.THRESH_BINARY) - image = image[:, :, np.newaxis] - if normalize: - images[i] = normalize_y(image) - else: - images[i] = image - - print(images.shape) - - return images, image_files - - def select_train_data(dataframe, ori_imgs, label_imgs, ori_filenames): train_img_names = list() for node in dataframe.itertuples(): diff --git a/setup.cfg b/setup.cfg index fff6a9f8..6261d576 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,6 +1,6 @@ [metadata] name = napari-cellseg3d -version = 0.0.1rc4 +version = 0.0.2rc1 author = Cyril Achard, Maxime Vidal, Jessy Lauer, Mackenzie Mathis author_email = cyril.achard@epfl.ch, maxime.vidal@epfl.ch, mackenzie@post.harvard.edu @@ -50,9 +50,9 @@ install_requires = tifffile>=2022.2.9 imageio-ffmpeg>=0.4.5 torch>=1.11 - monai[nibabel,scikit-image,itk,einops]>=0.9.0 + monai[nibabel,einops]>=0.9.0 + itk tqdm - monai>=0.9.0 nibabel scikit-image pillow diff --git a/tox.ini b/tox.ini index 5f8c37d3..139b04f5 100644 --- a/tox.ini +++ b/tox.ini @@ -9,20 +9,21 @@ python = [gh-actions:env] PLATFORM = -; ubuntu-latest: linux + ubuntu-latest: linux ; macos-latest: macos - windows-latest: windows +; windows-latest: windows [testenv] platform = + linux: linux ; macos: darwin -; linux: linux - windows: win32 +; windows: win32 passenv = CI PYTHONPATH GITHUB_ACTIONS - DISPLAY XAUTHORITY + DISPLAY + XAUTHORITY NUMPY_EXPERIMENTAL_ARRAY_FUNCTION PYVISTA_OFF_SCREEN deps =