Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 12 additions & 8 deletions modules/decollate_batch.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -76,11 +76,11 @@
"name": "stdout",
"output_type": "stream",
"text": [
"MONAI version: 0.4.0+546.gfbd6e65\n",
"MONAI version: 0.4.0+540.g8803885\n",
"Numpy version: 1.21.0\n",
"Pytorch version: 1.9.0+cu102\n",
"MONAI flags: HAS_EXT = False, USE_COMPILED = False\n",
"MONAI rev id: fbd6e6597a66b64821f6ef0b4da2560b103c00b9\n",
"MONAI rev id: 88038854f6cd256989695c2368e3ee9fca213e8d\n",
"\n",
"Optional dependencies:\n",
"Pytorch Ignite version: 0.4.5\n",
Expand Down Expand Up @@ -260,19 +260,22 @@
"metadata": {},
"source": [
"## Setup postprocessing transforms, metrics\n",
"Here we try to invert the preprocessing predictions for `pred` and save into Nifti files."
"Here we try to invert the preprocessing predictions for `pred` and save into Nifti files.\n",
"\n",
"As all the post processing transforms expect `Tensor` input, apply `ToTensord` first to ensure the data type after `decollate_batch`."
]
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 9,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"postprocessing = Compose(\n",
" [\n",
" ToTensord(keys=[\"pred\", \"seg\"]), # ensure Tensor type after `decollate`\n",
" Activationsd(keys=\"pred\", sigmoid=True),\n",
" Invertd(\n",
" keys=\"pred\", # invert the `pred` data field, also support multiple fields\n",
Expand Down Expand Up @@ -329,9 +332,10 @@
"\n",
" # decollate the batch data into list of dictionaries, every dictionary maps to an input data\n",
" data = [postprocessing(i) for i in decollate_batch(data)]\n",
" # compute metric for current iteration\n",
" # extract a list of `prections` and a list of `labels` with the `from_engine` utility\n",
" dice_metric(y_pred=from_engine(\"pred\")(data), y=from_engine(\"seg\")(data))\n",
" pred, y = from_engine([\"pred\", \"seg\"])(data)\n",
" # compute mean dice for current iteration\n",
" dice_metric(y_pred=pred, y=y)\n",
" # aggregate the final mean dice result\n",
" print(f\"evaluation metric: {dice_metric.aggregate().item()}\")\n",
" # reset the metric status\n",
Expand All @@ -349,7 +353,7 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 12,
"metadata": {
"pycharm": {
"is_executing": true
Expand Down Expand Up @@ -378,7 +382,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.10"
"version": "3.8.5"
}
},
"nbformat": 4,
Expand Down
4 changes: 3 additions & 1 deletion modules/transfer_mmar.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,7 @@
],
"source": [
"train_ds = LMDBDataset(data=train_files, transform=train_transforms, cache_dir=root_dir)\n",
"# initialize cache and print meta information\n",
"print(train_ds.info())\n",
"\n",
"# use batch_size=2 to load images and use RandCropByPosNegLabeld\n",
Expand All @@ -433,6 +434,7 @@
"# the validation data loader will be created on the fly to ensure \n",
"# a deterministic validation set for demo purpose.\n",
"val_ds = LMDBDataset(data=val_files, transform=val_transforms, cache_dir=root_dir)\n",
"# initialize cache and print meta information\n",
"print(val_ds.info())"
]
},
Expand Down Expand Up @@ -825,7 +827,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.10"
"version": "3.8.5"
}
},
"nbformat": 4,
Expand Down