|
7 | 7 | import numpy as np |
8 | 8 | import pandas as pd |
9 | 9 | import torch |
| 10 | +from torch.nn.functional import softmax |
10 | 11 | from tqdm import tqdm |
11 | 12 |
|
12 | | -from aviary.core import Normalizer |
| 13 | +from aviary.core import Normalizer, sampled_softmax |
13 | 14 | from aviary.utils import get_metrics, print_walltime |
14 | 15 |
|
15 | 16 | if TYPE_CHECKING: |
@@ -103,19 +104,30 @@ def make_ensemble_predictions( |
103 | 104 | pred_col = f"{target_col}_pred_{idx}" if target_col else f"pred_{idx}" |
104 | 105 |
|
105 | 106 | if model.robust: |
106 | | - preds, aleat_log_std = preds.T |
107 | | - ale_col = ( |
108 | | - f"{target_col}_aleatoric_std_{idx}" |
109 | | - if target_col |
110 | | - else f"aleatoric_std_{idx}" |
111 | | - ) |
112 | | - df[pred_col] = preds |
113 | | - df[ale_col] = aleatoric_std = np.exp(aleat_log_std) |
| 107 | + if task_type == "regression": |
| 108 | + preds, aleat_log_std = preds.T |
| 109 | + ale_col = ( |
| 110 | + f"{target_col}_aleatoric_std_{idx}" |
| 111 | + if target_col |
| 112 | + else f"aleatoric_std_{idx}" |
| 113 | + ) |
| 114 | + df[pred_col] = preds |
| 115 | + df[ale_col] = aleatoric_std = np.exp(aleat_log_std) |
| 116 | + elif task_type == "classification": |
| 117 | + # need to convert to tensor to use `sampled_softmax` |
| 118 | + preds = torch.from_numpy(preds).to(device) |
| 119 | + pre_logits, log_std = preds.chunk(2, dim=1) |
| 120 | + logits = sampled_softmax(pre_logits, log_std) |
| 121 | + df[pred_col] = logits.argmax(dim=1).cpu().numpy() |
114 | 122 | else: |
115 | | - df[pred_col] = preds |
| 123 | + if task_type == "regression": |
| 124 | + df[pred_col] = preds |
| 125 | + else: |
| 126 | + logits = softmax(preds, dim=1) |
| 127 | + df[pred_col] = logits.argmax(dim=1).cpu().numpy() |
116 | 128 |
|
117 | 129 | # denormalize predictions if a normalizer was used during training |
118 | | - if "normalizer_dict" in checkpoint: |
| 130 | + if checkpoint["normalizer_dict"][target_name] is not None: |
119 | 131 | assert task_type == "regression", "Normalization only takes place for regression." |
120 | 132 | normalizer = Normalizer.from_state_dict( |
121 | 133 | checkpoint["normalizer_dict"][target_name] |
|
0 commit comments