diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 6cda0b14..668e10c2 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -5,7 +5,7 @@ ci: repos: - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.5.3 + rev: v0.5.5 hooks: - id: ruff args: [--fix] diff --git a/aviary/cgcnn/data.py b/aviary/cgcnn/data.py index bb6cea09..993a1f12 100644 --- a/aviary/cgcnn/data.py +++ b/aviary/cgcnn/data.py @@ -148,9 +148,7 @@ def __getitem__(self, idx: int): site_atoms = [atom.species.as_dict() for atom in struct] atom_features = np.vstack( [ - np.sum( - [self.elem_features[el] * amt for el, amt in site.items()], axis=0 - ) + np.sum([self.elem_features[el] * amt for el, amt in site.items()], axis=0) for site in site_atoms ] ) diff --git a/aviary/core.py b/aviary/core.py index e919e7e2..f8a3ee7c 100644 --- a/aviary/core.py +++ b/aviary/core.py @@ -261,9 +261,7 @@ def evaluate( # *_ discards identifiers like material_id and formula which we don't need when # training tqdm(disable=None) means suppress output in non-tty (e.g. CI/log # files) but keep in terminal (i.e. tty mode) https://git.io/JnBOi - for inputs, targets_list, *_ in tqdm( - data_loader, disable=None if pbar else True - ): + for inputs, targets_list, *_ in tqdm(data_loader, disable=None if pbar else True): inputs = [ # noqa: PLW2901 tensor.to(self.device) if hasattr(tensor, "to") else tensor for tensor in inputs diff --git a/aviary/networks.py b/aviary/networks.py index b10094d4..32dbcb62 100644 --- a/aviary/networks.py +++ b/aviary/networks.py @@ -34,15 +34,15 @@ def __init__( dims = [input_dim, *list(hidden_layer_dims)] self.fcs = nn.ModuleList( - nn.Linear(dims[i], dims[i + 1]) for i in range(len(dims) - 1) + nn.Linear(dims[idx], dims[idx + 1]) for idx in range(len(dims) - 1) ) if batch_norm: self.bns = nn.ModuleList( - nn.BatchNorm1d(dims[i + 1]) for i in range(len(dims) - 1) + nn.BatchNorm1d(dims[idx + 1]) for idx in range(len(dims) - 1) ) else: - self.bns = nn.ModuleList(nn.Identity() for i in range(len(dims) - 1)) + self.bns = nn.ModuleList(nn.Identity() for _ in range(len(dims) - 1)) self.acts = nn.ModuleList(activation() for _ in range(len(dims) - 1)) @@ -95,21 +95,21 @@ def __init__( dims = [input_dim, *list(hidden_layer_dims)] self.fcs = nn.ModuleList( - nn.Linear(dims[i], dims[i + 1]) for i in range(len(dims) - 1) + nn.Linear(dims[idx], dims[idx + 1]) for idx in range(len(dims) - 1) ) if batch_norm: self.bns = nn.ModuleList( - nn.BatchNorm1d(dims[i + 1]) for i in range(len(dims) - 1) + nn.BatchNorm1d(dims[idx + 1]) for idx in range(len(dims) - 1) ) else: - self.bns = nn.ModuleList(nn.Identity() for i in range(len(dims) - 1)) + self.bns = nn.ModuleList(nn.Identity() for _ in range(len(dims) - 1)) self.res_fcs = nn.ModuleList( - nn.Linear(dims[i], dims[i + 1], bias=False) - if (dims[i] != dims[i + 1]) + nn.Linear(dims[idx], dims[idx + 1], bias=False) + if (dims[idx] != dims[idx + 1]) else nn.Identity() - for i in range(len(dims) - 1) + for idx in range(len(dims) - 1) ) self.acts = nn.ModuleList(activation() for _ in range(len(dims) - 1)) diff --git a/aviary/predict.py b/aviary/predict.py index afff282b..a7b55efe 100644 --- a/aviary/predict.py +++ b/aviary/predict.py @@ -69,7 +69,7 @@ def make_ensemble_predictions( # (i.e. tty mode) https://git.io/JnBOi print(f"Pytorch running on {device=}") for idx, checkpoint_path in tqdm( - enumerate(tqdm(checkpoint_paths), 1), disable=None if pbar else True + enumerate(tqdm(checkpoint_paths), start=1), disable=None if pbar else True ): try: checkpoint = torch.load(checkpoint_path, map_location=device) @@ -189,7 +189,7 @@ def predict_from_wandb_checkpoints( checkpoint_paths: list[str] = [] - for idx, run in enumerate(runs, 1): + for idx, run in enumerate(runs, start=1): run_path = "/".join(run.path) out_dir = f"{cache_dir}/{run_path}" os.makedirs(out_dir, exist_ok=True) diff --git a/aviary/roost/data.py b/aviary/roost/data.py index 07af42c7..00821668 100644 --- a/aviary/roost/data.py +++ b/aviary/roost/data.py @@ -176,7 +176,7 @@ def collate_batch( batch_cry_ids = [] cry_base_idx = 0 - for i, (inputs, target, *cry_ids) in enumerate(samples): + for idx, (inputs, target, *cry_ids) in enumerate(samples): elem_weights, elem_fea, self_idx, nbr_idx = inputs n_sites = elem_fea.shape[0] # number of atoms for this crystal @@ -190,7 +190,7 @@ def collate_batch( batch_nbr_idx.append(nbr_idx + cry_base_idx) # mapping from atoms to crystals - crystal_elem_idx.append(torch.tensor([i] * n_sites)) + crystal_elem_idx.append(torch.tensor([idx] * n_sites)) # batch the targets and ids batch_targets.append(target) diff --git a/aviary/train.py b/aviary/train.py index 503f6c7b..e16e6f6d 100644 --- a/aviary/train.py +++ b/aviary/train.py @@ -2,6 +2,7 @@ from __future__ import annotations import os +from datetime import datetime from typing import TYPE_CHECKING, Any, Literal import numpy as np @@ -20,7 +21,7 @@ try: import wandb except ImportError: - wandb = None + wandb = None # type: ignore[assignment] if TYPE_CHECKING: from torch import nn @@ -319,13 +320,14 @@ def train_model( if checkpoint is not None: checkpoint_model( checkpoint_endpoint=checkpoint, - model=inference_model, + model_params=model_params, + inference_model=inference_model, optimizer_instance=optimizer_instance, lr_scheduler=lr_scheduler, loss_dict=loss_dict, - epoch=epochs, + epochs=epochs, test_metrics=test_metrics, - timestamp=timestamp, + timestamp=timestamp or datetime.now().astimezone().strftime("%Y%m%d-%H%M%S"), run_name=run_name, normalizer_dict=normalizer_dict, run_params=run_params, @@ -364,7 +366,7 @@ def train_model( def checkpoint_model( checkpoint_endpoint: str, - model_params: dict, + model_params: dict | None, inference_model: nn.Module, optimizer_instance: torch.optim.Optimizer, lr_scheduler: torch.optim.lr_scheduler._LRScheduler, diff --git a/aviary/utils.py b/aviary/utils.py index 4c1e22e6..bbe35667 100644 --- a/aviary/utils.py +++ b/aviary/utils.py @@ -174,13 +174,9 @@ def initialize_optim( momentum=momentum, ) elif optim == "Adam": - optimizer = Adam( - model.parameters(), lr=learning_rate, weight_decay=weight_decay - ) + optimizer = Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay) elif optim == "AdamW": - optimizer = AdamW( - model.parameters(), lr=learning_rate, weight_decay=weight_decay - ) + optimizer = AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay) else: raise NameError("Only SGD, Adam or AdamW are allowed as --optim") @@ -361,9 +357,7 @@ def train_ensemble( sample_target = Tensor(train_set.df[target].values) if not restart_params["resume"]: normalizer.fit(sample_target) - print( - f"Dummy MAE: {(sample_target - normalizer.mean).abs().mean():.4f}" - ) + print(f"Dummy MAE: {(sample_target - normalizer.mean).abs().mean():.4f}") if log: writer = SummaryWriter( @@ -528,9 +522,7 @@ def results_multitask( elif task_type == "classification": if model.robust: mean, log_std = output.chunk(2, dim=1) - logits = ( - sampled_softmax(mean, log_std, samples=10).data.cpu().numpy() - ) + logits = sampled_softmax(mean, log_std, samples=10).data.cpu().numpy() pre_logits = mean.data.cpu().numpy() pre_logits_std = torch.exp(log_std).data.cpu().numpy() res_dict["pre-logits_ale"].append(pre_logits_std) # type: ignore[union-attr] diff --git a/aviary/wren/data.py b/aviary/wren/data.py index 5d1b4b5f..fd5d7257 100644 --- a/aviary/wren/data.py +++ b/aviary/wren/data.py @@ -139,16 +139,16 @@ def __getitem__(self, idx: int): n_wyks = len(elements) self_idx = [] nbr_idx = [] - for i in range(n_wyks): - self_idx += [i] * n_wyks + for wyk_idx in range(n_wyks): + self_idx += [wyk_idx] * n_wyks nbr_idx += list(range(n_wyks)) self_aug_fea_idx = [] nbr_aug_fea_idx = [] n_aug = len(augmented_wyks) - for i in range(n_aug): - self_aug_fea_idx += [x + i * n_wyks for x in self_idx] - nbr_aug_fea_idx += [x + i * n_wyks for x in nbr_idx] + for aug_idx in range(n_aug): + self_aug_fea_idx += [x + aug_idx * n_wyks for x in self_idx] + nbr_aug_fea_idx += [x + aug_idx * n_wyks for x in nbr_idx] # convert all data to tensors wyckoff_weights = Tensor(wyk_site_multiplcities) @@ -291,9 +291,7 @@ def parse_protostructure_label( ) # Separate out pairs of Wyckoff letters and their number of occurrences - sep_n_wyks = [ - "".join(g) for _, g in groupby(wyk_letters_normalized, str.isalpha) - ] + sep_n_wyks = ["".join(g) for _, g in groupby(wyk_letters_normalized, str.isalpha)] # Process Wyckoff letters and multiplicities mults = map(int, sep_n_wyks[0::2]) diff --git a/aviary/wren/model.py b/aviary/wren/model.py index 25ee00ea..35b0b3f6 100644 --- a/aviary/wren/model.py +++ b/aviary/wren/model.py @@ -205,7 +205,7 @@ def __init__( msg_gate_layers=elem_gate, msg_net_layers=elem_msg, ) - for i in range(n_graph) + for _ in range(n_graph) ) # define a global pooling function for materials @@ -259,9 +259,7 @@ def forward( for attnhead in self.cry_pool ] - return scatter_mean( - torch.mean(torch.stack(head_fea), dim=0), aug_cry_idx, dim=0 - ) + return scatter_mean(torch.mean(torch.stack(head_fea), dim=0), aug_cry_idx, dim=0) def __repr__(self) -> str: return ( diff --git a/aviary/wren/utils.py b/aviary/wren/utils.py index ed832cb9..0c81d272 100644 --- a/aviary/wren/utils.py +++ b/aviary/wren/utils.py @@ -361,13 +361,11 @@ def sort_and_score_element_wyckoffs(element_wyckoffs: str) -> tuple[str, int]: wp_counts = split_alpha_numeric(el_wyks) sorted_element_wyckoffs.append( "".join( - [ - f"{count}{wyckoff_letter}" if count != "1" else wyckoff_letter - for count, wyckoff_letter in sorted( - zip(wp_counts["numeric"], wp_counts["alpha"]), - key=lambda x: x[1], - ) - ] + f"{count}{wyckoff_letter}" if count != "1" else wyckoff_letter + for count, wyckoff_letter in sorted( + zip(wp_counts["numeric"], wp_counts["alpha"]), + key=lambda x: x[1], + ) ) ) score += sum( @@ -391,19 +389,19 @@ def get_prototype_formula_from_composition(composition: Composition) -> str: """ reduced = composition.element_composition if all(x == int(x) for x in composition.values()): - reduced /= gcd(*(int(i) for i in composition.values())) + reduced /= gcd(*(int(amt) for amt in composition.values())) amounts = [reduced[key] for key in sorted(reduced, key=str)] anon = "" - for e, amt in zip(ascii_uppercase, amounts): + for elem, amt in zip(ascii_uppercase, amounts): if amt == 1: amt_str = "" elif abs(amt % 1) < 1e-8: amt_str = str(int(amt)) else: amt_str = str(amt) - anon += f"{e}{amt_str}" + anon += f"{elem}{amt_str}" return anon @@ -415,13 +413,8 @@ def get_anonymous_formula_from_prototype_formula(prototype_formula: str) -> str: anom_list = split_alpha_numeric(prototype_formula) return "".join( - [ - f"{el}{num}" if num != 1 else el - for el, num in zip( - anom_list["alpha"], - sorted(map(int, anom_list["numeric"])), - ) - ] + f"{el}{num}" if num != 1 else el + for el, num in zip(anom_list["alpha"], sorted(map(int, anom_list["numeric"]))) ) @@ -435,12 +428,8 @@ def get_formula_from_protostructure_label(protostructure_label: str) -> str: anom_list = split_alpha_numeric(prototype_formula) return "".join( - [ - f"{el}{num}" if num != 1 else el - for el, num in zip( - chemsys.split("-"), map(int, anom_list["numeric"]), strict=True - ) - ] + f"{el}{num}" if num != 1 else el + for el, num in zip(chemsys.split("-"), map(int, anom_list["numeric"])) ) diff --git a/aviary/wrenformer/data.py b/aviary/wrenformer/data.py index 685c1e98..75a694c9 100644 --- a/aviary/wrenformer/data.py +++ b/aviary/wrenformer/data.py @@ -66,9 +66,7 @@ def collate_batch( @cache -def get_wyckoff_features( - equivalent_wyckoff_set: list[tuple], spg_num: int -) -> np.ndarray: +def get_wyckoff_features(equivalent_wyckoff_set: list[tuple], spg_num: int) -> np.ndarray: """Get Wyckoff set features from the precomputed dictionary. The output of this function is cached for speed. @@ -204,6 +202,4 @@ def df_to_in_mem_dataloader( inputs[idx] = tensor.to(device) ids = df.get(id_col, df.index).to_numpy() - return InMemoryDataLoader( - [inputs, targets, ids], collate_fn=collate_batch, **kwargs - ) + return InMemoryDataLoader([inputs, targets, ids], collate_fn=collate_batch, **kwargs) diff --git a/examples/wrenformer/mat_bench/make_plots.py b/examples/wrenformer/mat_bench/make_plots.py index 55012837..5cf9cdd6 100644 --- a/examples/wrenformer/mat_bench/make_plots.py +++ b/examples/wrenformer/mat_bench/make_plots.py @@ -40,7 +40,7 @@ # %% --- Load other's scores --- # load benchmark data for models with existing Matbench submission -for idx, dirname in enumerate(glob(f"{bench_dir}/*"), 1): +for idx, dirname in enumerate(glob(f"{bench_dir}/*"), start=1): model_name = dirname.split("/matbench_v0.1_")[-1] print(f"{idx}. {model_name}") mbbm = MatbenchBenchmark.from_file(f"{dirname}/results.json.gz") @@ -62,7 +62,7 @@ # %% --- Load our scores --- our_score_files = sorted(glob("model_scores/*.json"), key=lambda s: s.split("@")[0]) -for idx, filename in enumerate(our_score_files, 1): +for idx, filename in enumerate(our_score_files, start=1): date, model_name = re.split(r"@\d\d-\d\d-", filename.split("/")[-1]) print(f"{idx}. {date} {model_name}") diff --git a/examples/wrenformer/mat_bench/save_matbench_aflow_labels.py b/examples/wrenformer/mat_bench/save_matbench_aflow_labels.py index 41a72126..fc6d894a 100644 --- a/examples/wrenformer/mat_bench/save_matbench_aflow_labels.py +++ b/examples/wrenformer/mat_bench/save_matbench_aflow_labels.py @@ -17,7 +17,7 @@ benchmark = MatbenchBenchmark() -for idx, task in enumerate(benchmark.tasks, 1): +for idx, task in enumerate(benchmark.tasks, start=1): print(f"\n\n{idx}/{len(benchmark.tasks)}") task.load() df: pd.DataFrame = task.df diff --git a/examples/wrenformer/mat_bench/utils.py b/examples/wrenformer/mat_bench/utils.py index 2284f003..0d130b67 100644 --- a/examples/wrenformer/mat_bench/utils.py +++ b/examples/wrenformer/mat_bench/utils.py @@ -54,7 +54,5 @@ def non_serializable_handler(obj: object) -> str: return f"" with open(file_path, "w") as file: - default = ( - non_serializable_handler if on_non_serializable == "annotate" else None - ) + default = non_serializable_handler if on_non_serializable == "annotate" else None json.dump(dct, file, default=default, indent=2) diff --git a/tests/conftest.py b/tests/conftest.py index 23923102..15cb7cb5 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -19,7 +19,7 @@ def df_matbench_phonons(): """Returns the dataframe for the Matbench phonon DOS peak task.""" df = load_dataset("matbench_phonons") - df["material_id"] = [f"mb_phdos_{i + 1}" for i in range(len(df))] + df["material_id"] = [f"mb_phdos_{idx + 1}" for idx in range(len(df))] df = df.set_index("material_id", drop=False) df["composition"] = [x.composition.formula.replace(" ", "") for x in df.structure] @@ -33,7 +33,7 @@ def df_matbench_jdft2d(): """Returns Matbench experimental band gap task dataframe. Currently unused.""" df = load_dataset("matbench_jdft2d") - df["material_id"] = [f"mb_jdft2d_{i + 1}" for i in range(len(df))] + df["material_id"] = [f"mb_jdft2d_{idx + 1}" for idx in range(len(df))] df = df.set_index("material_id", drop=False) df["composition"] = [x.composition.formula.replace(" ", "") for x in df.structure] diff --git a/tests/test_wyckoff_ops.py b/tests/test_wyckoff_ops.py index 459240b4..5475199d 100644 --- a/tests/test_wyckoff_ops.py +++ b/tests/test_wyckoff_ops.py @@ -229,8 +229,8 @@ def test_find_translations(dict1, dict2, expected): # Additional test for performance with larger input def test_find_translations_performance(): - dict1 = {f"key{i}": i for i in range(8)} - dict2 = {f"val{i}": i for i in range(8)} + dict1 = {f"key{idx}": idx for idx in range(8)} + dict2 = {f"val{idx}": idx for idx in range(8)} result = _find_translations(dict1, dict2) assert len(result) == 1 # There should be only one valid translation