Skip to content
Merged
8 changes: 7 additions & 1 deletion src/simtools/data_model/model_data_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,7 +400,12 @@ def validate_and_transform(
data_dict=product_data_dict,
check_exact_data_type=False,
)
return _validator.validate_and_transform(is_model_parameter)
validated = _validator.validate_and_transform(is_model_parameter)

if is_model_parameter and isinstance(validated, dict):
validated["unit"] = value_conversion.normalize_dimensionless_unit(validated.get("unit"))
Comment thread
orelgueta marked this conversation as resolved.

return validated

def write(self, product_data=None, metadata=None):
"""
Expand Down Expand Up @@ -488,6 +493,7 @@ def prepare_data_dict_for_writing(data_dict):

"""
try:
data_dict["unit"] = value_conversion.normalize_dimensionless_unit(data_dict["unit"])
if isinstance(data_dict["unit"], str):
data_dict["unit"] = data_dict["unit"].replace("None", "null")
elif isinstance(data_dict["unit"], list):
Expand Down
61 changes: 60 additions & 1 deletion src/simtools/data_model/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from simtools.data_model import format_checkers
from simtools.dependencies import get_software_version
from simtools.io import ascii_handler
from simtools.utils import names
from simtools.utils import names, value_conversion
from simtools.version import check_version_constraint

_logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -67,6 +67,26 @@ def get_model_parameter_schema_file(parameter):
return schema_file


def get_model_parameter_schema(parameter, schema_version=None):
Comment thread
GernotMaier marked this conversation as resolved.
"""
Return schema dictionary for a given model parameter and schema version.

Parameters
----------
parameter: str
Model parameter name.
schema_version: str, optional
Schema version. If not provided, the latest version is used.

Returns
-------
dict
Schema dictionary.
"""
schema_file = get_model_parameter_schema_file(parameter)
return load_schema(schema_file, schema_version=schema_version or "latest")


def get_model_parameter_schema_version(schema_version=None):
"""
Validate and return schema versions.
Expand Down Expand Up @@ -460,3 +480,42 @@ def _extract_schema_from_file(file_name, observatory="cta"):
return None

return _extract_schema_url_from_metadata_dict(metadata, observatory)


def get_parameter_attribute_from_schema(par_name, schema_version, attribute_name):
"""
Return one attribute from model-parameter schema data entries.

Parameters
----------
par_name: str
Name of the parameter.
schema_version: str
Schema version to look up.
attribute_name: str
Attribute to read from data entries (e.g. "type", "unit").

Returns
-------
str or list or None
Attribute value as scalar for single-entry schemas, list for multi-entry schemas,
or None for unsupported schema data structures.
"""
schema_dict = get_model_parameter_schema(par_name, schema_version)
data = schema_dict.get("data", [])
if isinstance(data, list):
values = [
_normalize_parameter_schema_attribute(attribute_name, item.get(attribute_name))
for item in data
]
return values[0] if len(values) == 1 else values
if isinstance(data, dict):
return _normalize_parameter_schema_attribute(attribute_name, data.get(attribute_name))
return None


def _normalize_parameter_schema_attribute(attribute_name, value):
"""Normalize schema attribute values for public helper functions."""
if attribute_name == "unit":
return value_conversion.normalize_dimensionless_unit(value)
return value
Loading