diff --git a/src/simulation/m_ibm.fpp b/src/simulation/m_ibm.fpp index 726cb4cc65..ad856135f6 100644 --- a/src/simulation/m_ibm.fpp +++ b/src/simulation/m_ibm.fpp @@ -191,7 +191,7 @@ contains type(ghost_point) :: gp type(ghost_point) :: innerp - ! set the Moving IBM interior Pressure Values + ! set the Moving IBM interior conservative variables $:GPU_PARALLEL_LOOP(private='[i,j,k,patch_id,rho]', copyin='[E_idx,momxb]', collapse=3) do l = 0, p do k = 0, n diff --git a/toolchain/mfc/case.py b/toolchain/mfc/case.py index 4388ac97c7..9627ef4503 100644 --- a/toolchain/mfc/case.py +++ b/toolchain/mfc/case.py @@ -62,7 +62,7 @@ def get_inp(self, _target) -> str: cons.print(f"Generating [magenta]{target.name}.inp[/magenta]:") cons.indent() - MASTER_KEYS: list = case_dicts.get_input_dict_keys(target.name) + MASTER_KEYS = case_dicts.get_input_dict_keys(target.name) ignored = [] diff --git a/toolchain/mfc/case_validator.py b/toolchain/mfc/case_validator.py index da34e1b984..79e136eb6e 100644 --- a/toolchain/mfc/case_validator.py +++ b/toolchain/mfc/case_validator.py @@ -19,6 +19,7 @@ from .common import MFCException from .params.definitions import CONSTRAINTS +from .params.namelist_parser import get_fortran_constants from .state import CFG # Physics documentation for check methods. @@ -559,7 +560,12 @@ def check_ibm(self): ib_state_wrt = self.get("ib_state_wrt", "F") == "T" self.prohibit(ib and n <= 0, "Immersed Boundaries do not work in 1D (requires n > 0)") - self.prohibit(ib and (num_ibs <= 0 or num_ibs > 1000), "num_ibs must be between 1 and num_patches_max (1000)") + self.prohibit(ib and num_ibs <= 0, "num_ibs must be >= 1 when ib is enabled") + num_patches_max = get_fortran_constants().get("num_patches_max", 1000) + self.prohibit( + ib and num_ibs > num_patches_max, + f"num_ibs must be <= {num_patches_max} (num_patches_max in m_constants.fpp)", + ) self.prohibit(not ib and num_ibs > 0, "num_ibs is set, but ib is not enabled") self.prohibit(ib_state_wrt and not ib, "ib_state_wrt requires ib to be enabled") @@ -1177,6 +1183,11 @@ def check_restart(self): self.prohibit(old_grid and t_step_old is None, "old_grid requires t_step_old to be set") self.prohibit(num_patches < 0, "num_patches must be non-negative") self.prohibit(num_patches == 0 and t_step_old is None, "num_patches must be positive for the non-restart case") + num_patches_max = get_fortran_constants().get("num_patches_max", 1000) + self.prohibit( + num_patches > num_patches_max, + f"num_patches must be <= {num_patches_max} (num_patches_max in m_constants.fpp)", + ) def check_qbmm_pre_process(self): """Checks QBMM constraints for pre-process""" @@ -1388,7 +1399,7 @@ def check_patch_physics(self): def check_bc_patches(self): """Checks boundary condition patch geometry (pre-process)""" num_bc_patches = self.get("num_bc_patches", 0) - num_bc_patches_max = self.get("num_bc_patches_max", 10) + num_bc_patches_max = get_fortran_constants().get("num_bc_patches_max", 10) if num_bc_patches <= 0: return diff --git a/toolchain/mfc/params/__init__.py b/toolchain/mfc/params/__init__.py index 65267b34fe..39c7e6ffef 100644 --- a/toolchain/mfc/params/__init__.py +++ b/toolchain/mfc/params/__init__.py @@ -23,11 +23,12 @@ # and freezes it. It must come after REGISTRY is imported and must not be removed. from . import definitions # noqa: F401 from .definitions import CONSTRAINTS, DEPENDENCIES, get_value_label -from .registry import REGISTRY, RegistryFrozenError +from .registry import REGISTRY, IndexedFamily, RegistryFrozenError from .schema import ParamDef, ParamType __all__ = [ "REGISTRY", + "IndexedFamily", "RegistryFrozenError", "ParamDef", "ParamType", diff --git a/toolchain/mfc/params/definitions.py b/toolchain/mfc/params/definitions.py index 1e0a92c786..63076f80a6 100644 --- a/toolchain/mfc/params/definitions.py +++ b/toolchain/mfc/params/definitions.py @@ -8,11 +8,34 @@ import re from typing import Any, Dict -from .registry import REGISTRY +from .namelist_parser import get_fortran_constants +from .registry import REGISTRY, IndexedFamily from .schema import ParamDef, ParamType -# Index limits -NP, NF, NI, NA, NPR, NB = 10, 10, 1000, 4, 10, 10 # patches, fluids, ibs, acoustic, probes, bc_patches +# Index limits — sourced from Fortran compile-time constants (m_constants.fpp). +# These must stay in sync with Fortran; we error if the source can't be parsed. +_FC = get_fortran_constants() + + +def _fc(name: str) -> int: + """Get a required Fortran constant, raising if unavailable.""" + if name not in _FC: + raise RuntimeError( + f"Fortran constant '{name}' not found in m_constants.fpp. " + f"Toolchain is out of sync with Fortran source." + ) + return _FC[name] + + +NF = _fc("num_fluids_max") # fluid_pp +NPR = _fc("num_probes_max") # probe, acoustic, integral +NB = _fc("num_bc_patches_max") # patch_bc +NUM_PATCHES_MAX = _fc("num_patches_max") # patch_icpp, patch_ib (Fortran array bound) +# Enumeration limits for families not yet converted to IndexedFamily. +# These are smaller than the Fortran array bounds to keep the registry compact. +# The CONSTRAINTS dict below uses the Fortran constants for validation. +NP = 10 # patch_icpp: has per-index variations, can't easily be IndexedFamily +NA = 4 # acoustic sources: enumerated individually # Auto-generated Descriptions @@ -637,9 +660,9 @@ def get_value_label(param_name: str, value: int) -> str: "R0ref": {"min": 0}, "sigma": {"min": 0}, # Counts (must be positive) - "num_fluids": {"min": 1, "max": 10}, - "num_patches": {"min": 0, "max": 10}, - "num_ibs": {"min": 0, "max": 1000}, + "num_fluids": {"min": 1, "max": NF}, + "num_patches": {"min": 0, "max": NUM_PATCHES_MAX}, + "num_ibs": {"min": 0}, "num_source": {"min": 1}, "num_probes": {"min": 1}, "num_integrals": {"min": 1}, @@ -1136,26 +1159,37 @@ def _load(): ]: _r(f"bub_pp%{a}", REAL, {"bubbles"}, math=sym) - # patch_ib (10 immersed boundaries) - for i in range(1, NI + 1): - px = f"patch_ib({i})%" - for a in ["geometry", "moving_ibm"]: - _r(f"{px}{a}", INT, {"ib"}) - for a, pt in [("radius", REAL), ("theta", REAL), ("slip", LOG), ("c", REAL), ("p", REAL), ("t", REAL), ("m", REAL), ("mass", REAL)]: - _r(f"{px}{a}", pt, {"ib"}) - for j in range(1, 4): - _r(f"{px}angles({j})", REAL, {"ib"}) - for d in ["x", "y", "z"]: - _r(f"{px}{d}_centroid", REAL, {"ib"}) - _r(f"{px}length_{d}", REAL, {"ib"}) - for a, pt in [("model_filepath", STR), ("model_spc", INT), ("model_threshold", REAL)]: - _r(f"{px}{a}", pt, {"ib"}) - for t in ["translate", "scale", "rotate"]: - for j in range(1, 4): - _r(f"{px}model_{t}({j})", REAL, {"ib"}) + # patch_ib (immersed boundaries) — registered as indexed family for O(1) lookup. + # max_index is None so the parameter registry stays compact (no enumeration). + # The Fortran-side upper bound (num_patches_max in m_constants.fpp) is parsed + # and enforced by the case_validator, not by max_index here. + _ib_tags = {"ib"} + _ib_attrs: Dict[str, tuple] = {} + for a in ["geometry", "moving_ibm"]: + _ib_attrs[a] = (INT, _ib_tags) + for a, pt in [("radius", REAL), ("theta", REAL), ("slip", LOG), ("c", REAL), ("p", REAL), ("t", REAL), ("m", REAL), ("mass", REAL)]: + _ib_attrs[a] = (pt, _ib_tags) + for j in range(1, 4): + _ib_attrs[f"angles({j})"] = (REAL, _ib_tags) + for d in ["x", "y", "z"]: + _ib_attrs[f"{d}_centroid"] = (REAL, _ib_tags) + _ib_attrs[f"length_{d}"] = (REAL, _ib_tags) + for a, pt in [("model_filepath", STR), ("model_spc", INT), ("model_threshold", REAL)]: + _ib_attrs[a] = (pt, _ib_tags) + for t in ["translate", "scale", "rotate"]: for j in range(1, 4): - _r(f"{px}vel({j})", A_REAL, {"ib"}) - _r(f"{px}angular_vel({j})", A_REAL, {"ib"}) + _ib_attrs[f"model_{t}({j})"] = (REAL, _ib_tags) + for j in range(1, 4): + _ib_attrs[f"vel({j})"] = (A_REAL, _ib_tags) + _ib_attrs[f"angular_vel({j})"] = (A_REAL, _ib_tags) + REGISTRY.register_family( + IndexedFamily( + base_name="patch_ib", + attrs=_ib_attrs, + tags=_ib_tags, + max_index=NUM_PATCHES_MAX, + ) + ) # acoustic sources (4 sources) for i in range(1, NA + 1): diff --git a/toolchain/mfc/params/namelist_parser.py b/toolchain/mfc/params/namelist_parser.py index 93f43e2622..0804c8c000 100644 --- a/toolchain/mfc/params/namelist_parser.py +++ b/toolchain/mfc/params/namelist_parser.py @@ -11,7 +11,7 @@ import re from pathlib import Path -from typing import Dict, Set +from typing import Dict, Optional, Set # Fallback parameters for when Fortran source files are not available. # Generated from the namelist definitions in src/*/m_start_up.fpp. @@ -464,6 +464,44 @@ def parse_all_namelists(mfc_root: Path) -> Dict[str, Set[str]]: return result +def parse_fortran_constants(filepath: Path) -> Dict[str, int]: + """ + Parse integer parameter constants from a Fortran source file. + + Extracts lines like ``integer, parameter :: name = 123`` and returns + a dict mapping constant names to their integer values. + """ + constants: Dict[str, int] = {} + pattern = re.compile( + r"integer\s*,\s*parameter\s*::\s*(\w+)\s*=\s*(\d+)", re.IGNORECASE + ) + try: + text = filepath.read_text() + except FileNotFoundError: + return constants + for m in pattern.finditer(text): + constants[m.group(1)] = int(m.group(2)) + return constants + + +# Module-level cache for Fortran constants (None = not yet loaded) +_FORTRAN_CONSTANTS_CACHE: Optional[Dict[str, int]] = None + + +def get_fortran_constants() -> Dict[str, int]: + """ + Get Fortran compile-time constants from m_constants.fpp. + + Cached after first call. Returns empty dict if source unavailable. + """ + global _FORTRAN_CONSTANTS_CACHE # noqa: PLW0603 + if _FORTRAN_CONSTANTS_CACHE is None: + root = get_mfc_root() + path = root / "src" / "common" / "m_constants.fpp" + _FORTRAN_CONSTANTS_CACHE = parse_fortran_constants(path) + return _FORTRAN_CONSTANTS_CACHE + + def get_mfc_root() -> Path: """Get the MFC root directory from this file's location.""" # This file is at toolchain/mfc/params/namelist_parser.py diff --git a/toolchain/mfc/params/registry.py b/toolchain/mfc/params/registry.py index 08aa11eac8..210f1995e5 100644 --- a/toolchain/mfc/params/registry.py +++ b/toolchain/mfc/params/registry.py @@ -3,7 +3,7 @@ Central storage for MFC parameter definitions. This module provides the ParamRegistry class which serves as the single source of truth for all -~3,300 MFC parameters. +MFC parameters. Usage ----- @@ -13,7 +13,10 @@ from mfc.params import REGISTRY # Get a specific parameter - param = REGISTRY.all_params.get('m') + param = REGISTRY.get_param_def('m') + + # Check if a parameter name is valid (including indexed families) + REGISTRY.is_known_param('patch_ib(500)%geometry') # True # Get parameters by feature tag mhd_params = REGISTRY.get_params_by_tag('mhd') @@ -27,53 +30,183 @@ import re from collections import defaultdict +from collections.abc import Mapping +from dataclasses import dataclass, field from functools import lru_cache from types import MappingProxyType -from typing import Any, Dict, Mapping, Set +from typing import Any, Dict, Iterator, Optional, Set, Tuple -from .schema import ParamDef +from .schema import ParamDef, ParamType class RegistryFrozenError(RuntimeError): """Raised when attempting to modify a frozen registry.""" +# Regex for parsing indexed family parameter names: +# patch_ib(123)%vel(1) -> base="patch_ib", index=123, attr="vel(1)" +# patch_ib(1)%geometry -> base="patch_ib", index=1, attr="geometry" +_INDEXED_RE = re.compile(r"^([a-zA-Z_]\w*)\((\d+)\)%(.+)$") + + +def _resolve_family( + name: str, families: Dict[str, "IndexedFamily"] +) -> Optional[Tuple[ParamType, Set[str]]]: + """ + Resolve a parameter name against indexed families. + + Returns (ParamType, tags) if the name matches a registered family + attribute, or None otherwise. This is the single implementation of + family pattern-matching used by both _FamilyAwareMapping and + ParamRegistry. + """ + m = _INDEXED_RE.match(name) + if m is None: + return None + base, idx_str, attr = m.groups() + fam = families.get(base) + if fam is None: + return None + idx = int(idx_str) + if idx < 1: + return None + if fam.max_index is not None and idx > fam.max_index: + return None + entry = fam.attrs.get(attr) + return entry if entry is not None else None + + +@dataclass(frozen=True) +class IndexedFamily: + """ + Template for an indexed parameter family like patch_ib(N)%attr. + + Instead of registering every index individually (e.g., patch_ib(1)%geometry + through patch_ib(N)%geometry for all attributes), we store one template and + validate parameter names via pattern matching. + + Attributes: + base_name: Family prefix (e.g., "patch_ib") + attrs: Mapping of attribute name to (ParamType, tags) — attribute names + may include sub-indices like "vel(1)", "angles(3)". + tags: Metadata-only tags for this family (not used in resolution; + per-attribute tags in ``attrs`` are what get returned). + max_index: Upper bound on the index (1-based). None = unlimited. + """ + + base_name: str + attrs: Dict[str, Tuple[ParamType, Set[str]]] = field(default_factory=dict) + tags: Set[str] = field(default_factory=set) + max_index: Optional[int] = None + + def __post_init__(self): + if not self.base_name or not re.match(r"^[a-zA-Z_]\w*$", self.base_name): + raise ValueError(f"Invalid base_name: {self.base_name!r}") + if self.max_index is not None and self.max_index < 1: + raise ValueError(f"max_index must be >= 1 or None, got {self.max_index}") + + +class _FamilyAwareMapping(Mapping): + """ + Read-only mapping that combines scalar params with indexed family lookups. + + For containment checks and item access, indexed family params like + ``patch_ib(500)%geometry`` are resolved via pattern matching against + registered families — no enumeration needed. + + For iteration (items/keys/values/len), only scalar params and one + representative example per family attribute (index=1) are yielded. + This keeps iteration bounded regardless of max_index. + + keys(), items(), values() are inherited from collections.abc.Mapping + and return proper KeysView/ItemsView/ValuesView objects. + """ + + __slots__ = ("_scalars", "_families", "_examples") + + def __init__( + self, + scalars: Dict[str, ParamDef], + families: Dict[str, IndexedFamily], + ): + self._scalars = scalars + self._families = families + # Pre-build one example per family attr for iteration/docs + self._examples: Dict[str, ParamDef] = {} + for fam in families.values(): + for attr_name, (ptype, tags) in fam.attrs.items(): + key = f"{fam.base_name}(1)%{attr_name}" + self._examples[key] = ParamDef(name=key, param_type=ptype, tags=set(tags)) + + def _make_param_def(self, name: str) -> Optional[ParamDef]: + """Build a ParamDef from a family match, or return None.""" + result = _resolve_family(name, self._families) + if result is None: + return None + ptype, tags = result + return ParamDef(name=name, param_type=ptype, tags=set(tags)) + + def __getitem__(self, key: str) -> ParamDef: + try: + return self._scalars[key] + except KeyError: + pass + result = self._make_param_def(key) + if result is not None: + return result + raise KeyError(key) + + def __contains__(self, key: object) -> bool: + # Note: this intentionally deviates from the standard Mapping contract. + # `key in self` may be True for family params (e.g., patch_ib(500)%geometry) + # that do NOT appear in iter(self) (which only yields index=1 examples). + if key in self._scalars: + return True + if isinstance(key, str): + return _resolve_family(key, self._families) is not None + return False + + def __iter__(self) -> Iterator[str]: + yield from self._scalars + yield from self._examples + + def __len__(self) -> int: + return len(self._scalars) + len(self._examples) + + class ParamRegistry: """ Central registry for MFC parameters. - This class stores parameter definitions and provides lookup methods - for retrieving parameters by name or by feature tag. + Supports two kinds of parameters: + 1. Scalar/small-indexed params — stored individually in _params. + 2. Indexed families — stored as templates in _families, matched by pattern. The registry can be frozen after initialization to prevent further modifications, ensuring thread-safety for read operations. - - Attributes: - _params: Dictionary mapping parameter names to ParamDef instances. - _by_tag: Dictionary mapping tags to sets of parameter names. - _frozen: Whether the registry has been frozen (immutable). """ def __init__(self): """Initialize an empty registry.""" self._params: Dict[str, ParamDef] = {} + self._families: Dict[str, IndexedFamily] = {} self._by_tag: Dict[str, Set[str]] = defaultdict(set) self._frozen: bool = False - self._params_proxy: Mapping[str, ParamDef] = None + self._all_params_view: Optional[_FamilyAwareMapping] = None def freeze(self) -> None: """ Freeze the registry, preventing further modifications. After calling this method: - - register() will raise RegistryFrozenError - - all_params returns a read-only view (MappingProxyType) + - register() and register_family() will raise RegistryFrozenError + - all_params returns a family-aware read-only mapping This method is idempotent (safe to call multiple times). """ if not self._frozen: self._frozen = True - self._params_proxy = MappingProxyType(self._params) + self._all_params_view = _FamilyAwareMapping(self._params, self._families) @property def is_frozen(self) -> bool: @@ -112,20 +245,83 @@ def register(self, param: ParamDef) -> None: for tag in param.tags: self._by_tag[tag].add(param.name) + def register_family(self, family: IndexedFamily) -> None: + """ + Register an indexed parameter family. + + Instead of registering N*attrs individual params, this stores a + single template that is matched by pattern. This makes validation + O(1) per parameter regardless of max_index. + + Args: + family: The indexed family definition. + + Raises: + RegistryFrozenError: If the registry has been frozen. + """ + if self._frozen: + raise RegistryFrozenError(f"Cannot register family '{family.base_name}': registry is frozen.") + self._families[family.base_name] = family + # Register tags for the family (using example names) + for attr_name, (_, tags) in family.attrs.items(): + example = f"{family.base_name}(1)%{attr_name}" + for tag in tags: + self._by_tag[tag].add(example) + + @property + def families(self) -> Mapping[str, IndexedFamily]: + """Get all registered indexed families (read-only view).""" + return MappingProxyType(self._families) + @property def all_params(self) -> Mapping[str, ParamDef]: """ - Get all registered parameters. + Get all registered parameters as a mapping. - Returns: - Mapping of parameter names to their definitions. - If the registry is frozen, returns a read-only view. - If not frozen, returns the internal dict (mutable). + Returns a family-aware mapping that supports: + - Containment: ``'patch_ib(500)%geometry' in registry.all_params`` + - Lookup: ``registry.all_params.get('patch_ib(500)%geometry')`` + - Iteration: yields scalar params + one example per family attr + + If the registry is frozen, returns an immutable view. + If not frozen and no families are registered, returns the internal dict. + If not frozen but families exist, raises RuntimeError (the plain dict + cannot resolve family params — call freeze() first). """ - if self._frozen and self._params_proxy is not None: - return self._params_proxy + if self._frozen and self._all_params_view is not None: + return self._all_params_view + if self._families: + raise RuntimeError( + "Cannot access all_params before freeze() when indexed families " + "are registered. Call freeze() first." + ) return self._params + def is_known_param(self, name: str) -> bool: + """ + Check if a parameter name is valid (scalar or indexed family). + + This is the fast path for validating user-provided parameter names. + O(1) for both scalar params and indexed family params. + """ + if name in self._params: + return True + return _resolve_family(name, self._families) is not None + + def get_param_def(self, name: str) -> Optional[ParamDef]: + """ + Get the ParamDef for a parameter name, resolving indexed families. + + Returns None if the name is not recognized. + """ + if name in self._params: + return self._params[name] + result = _resolve_family(name, self._families) + if result is None: + return None + ptype, tags = result + return ParamDef(name=name, param_type=ptype, tags=set(tags)) + def get_params_by_tag(self, tag: str) -> Dict[str, ParamDef]: """ Get parameters with a specific feature tag. @@ -136,7 +332,16 @@ def get_params_by_tag(self, tag: str) -> Dict[str, ParamDef]: Returns: Dictionary mapping parameter names to their definitions. """ - return {name: self._params[name] for name in self._by_tag.get(tag, set())} + result = {} + for name in self._by_tag.get(tag, set()): + if name in self._params: + result[name] = self._params[name] + else: + # Family example — resolve it + param_def = self.get_param_def(name) + if param_def is not None: + result[name] = param_def + return result def get_all_tags(self) -> Set[str]: """ @@ -151,9 +356,8 @@ def get_json_schema(self) -> Dict[str, Any]: """ Generate JSON schema for case file validation. - Indexed parameter families (e.g., patch_ib(1)%radius through - patch_ib(1000)%radius) are collapsed into patternProperties - regexes to keep the schema small (~500 entries vs ~40,000). + Indexed parameter families (e.g., patch_ib(N)%radius) are + represented as patternProperties regexes. Returns: JSON schema dict compatible with fastjsonschema. @@ -161,19 +365,29 @@ def get_json_schema(self) -> Dict[str, Any]: properties = {} pattern_props = {} - for name, param in self.all_params.items(): + # Scalar and small-indexed params + for name, param in self._params.items(): if "(" not in name: - # Scalar param — explicit property properties[name] = param.param_type.json_schema else: - # Indexed param — collapse into pattern - # Replace digit sequences inside parens: (1) -> (\d+) + # Small indexed param — collapse into pattern pattern = re.sub(r"\(\d+\)", "__IDX__", name) pattern = re.escape(pattern).replace("__IDX__", r"\(\d+\)") pattern = f"^{pattern}$" if pattern not in pattern_props: pattern_props[pattern] = param.param_type.json_schema + # Indexed families — generate one pattern per attribute + for fam in self._families.values(): + base_esc = re.escape(fam.base_name) + for attr_name, (ptype, _tags) in fam.attrs.items(): + # Escape the attr name but replace sub-indices with \(\d+\) + attr_pattern = re.sub(r"\(\d+\)", "__IDX__", attr_name) + attr_pattern = re.escape(attr_pattern).replace("__IDX__", r"\(\d+\)") + pattern = f"^{base_esc}\\([1-9]\\d*\\)%{attr_pattern}$" + if pattern not in pattern_props: + pattern_props[pattern] = ptype.json_schema + return { "type": "object", "properties": properties, diff --git a/toolchain/mfc/params/suggest.py b/toolchain/mfc/params/suggest.py index 11eb4bd60b..0ca219da9a 100644 --- a/toolchain/mfc/params/suggest.py +++ b/toolchain/mfc/params/suggest.py @@ -98,6 +98,9 @@ def suggest_parameter(unknown_param: str) -> List[str]: """ Suggest similar parameter names from the registry. + For indexed family params, suggests the matching family attribute + with index 1 as a representative example. + Args: unknown_param: Unknown parameter name. @@ -107,6 +110,7 @@ def suggest_parameter(unknown_param: str) -> List[str]: # Import here to avoid circular import (registry imports definitions which may use suggest) from .registry import REGISTRY + # all_params.keys() includes scalar params + one example per family attr return suggest_similar(unknown_param, REGISTRY.all_params.keys()) diff --git a/toolchain/mfc/params/validate.py b/toolchain/mfc/params/validate.py index efbc798e38..d6bdc2c057 100644 --- a/toolchain/mfc/params/validate.py +++ b/toolchain/mfc/params/validate.py @@ -44,12 +44,53 @@ from .suggest import suggest_parameter +def _family_attr_error(name: str) -> Optional[str]: + """ + Diagnose why a family-pattern param was rejected by the registry. + + Distinguishes three cases for known family bases: + - Invalid index (0 or exceeding max_index) + - Unknown attribute + - Unknown family base (returns None to let caller handle) + + Returns a targeted error message, or None if name doesn't match + any family pattern. + """ + from .registry import _INDEXED_RE + + m = _INDEXED_RE.match(name) + if m is None: + return None + base, idx_str, attr = m.groups() + fam = REGISTRY.families.get(base) + if fam is None: + return None + + # Check if the problem is the index (attr is valid but index is out of range) + idx = int(idx_str) + if attr in fam.attrs: + if idx < 1: + return f"Invalid index {idx} for {base}: indices are 1-based (must be >= 1)" + if fam.max_index is not None and idx > fam.max_index: + return f"Index {idx} exceeds maximum ({fam.max_index}) for {base}" + return None # Both attr and index look valid; shouldn't reach here + + # Unknown attribute — provide targeted message + valid = sorted(fam.attrs.keys()) + if len(valid) > 8: + shown = ", ".join(valid[:8]) + f", ... ({len(valid)} total)" + else: + shown = ", ".join(valid) + return f"Unknown attribute '{attr}' for {base}. Valid attributes: {shown}" + + def check_unknown_params(params: Dict[str, Any]) -> List[str]: """ Check for unknown parameters and suggest corrections. - Uses fuzzy matching via rapidfuzz to provide "Did you mean?" suggestions - for parameter names that don't exist in the registry. + For indexed family params with a known base but invalid attribute, + provides a targeted "valid attributes" message. Otherwise, uses + fuzzy matching to provide "Did you mean?" suggestions. Args: params: Dictionary of parameter name -> value @@ -60,9 +101,13 @@ def check_unknown_params(params: Dict[str, Any]) -> List[str]: errors = [] for name in params.keys(): - if name not in REGISTRY.all_params: - suggestions = suggest_parameter(name) - errors.append(unknown_param_error(name, suggestions)) + if not REGISTRY.is_known_param(name): + family_err = _family_attr_error(name) + if family_err: + errors.append(family_err) + else: + suggestions = suggest_parameter(name) + errors.append(unknown_param_error(name, suggestions)) return errors @@ -80,7 +125,7 @@ def validate_constraints(params: Dict[str, Any]) -> List[str]: errors = [] for name, value in params.items(): - param_def = REGISTRY.all_params.get(name) + param_def = REGISTRY.get_param_def(name) if param_def is None: continue # Unknown params handled by check_unknown_params @@ -150,7 +195,7 @@ def check_dependencies(params: Dict[str, Any]) -> Tuple[List[str], List[str]]: warnings = [] for name, value in params.items(): - param_def = REGISTRY.all_params.get(name) + param_def = REGISTRY.get_param_def(name) if param_def is None or param_def.dependencies is None: continue diff --git a/toolchain/mfc/params_tests/test_definitions.py b/toolchain/mfc/params_tests/test_definitions.py index 4008543991..ca731c1044 100644 --- a/toolchain/mfc/params_tests/test_definitions.py +++ b/toolchain/mfc/params_tests/test_definitions.py @@ -154,10 +154,12 @@ class TestParameterCounts(unittest.TestCase): """Tests for expected parameter counts.""" def test_total_param_count(self): - """Total parameter count should be around 40000.""" + """Total parameter count (scalars + family examples) should be reasonable.""" count = len(REGISTRY.all_params) - self.assertGreater(count, 39000, f"Too few parameters. Got {count}.") - self.assertLess(count, 41000, f"Too many parameters. Got {count}.") + # After indexed family refactor: patch_ib contributes ~30 examples + # instead of NI*30 individual entries. + self.assertGreater(count, 3000, f"Too few parameters. Got {count}.") + self.assertLess(count, 5000, f"Too many parameters. Got {count}.") def test_log_params_count(self): """Should have many LOG type parameters.""" diff --git a/toolchain/mfc/params_tests/test_integration.py b/toolchain/mfc/params_tests/test_integration.py index 96e7603d88..3d68482464 100644 --- a/toolchain/mfc/params_tests/test_integration.py +++ b/toolchain/mfc/params_tests/test_integration.py @@ -77,6 +77,33 @@ def test_get_json_schema_has_all_params(self): self.assertEqual(len(properties), scalar_count) self.assertGreater(len(pattern_props), 0) + def test_family_pattern_matches_param_names(self): + """patternProperties regexes should match valid family param names.""" + import re + + schema = REGISTRY.get_json_schema() + patterns = schema.get("patternProperties", {}) + + # These valid names must match at least one pattern + valid_names = [ + "patch_ib(1)%geometry", + "patch_ib(99)%vel(1)", + "patch_ib(5)%model_translate(2)", + "patch_ib(1000)%radius", + ] + for name in valid_names: + matched = any(re.match(p, name) for p in patterns) + self.assertTrue(matched, f"'{name}' did not match any patternProperties regex") + + # These invalid names must NOT match any pattern + invalid_names = [ + "patch_ib(0)%geometry", + "patch_ib(1)%bogus_attr", + ] + for name in invalid_names: + matched = any(re.match(p, name) for p in patterns) + self.assertFalse(matched, f"'{name}' should not match any patternProperties regex") + def test_core_params_in_schema(self): """Core params should be in JSON schema.""" schema = REGISTRY.get_json_schema() @@ -129,12 +156,21 @@ def test_case_dicts_loads_from_registry(self): # ALL should be populated self.assertIsNotNone(case_dicts.ALL) - def test_case_dicts_all_contains_all_params(self): - """case_dicts.ALL should contain all registry params.""" + def test_case_dicts_all_contains_registry_params(self): + """case_dicts.ALL should recognize all registry params.""" from ..run import case_dicts - # ALL should have approximately the same params as registry - self.assertEqual(len(case_dicts.ALL), len(REGISTRY.all_params)) + # ALL should contain scalar params and family params + self.assertIn("m", case_dicts.ALL) + self.assertIn("model_eqns", case_dicts.ALL) + # Family params should also be recognized via pattern matching + self.assertIn("patch_ib(1)%geometry", case_dicts.ALL) + self.assertIn("patch_ib(500)%radius", case_dicts.ALL) + # Bogus attr, unknown family, and zero index should not match + self.assertNotIn("nonexistent_param", case_dicts.ALL) + self.assertNotIn("patch_ib(1)%bogus_attr", case_dicts.ALL) + self.assertNotIn("patch_ib(0)%geometry", case_dicts.ALL) + self.assertNotIn("unknown_family(1)%geometry", case_dicts.ALL) def test_case_optimization_params_from_registry(self): """CASE_OPTIMIZATION should be populated from registry.""" @@ -172,20 +208,14 @@ def test_get_validator_works(self): self.assertTrue(callable(validator)) def test_get_input_dict_keys(self): - """get_input_dict_keys should return target-specific params.""" + """get_input_dict_keys should return target-aware set supporting 'in'.""" from ..run import case_dicts - # Each target gets a filtered subset of params based on Fortran namelists + # Each target gets a set-like object that checks base name against namelists pre_keys = case_dicts.get_input_dict_keys("pre_process") sim_keys = case_dicts.get_input_dict_keys("simulation") post_keys = case_dicts.get_input_dict_keys("post_process") - # pre_process has most params (includes patch_icpp, patch_bc) - self.assertGreater(len(pre_keys), 2500) - # simulation and post_process have fewer (no patch_icpp, etc.) - self.assertGreater(len(sim_keys), 500) - self.assertGreater(len(post_keys), 400) - # Verify target-specific filtering based on Fortran namelists self.assertIn("num_patches", pre_keys) self.assertNotIn("num_patches", sim_keys) @@ -196,10 +226,12 @@ def test_get_input_dict_keys(self): self.assertNotIn("run_time_info", post_keys) # Verify indexed params are filtered correctly - patch_icpp_pre = [k for k in pre_keys if k.startswith("patch_icpp")] - patch_icpp_sim = [k for k in sim_keys if k.startswith("patch_icpp")] - self.assertGreater(len(patch_icpp_pre), 1000) # Many patch_icpp params - self.assertEqual(len(patch_icpp_sim), 0) # None in simulation + self.assertIn("patch_icpp(1)%geometry", pre_keys) + self.assertNotIn("patch_icpp(1)%geometry", sim_keys) + + # Verify family params (patch_ib) work for simulation target + self.assertIn("patch_ib(1)%geometry", sim_keys) + self.assertIn("patch_ib(500)%radius", sim_keys) # Verify shared params are in all targets self.assertIn("m", pre_keys) diff --git a/toolchain/mfc/params_tests/test_registry.py b/toolchain/mfc/params_tests/test_registry.py index ee16844915..a16a052406 100644 --- a/toolchain/mfc/params_tests/test_registry.py +++ b/toolchain/mfc/params_tests/test_registry.py @@ -1,12 +1,12 @@ """ Unit tests for params/registry.py module. -Tests registry functionality, freezing, and tag queries. +Tests registry functionality, freezing, tag queries, and indexed families. """ import unittest -from ..params.registry import ParamRegistry, RegistryFrozenError +from ..params.registry import IndexedFamily, ParamRegistry, RegistryFrozenError from ..params.schema import ParamDef, ParamType @@ -102,6 +102,170 @@ def test_all_params_readonly_after_freeze(self): params["hacked"] = "value" +class TestIndexedFamily(unittest.TestCase): + """Tests for indexed family registration and resolution.""" + + def _make_registry_with_family(self, max_index=None): + """Helper: registry with one indexed family.""" + reg = ParamRegistry() + reg.register_family( + IndexedFamily( + base_name="thing", + attrs={ + "geom": (ParamType.INT, {"tag1"}), + "vel(1)": (ParamType.REAL, {"tag1"}), + }, + tags={"tag1"}, + max_index=max_index, + ) + ) + reg.freeze() + return reg + + def test_valid_family_param_is_known(self): + """Valid indexed family param should be recognized.""" + reg = self._make_registry_with_family() + self.assertTrue(reg.is_known_param("thing(1)%geom")) + self.assertTrue(reg.is_known_param("thing(999)%vel(1)")) + + def test_bogus_attr_rejected(self): + """Unknown attribute on a valid family should be rejected.""" + reg = self._make_registry_with_family() + self.assertFalse(reg.is_known_param("thing(1)%nonexistent")) + + def test_zero_index_rejected(self): + """Index 0 should be rejected (1-based indexing).""" + reg = self._make_registry_with_family() + self.assertFalse(reg.is_known_param("thing(0)%geom")) + + def test_unknown_family_rejected(self): + """Unknown family base name should be rejected.""" + reg = self._make_registry_with_family() + self.assertFalse(reg.is_known_param("unknown(1)%geom")) + + def test_max_index_enforced(self): + """Indices beyond max_index should be rejected.""" + reg = self._make_registry_with_family(max_index=5) + self.assertTrue(reg.is_known_param("thing(5)%geom")) + self.assertFalse(reg.is_known_param("thing(6)%geom")) + + def test_unlimited_index_when_no_max(self): + """Without max_index, arbitrarily large indices are valid.""" + reg = self._make_registry_with_family(max_index=None) + self.assertTrue(reg.is_known_param("thing(999999)%geom")) + + def test_get_param_def_returns_correct_type(self): + """get_param_def should return correct ParamType for family params.""" + reg = self._make_registry_with_family() + pdef = reg.get_param_def("thing(3)%geom") + self.assertIsNotNone(pdef) + self.assertEqual(pdef.param_type, ParamType.INT) + self.assertEqual(pdef.name, "thing(3)%geom") + + def test_get_param_def_none_for_invalid(self): + """get_param_def should return None for invalid family params.""" + reg = self._make_registry_with_family() + self.assertIsNone(reg.get_param_def("thing(1)%bogus")) + self.assertIsNone(reg.get_param_def("unknown(1)%geom")) + + def test_all_params_iteration_bounded(self): + """Iterating all_params should yield scalars + one example per family attr.""" + reg = ParamRegistry() + reg.register(ParamDef(name="scalar1", param_type=ParamType.INT)) + reg.register(ParamDef(name="scalar2", param_type=ParamType.REAL)) + reg.register_family( + IndexedFamily( + base_name="thing", + attrs={ + "geom": (ParamType.INT, {"tag1"}), + "vel(1)": (ParamType.REAL, {"tag1"}), + }, + tags={"tag1"}, + ) + ) + reg.freeze() + + keys = list(reg.all_params) + # 2 scalars + 2 family attrs (one example each at index=1) + self.assertEqual(len(reg.all_params), 4) + self.assertEqual(len(keys), 4) + self.assertIn("scalar1", keys) + self.assertIn("scalar2", keys) + # Family examples use index=1 + self.assertIn("thing(1)%geom", keys) + self.assertIn("thing(1)%vel(1)", keys) + # Arbitrary indices should NOT appear in iteration + self.assertNotIn("thing(42)%geom", keys) + + def test_all_params_contains_family(self): + """all_params mapping should resolve family params via __contains__.""" + reg = self._make_registry_with_family() + self.assertIn("thing(42)%geom", reg.all_params) + self.assertNotIn("thing(1)%bogus", reg.all_params) + + def test_all_params_getitem_family(self): + """all_params[family_param] should return a ParamDef.""" + reg = self._make_registry_with_family() + pdef = reg.all_params["thing(7)%vel(1)"] + self.assertEqual(pdef.param_type, ParamType.REAL) + + def test_all_params_get_family(self): + """all_params.get() should resolve family params and respect defaults.""" + reg = self._make_registry_with_family() + pdef = reg.all_params.get("thing(5)%geom") + self.assertIsNotNone(pdef) + self.assertEqual(pdef.param_type, ParamType.INT) + self.assertEqual(reg.all_params.get("thing(1)%nonexistent", "default"), "default") + + def test_all_params_getitem_raises_for_invalid(self): + """all_params[invalid] should raise KeyError.""" + reg = self._make_registry_with_family() + with self.assertRaises(KeyError): + _ = reg.all_params["thing(1)%bogus"] + + def test_register_family_after_freeze_raises(self): + """Registering a family after freeze should raise.""" + reg = ParamRegistry() + reg.freeze() + with self.assertRaises(RegistryFrozenError): + reg.register_family( + IndexedFamily(base_name="late", attrs={}, tags=set()) + ) + + def test_invalid_base_name_raises(self): + """IndexedFamily with empty or invalid base_name should raise.""" + with self.assertRaises(ValueError): + IndexedFamily(base_name="") + with self.assertRaises(ValueError): + IndexedFamily(base_name="123invalid") + + def test_invalid_max_index_raises(self): + """IndexedFamily with max_index < 1 should raise.""" + with self.assertRaises(ValueError): + IndexedFamily(base_name="test", max_index=0) + with self.assertRaises(ValueError): + IndexedFamily(base_name="test", max_index=-5) + + def test_all_params_before_freeze_with_families_raises(self): + """Accessing all_params before freeze with families should raise.""" + reg = ParamRegistry() + reg.register_family( + IndexedFamily( + base_name="fam", + attrs={"x": (ParamType.INT, {"t"})}, + tags={"t"}, + ) + ) + with self.assertRaises(RuntimeError): + _ = reg.all_params + + def test_family_params_by_tag(self): + """get_params_by_tag should include family example params.""" + reg = self._make_registry_with_family() + tagged = reg.get_params_by_tag("tag1") + self.assertGreater(len(tagged), 0) + + class TestGlobalRegistry(unittest.TestCase): """Tests for the global REGISTRY instance.""" @@ -115,7 +279,7 @@ def test_global_registry_has_params(self): """Global REGISTRY should have parameters loaded.""" from ..params import REGISTRY - self.assertGreater(len(REGISTRY.all_params), 3500) + self.assertGreater(len(REGISTRY.all_params), 3000) def test_global_registry_cannot_be_modified(self): """Global REGISTRY should reject new registrations.""" diff --git a/toolchain/mfc/params_tests/test_validate.py b/toolchain/mfc/params_tests/test_validate.py index 061470bf24..f682aced99 100644 --- a/toolchain/mfc/params_tests/test_validate.py +++ b/toolchain/mfc/params_tests/test_validate.py @@ -79,6 +79,29 @@ def test_unknown_param_returns_error(self): self.assertIn("Unknown parameter", errors[0]) self.assertIn("totally_unknown_xyz_123", errors[0]) + def test_family_attr_typo_gives_targeted_error(self): + """Typo in family attribute should list valid attributes.""" + params = {"patch_ib(1)%geometri": 1} + errors = check_unknown_params(params) + self.assertEqual(len(errors), 1) + self.assertIn("Valid attributes", errors[0]) + self.assertIn("geometry", errors[0]) + self.assertNotIn("Did you mean", errors[0]) + + def test_family_valid_attr_no_error(self): + """Valid family param should not generate an error.""" + params = {"patch_ib(1)%geometry": 1} + errors = check_unknown_params(params) + self.assertEqual(errors, []) + + def test_family_zero_index_gives_index_error(self): + """Zero index on valid attr should report index error, not attr error.""" + params = {"patch_ib(0)%geometry": 1} + errors = check_unknown_params(params) + self.assertEqual(len(errors), 1) + self.assertIn("1-based", errors[0]) + self.assertNotIn("Unknown attribute", errors[0]) + @unittest.skipUnless(RAPIDFUZZ_AVAILABLE, "rapidfuzz not installed") def test_similar_param_suggests_correction(self): """Typo near valid param should suggest 'did you mean?'.""" diff --git a/toolchain/mfc/run/case_dicts.py b/toolchain/mfc/run/case_dicts.py index 5e7006128f..7100e9e7dd 100644 --- a/toolchain/mfc/run/case_dicts.py +++ b/toolchain/mfc/run/case_dicts.py @@ -5,24 +5,47 @@ All parameter definitions are sourced from the registry. Exports: - ALL: Dict of all parameters {name: ParamType} + ALL: Family-aware mapping of all parameters {name: ParamType} IGNORE: Parameters to skip during certain operations CASE_OPTIMIZATION: Parameters that can be hard-coded for GPU builds SCHEMA: JSON schema for fastjsonschema validation get_validator(): Returns compiled JSON schema validator - get_input_dict_keys(): Get parameter keys for a target + get_input_dict_keys(): Get set-like object for target parameter checking """ import re +from collections.abc import Mapping from ..state import ARG -def _load_all_params(): - """Load all parameters as {name: ParamType} dict.""" - from ..params import REGISTRY +class _ParamTypeMapping(Mapping): + """ + Read-only mapping wrapping REGISTRY's all_params for {name: ParamType} access. + + Delegates containment checks and lookup to the registry's family-aware + mapping, so indexed families like ``patch_ib(500000)%geometry`` resolve + in O(1) without enumerating all possible indices. + + For iteration, yields scalar params plus one example per family attr. + """ + + def __init__(self): + from ..params import REGISTRY - return {name: param.param_type for name, param in REGISTRY.all_params.items()} + self._view = REGISTRY.all_params + + def __contains__(self, key): + return key in self._view + + def __getitem__(self, key): + return self._view[key].param_type + + def __iter__(self): + return iter(self._view) + + def __len__(self): + return len(self._view) def _load_case_optimization_params(): @@ -56,8 +79,8 @@ def _get_target_params(): # Parameters to ignore during certain operations IGNORE = ["cantera_file", "chemistry"] -# Combined dict of all parameters -ALL = _load_all_params() +# Family-aware mapping of all parameters — supports O(1) lookup for indexed families +ALL = _ParamTypeMapping() # Parameters that can be hard-coded for GPU case optimization CASE_OPTIMIZATION = _load_case_optimization_params() @@ -65,6 +88,9 @@ def _get_target_params(): # JSON schema for validation SCHEMA = _build_schema() +# Regex to extract the base name from indexed params +_BASE_NAME_RE = re.compile(r"^([a-zA-Z_][a-zA-Z0-9_]*)") + def _is_param_valid_for_target(param_name: str, target_name: str) -> bool: """ @@ -85,7 +111,7 @@ def _is_param_valid_for_target(param_name: str, target_name: str) -> bool: # e.g., "patch_icpp(1)%geometry" -> "patch_icpp" # e.g., "fluid_pp(2)%gamma" -> "fluid_pp" # e.g., "acoustic(1)%loc(1)" -> "acoustic" - match = re.match(r"^([a-zA-Z_][a-zA-Z0-9_]*)", param_name) + match = _BASE_NAME_RE.match(param_name) if match: base_name = match.group(1) return base_name in target_params @@ -93,26 +119,42 @@ def _is_param_valid_for_target(param_name: str, target_name: str) -> bool: return param_name in target_params -def get_input_dict_keys(target_name: str) -> list: +class _TargetKeySet: """ - Get parameter keys for a given target. + Set-like object for checking if a param is valid for a specific target. + + Supports ``key in target_key_set`` via base-name matching against the + Fortran namelist, plus optionally filtering out case-optimization params. + Does not enumerate all possible indexed family members. + """ + + def __init__(self, target_name: str, filter_case_opt: bool = False): + self._target_name = target_name + self._filter_case_opt = filter_case_opt + self._case_opt = set(CASE_OPTIMIZATION) if filter_case_opt else set() + + def __contains__(self, key): + if self._filter_case_opt and key in self._case_opt: + return False + return _is_param_valid_for_target(key, self._target_name) - Uses the Fortran namelist definitions as the source of truth. - Only returns params whose base name is in the target's namelist. + +def get_input_dict_keys(target_name: str): + """ + Get a set-like object for checking parameter validity for a target. + + Returns an object that supports ``key in result`` for O(1) checks. + For indexed families, this does NOT enumerate all possible indices — + it checks the base name against the Fortran namelist. Args: target_name: One of 'pre_process', 'simulation', 'post_process' Returns: - List of parameter names valid for that target + Set-like object supporting ``in`` operator """ - keys = [k for k in ALL.keys() if _is_param_valid_for_target(k, target_name)] - - # Case optimization filtering for simulation - if ARG("case_optimization", dflt=False) and target_name == "simulation": - keys = [k for k in keys if k not in CASE_OPTIMIZATION] - - return keys + filter_case_opt = ARG("case_optimization", dflt=False) and target_name == "simulation" + return _TargetKeySet(target_name, filter_case_opt) def get_validator():