diff --git a/2d_classification/mednist_tutorial.ipynb b/2d_classification/mednist_tutorial.ipynb index cd7a1c9a67..34183277ca 100644 --- a/2d_classification/mednist_tutorial.ipynb +++ b/2d_classification/mednist_tutorial.ipynb @@ -319,24 +319,22 @@ "source": [ "val_frac = 0.1\n", "test_frac = 0.1\n", - "train_x = []\n", - "train_y = []\n", - "val_x = []\n", - "val_y = []\n", - "test_x = []\n", - "test_y = []\n", + "length = len(image_files_list)\n", + "indices = np.arange(length)\n", + "np.random.shuffle(indices)\n", "\n", - "for i in range(num_total):\n", - " rann = np.random.random()\n", - " if rann < val_frac:\n", - " val_x.append(image_files_list[i])\n", - " val_y.append(image_class[i])\n", - " elif rann < test_frac + val_frac:\n", - " test_x.append(image_files_list[i])\n", - " test_y.append(image_class[i])\n", - " else:\n", - " train_x.append(image_files_list[i])\n", - " train_y.append(image_class[i])\n", + "test_split = int(test_frac * length)\n", + "val_split = int(val_frac * length) + test_split\n", + "test_indices = indices[:test_split]\n", + "val_indices = indices[test_split:val_split]\n", + "train_indices = indices[val_split:]\n", + "\n", + "train_x = [image_files_list[i] for i in train_indices]\n", + "train_y = [image_class[i] for i in train_indices]\n", + "val_x = [image_files_list[i] for i in val_indices]\n", + "val_y = [image_class[i] for i in val_indices]\n", + "test_x = [image_files_list[i] for i in test_indices]\n", + "test_y = [image_class[i] for i in test_indices]\n", "\n", "print(\n", " f\"Training count: {len(train_x)}, Validation count: \"\n",