Skip to content

Commit fe70c98

Browse files
authored
Added predictions for classification task. Fixed the check for normalizer presence. (#93)
1 parent a2de5b7 commit fe70c98

File tree

1 file changed

+23
-11
lines changed

1 file changed

+23
-11
lines changed

aviary/predict.py

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,10 @@
77
import numpy as np
88
import pandas as pd
99
import torch
10+
from torch.nn.functional import softmax
1011
from tqdm import tqdm
1112

12-
from aviary.core import Normalizer
13+
from aviary.core import Normalizer, sampled_softmax
1314
from aviary.utils import get_metrics, print_walltime
1415

1516
if TYPE_CHECKING:
@@ -103,19 +104,30 @@ def make_ensemble_predictions(
103104
pred_col = f"{target_col}_pred_{idx}" if target_col else f"pred_{idx}"
104105

105106
if model.robust:
106-
preds, aleat_log_std = preds.T
107-
ale_col = (
108-
f"{target_col}_aleatoric_std_{idx}"
109-
if target_col
110-
else f"aleatoric_std_{idx}"
111-
)
112-
df[pred_col] = preds
113-
df[ale_col] = aleatoric_std = np.exp(aleat_log_std)
107+
if task_type == "regression":
108+
preds, aleat_log_std = preds.T
109+
ale_col = (
110+
f"{target_col}_aleatoric_std_{idx}"
111+
if target_col
112+
else f"aleatoric_std_{idx}"
113+
)
114+
df[pred_col] = preds
115+
df[ale_col] = aleatoric_std = np.exp(aleat_log_std)
116+
elif task_type == "classification":
117+
# need to convert to tensor to use `sampled_softmax`
118+
preds = torch.from_numpy(preds).to(device)
119+
pre_logits, log_std = preds.chunk(2, dim=1)
120+
logits = sampled_softmax(pre_logits, log_std)
121+
df[pred_col] = logits.argmax(dim=1).cpu().numpy()
114122
else:
115-
df[pred_col] = preds
123+
if task_type == "regression":
124+
df[pred_col] = preds
125+
else:
126+
logits = softmax(preds, dim=1)
127+
df[pred_col] = logits.argmax(dim=1).cpu().numpy()
116128

117129
# denormalize predictions if a normalizer was used during training
118-
if "normalizer_dict" in checkpoint:
130+
if checkpoint["normalizer_dict"][target_name] is not None:
119131
assert task_type == "regression", "Normalization only takes place for regression."
120132
normalizer = Normalizer.from_state_dict(
121133
checkpoint["normalizer_dict"][target_name]

0 commit comments

Comments
 (0)