Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ ci:

repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.5.3
rev: v0.5.5
hooks:
- id: ruff
args: [--fix]
Expand Down
4 changes: 1 addition & 3 deletions aviary/cgcnn/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,9 +148,7 @@ def __getitem__(self, idx: int):
site_atoms = [atom.species.as_dict() for atom in struct]
atom_features = np.vstack(
[
np.sum(
[self.elem_features[el] * amt for el, amt in site.items()], axis=0
)
np.sum([self.elem_features[el] * amt for el, amt in site.items()], axis=0)
for site in site_atoms
]
)
Expand Down
4 changes: 1 addition & 3 deletions aviary/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,9 +261,7 @@ def evaluate(
# *_ discards identifiers like material_id and formula which we don't need when
# training tqdm(disable=None) means suppress output in non-tty (e.g. CI/log
# files) but keep in terminal (i.e. tty mode) https://git.io/JnBOi
for inputs, targets_list, *_ in tqdm(
data_loader, disable=None if pbar else True
):
for inputs, targets_list, *_ in tqdm(data_loader, disable=None if pbar else True):
inputs = [ # noqa: PLW2901
tensor.to(self.device) if hasattr(tensor, "to") else tensor
for tensor in inputs
Expand Down
18 changes: 9 additions & 9 deletions aviary/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,15 @@ def __init__(
dims = [input_dim, *list(hidden_layer_dims)]

self.fcs = nn.ModuleList(
nn.Linear(dims[i], dims[i + 1]) for i in range(len(dims) - 1)
nn.Linear(dims[idx], dims[idx + 1]) for idx in range(len(dims) - 1)
)

if batch_norm:
self.bns = nn.ModuleList(
nn.BatchNorm1d(dims[i + 1]) for i in range(len(dims) - 1)
nn.BatchNorm1d(dims[idx + 1]) for idx in range(len(dims) - 1)
)
else:
self.bns = nn.ModuleList(nn.Identity() for i in range(len(dims) - 1))
self.bns = nn.ModuleList(nn.Identity() for _ in range(len(dims) - 1))

self.acts = nn.ModuleList(activation() for _ in range(len(dims) - 1))

Expand Down Expand Up @@ -95,21 +95,21 @@ def __init__(
dims = [input_dim, *list(hidden_layer_dims)]

self.fcs = nn.ModuleList(
nn.Linear(dims[i], dims[i + 1]) for i in range(len(dims) - 1)
nn.Linear(dims[idx], dims[idx + 1]) for idx in range(len(dims) - 1)
)

if batch_norm:
self.bns = nn.ModuleList(
nn.BatchNorm1d(dims[i + 1]) for i in range(len(dims) - 1)
nn.BatchNorm1d(dims[idx + 1]) for idx in range(len(dims) - 1)
)
else:
self.bns = nn.ModuleList(nn.Identity() for i in range(len(dims) - 1))
self.bns = nn.ModuleList(nn.Identity() for _ in range(len(dims) - 1))

self.res_fcs = nn.ModuleList(
nn.Linear(dims[i], dims[i + 1], bias=False)
if (dims[i] != dims[i + 1])
nn.Linear(dims[idx], dims[idx + 1], bias=False)
if (dims[idx] != dims[idx + 1])
else nn.Identity()
for i in range(len(dims) - 1)
for idx in range(len(dims) - 1)
)
self.acts = nn.ModuleList(activation() for _ in range(len(dims) - 1))

Expand Down
4 changes: 2 additions & 2 deletions aviary/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def make_ensemble_predictions(
# (i.e. tty mode) https://git.io/JnBOi
print(f"Pytorch running on {device=}")
for idx, checkpoint_path in tqdm(
enumerate(tqdm(checkpoint_paths), 1), disable=None if pbar else True
enumerate(tqdm(checkpoint_paths), start=1), disable=None if pbar else True
):
try:
checkpoint = torch.load(checkpoint_path, map_location=device)
Expand Down Expand Up @@ -189,7 +189,7 @@ def predict_from_wandb_checkpoints(

checkpoint_paths: list[str] = []

for idx, run in enumerate(runs, 1):
for idx, run in enumerate(runs, start=1):
run_path = "/".join(run.path)
out_dir = f"{cache_dir}/{run_path}"
os.makedirs(out_dir, exist_ok=True)
Expand Down
4 changes: 2 additions & 2 deletions aviary/roost/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def collate_batch(
batch_cry_ids = []

cry_base_idx = 0
for i, (inputs, target, *cry_ids) in enumerate(samples):
for idx, (inputs, target, *cry_ids) in enumerate(samples):
elem_weights, elem_fea, self_idx, nbr_idx = inputs

n_sites = elem_fea.shape[0] # number of atoms for this crystal
Expand All @@ -190,7 +190,7 @@ def collate_batch(
batch_nbr_idx.append(nbr_idx + cry_base_idx)

# mapping from atoms to crystals
crystal_elem_idx.append(torch.tensor([i] * n_sites))
crystal_elem_idx.append(torch.tensor([idx] * n_sites))

# batch the targets and ids
batch_targets.append(target)
Expand Down
12 changes: 7 additions & 5 deletions aviary/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from __future__ import annotations

import os
from datetime import datetime
from typing import TYPE_CHECKING, Any, Literal

import numpy as np
Expand All @@ -20,7 +21,7 @@
try:
import wandb
except ImportError:
wandb = None
wandb = None # type: ignore[assignment]

if TYPE_CHECKING:
from torch import nn
Expand Down Expand Up @@ -319,13 +320,14 @@ def train_model(
if checkpoint is not None:
checkpoint_model(
checkpoint_endpoint=checkpoint,
model=inference_model,
model_params=model_params,
inference_model=inference_model,
optimizer_instance=optimizer_instance,
lr_scheduler=lr_scheduler,
loss_dict=loss_dict,
epoch=epochs,
epochs=epochs,
test_metrics=test_metrics,
timestamp=timestamp,
timestamp=timestamp or datetime.now().astimezone().strftime("%Y%m%d-%H%M%S"),
run_name=run_name,
normalizer_dict=normalizer_dict,
run_params=run_params,
Expand Down Expand Up @@ -364,7 +366,7 @@ def train_model(

def checkpoint_model(
checkpoint_endpoint: str,
model_params: dict,
model_params: dict | None,
inference_model: nn.Module,
optimizer_instance: torch.optim.Optimizer,
lr_scheduler: torch.optim.lr_scheduler._LRScheduler,
Expand Down
16 changes: 4 additions & 12 deletions aviary/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,13 +174,9 @@ def initialize_optim(
momentum=momentum,
)
elif optim == "Adam":
optimizer = Adam(
model.parameters(), lr=learning_rate, weight_decay=weight_decay
)
optimizer = Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
elif optim == "AdamW":
optimizer = AdamW(
model.parameters(), lr=learning_rate, weight_decay=weight_decay
)
optimizer = AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
else:
raise NameError("Only SGD, Adam or AdamW are allowed as --optim")

Expand Down Expand Up @@ -361,9 +357,7 @@ def train_ensemble(
sample_target = Tensor(train_set.df[target].values)
if not restart_params["resume"]:
normalizer.fit(sample_target)
print(
f"Dummy MAE: {(sample_target - normalizer.mean).abs().mean():.4f}"
)
print(f"Dummy MAE: {(sample_target - normalizer.mean).abs().mean():.4f}")

if log:
writer = SummaryWriter(
Expand Down Expand Up @@ -528,9 +522,7 @@ def results_multitask(
elif task_type == "classification":
if model.robust:
mean, log_std = output.chunk(2, dim=1)
logits = (
sampled_softmax(mean, log_std, samples=10).data.cpu().numpy()
)
logits = sampled_softmax(mean, log_std, samples=10).data.cpu().numpy()
pre_logits = mean.data.cpu().numpy()
pre_logits_std = torch.exp(log_std).data.cpu().numpy()
res_dict["pre-logits_ale"].append(pre_logits_std) # type: ignore[union-attr]
Expand Down
14 changes: 6 additions & 8 deletions aviary/wren/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,16 +139,16 @@ def __getitem__(self, idx: int):
n_wyks = len(elements)
self_idx = []
nbr_idx = []
for i in range(n_wyks):
self_idx += [i] * n_wyks
for wyk_idx in range(n_wyks):
self_idx += [wyk_idx] * n_wyks
nbr_idx += list(range(n_wyks))

self_aug_fea_idx = []
nbr_aug_fea_idx = []
n_aug = len(augmented_wyks)
for i in range(n_aug):
self_aug_fea_idx += [x + i * n_wyks for x in self_idx]
nbr_aug_fea_idx += [x + i * n_wyks for x in nbr_idx]
for aug_idx in range(n_aug):
self_aug_fea_idx += [x + aug_idx * n_wyks for x in self_idx]
nbr_aug_fea_idx += [x + aug_idx * n_wyks for x in nbr_idx]

# convert all data to tensors
wyckoff_weights = Tensor(wyk_site_multiplcities)
Expand Down Expand Up @@ -291,9 +291,7 @@ def parse_protostructure_label(
)

# Separate out pairs of Wyckoff letters and their number of occurrences
sep_n_wyks = [
"".join(g) for _, g in groupby(wyk_letters_normalized, str.isalpha)
]
sep_n_wyks = ["".join(g) for _, g in groupby(wyk_letters_normalized, str.isalpha)]

# Process Wyckoff letters and multiplicities
mults = map(int, sep_n_wyks[0::2])
Expand Down
6 changes: 2 additions & 4 deletions aviary/wren/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ def __init__(
msg_gate_layers=elem_gate,
msg_net_layers=elem_msg,
)
for i in range(n_graph)
for _ in range(n_graph)
)

# define a global pooling function for materials
Expand Down Expand Up @@ -259,9 +259,7 @@ def forward(
for attnhead in self.cry_pool
]

return scatter_mean(
torch.mean(torch.stack(head_fea), dim=0), aug_cry_idx, dim=0
)
return scatter_mean(torch.mean(torch.stack(head_fea), dim=0), aug_cry_idx, dim=0)

def __repr__(self) -> str:
return (
Expand Down
35 changes: 12 additions & 23 deletions aviary/wren/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,13 +361,11 @@ def sort_and_score_element_wyckoffs(element_wyckoffs: str) -> tuple[str, int]:
wp_counts = split_alpha_numeric(el_wyks)
sorted_element_wyckoffs.append(
"".join(
[
f"{count}{wyckoff_letter}" if count != "1" else wyckoff_letter
for count, wyckoff_letter in sorted(
zip(wp_counts["numeric"], wp_counts["alpha"]),
key=lambda x: x[1],
)
]
f"{count}{wyckoff_letter}" if count != "1" else wyckoff_letter
for count, wyckoff_letter in sorted(
zip(wp_counts["numeric"], wp_counts["alpha"]),
key=lambda x: x[1],
)
)
)
score += sum(
Expand All @@ -391,19 +389,19 @@ def get_prototype_formula_from_composition(composition: Composition) -> str:
"""
reduced = composition.element_composition
if all(x == int(x) for x in composition.values()):
reduced /= gcd(*(int(i) for i in composition.values()))
reduced /= gcd(*(int(amt) for amt in composition.values()))

amounts = [reduced[key] for key in sorted(reduced, key=str)]

anon = ""
for e, amt in zip(ascii_uppercase, amounts):
for elem, amt in zip(ascii_uppercase, amounts):
if amt == 1:
amt_str = ""
elif abs(amt % 1) < 1e-8:
amt_str = str(int(amt))
else:
amt_str = str(amt)
anon += f"{e}{amt_str}"
anon += f"{elem}{amt_str}"
return anon


Expand All @@ -415,13 +413,8 @@ def get_anonymous_formula_from_prototype_formula(prototype_formula: str) -> str:
anom_list = split_alpha_numeric(prototype_formula)

return "".join(
[
f"{el}{num}" if num != 1 else el
for el, num in zip(
anom_list["alpha"],
sorted(map(int, anom_list["numeric"])),
)
]
f"{el}{num}" if num != 1 else el
for el, num in zip(anom_list["alpha"], sorted(map(int, anom_list["numeric"])))
)


Expand All @@ -435,12 +428,8 @@ def get_formula_from_protostructure_label(protostructure_label: str) -> str:
anom_list = split_alpha_numeric(prototype_formula)

return "".join(
[
f"{el}{num}" if num != 1 else el
for el, num in zip(
chemsys.split("-"), map(int, anom_list["numeric"]), strict=True
)
]
f"{el}{num}" if num != 1 else el
for el, num in zip(chemsys.split("-"), map(int, anom_list["numeric"]))
)


Expand Down
8 changes: 2 additions & 6 deletions aviary/wrenformer/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,7 @@ def collate_batch(


@cache
def get_wyckoff_features(
equivalent_wyckoff_set: list[tuple], spg_num: int
) -> np.ndarray:
def get_wyckoff_features(equivalent_wyckoff_set: list[tuple], spg_num: int) -> np.ndarray:
"""Get Wyckoff set features from the precomputed dictionary. The output of this
function is cached for speed.

Expand Down Expand Up @@ -204,6 +202,4 @@ def df_to_in_mem_dataloader(
inputs[idx] = tensor.to(device)

ids = df.get(id_col, df.index).to_numpy()
return InMemoryDataLoader(
[inputs, targets, ids], collate_fn=collate_batch, **kwargs
)
return InMemoryDataLoader([inputs, targets, ids], collate_fn=collate_batch, **kwargs)
4 changes: 2 additions & 2 deletions examples/wrenformer/mat_bench/make_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@

# %% --- Load other's scores ---
# load benchmark data for models with existing Matbench submission
for idx, dirname in enumerate(glob(f"{bench_dir}/*"), 1):
for idx, dirname in enumerate(glob(f"{bench_dir}/*"), start=1):
model_name = dirname.split("/matbench_v0.1_")[-1]
print(f"{idx}. {model_name}")
mbbm = MatbenchBenchmark.from_file(f"{dirname}/results.json.gz")
Expand All @@ -62,7 +62,7 @@
# %% --- Load our scores ---
our_score_files = sorted(glob("model_scores/*.json"), key=lambda s: s.split("@")[0])

for idx, filename in enumerate(our_score_files, 1):
for idx, filename in enumerate(our_score_files, start=1):
date, model_name = re.split(r"@\d\d-\d\d-", filename.split("/")[-1])

print(f"{idx}. {date} {model_name}")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

benchmark = MatbenchBenchmark()

for idx, task in enumerate(benchmark.tasks, 1):
for idx, task in enumerate(benchmark.tasks, start=1):
print(f"\n\n{idx}/{len(benchmark.tasks)}")
task.load()
df: pd.DataFrame = task.df
Expand Down
4 changes: 1 addition & 3 deletions examples/wrenformer/mat_bench/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,5 @@ def non_serializable_handler(obj: object) -> str:
return f"<not serializable: {type(obj).__qualname__}>"

with open(file_path, "w") as file:
default = (
non_serializable_handler if on_non_serializable == "annotate" else None
)
default = non_serializable_handler if on_non_serializable == "annotate" else None
json.dump(dct, file, default=default, indent=2)
4 changes: 2 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def df_matbench_phonons():
"""Returns the dataframe for the Matbench phonon DOS peak task."""

df = load_dataset("matbench_phonons")
df["material_id"] = [f"mb_phdos_{i + 1}" for i in range(len(df))]
df["material_id"] = [f"mb_phdos_{idx + 1}" for idx in range(len(df))]
df = df.set_index("material_id", drop=False)
df["composition"] = [x.composition.formula.replace(" ", "") for x in df.structure]

Expand All @@ -33,7 +33,7 @@ def df_matbench_jdft2d():
"""Returns Matbench experimental band gap task dataframe. Currently unused."""

df = load_dataset("matbench_jdft2d")
df["material_id"] = [f"mb_jdft2d_{i + 1}" for i in range(len(df))]
df["material_id"] = [f"mb_jdft2d_{idx + 1}" for idx in range(len(df))]
df = df.set_index("material_id", drop=False)
df["composition"] = [x.composition.formula.replace(" ", "") for x in df.structure]

Expand Down
Loading