diff --git a/Examples/base/README.rst b/Examples/base/README.rst index 24202501..46da573a 100644 --- a/Examples/base/README.rst +++ b/Examples/base/README.rst @@ -3,4 +3,4 @@ Subclassing Examples ------------------------ -This section gathers examples which correspond to subclassing the :class:`easyscience.Objects.Base.BaseObj` class. +This section gathers examples which correspond to subclassing the :class:`easyscience.base_classes.ObjBase` class. diff --git a/Examples/base/plot_baseclass1.py b/Examples/base/plot_baseclass1.py index f0ca9c13..b87c559e 100644 --- a/Examples/base/plot_baseclass1.py +++ b/Examples/base/plot_baseclass1.py @@ -1,8 +1,8 @@ """ -Subclassing BaseObj - Simple Pendulum +Subclassing ObjBase - Simple Pendulum ===================================== -This example shows how to subclass :class:`easyscience.Objects.Base.BaseObj` with parameters from -:class:`EasyScience.Objects.Base.Parameter`. For this example a simple pendulum will be modeled. +This example shows how to subclass :class:`easyscience.base_classes.ObjBase` with parameters from +:class:`EasyScience.variable.Parameter`. For this example a simple pendulum will be modeled. .. math:: y = A \sin (2 \pi f t + \phi ) @@ -17,8 +17,8 @@ import matplotlib.pyplot as plt import numpy as np -from easyscience.Objects.ObjectClasses import BaseObj -from easyscience.Objects.ObjectClasses import Parameter +from easyscience.base_classes import ObjBase +from easyscience.variable import Parameter # %% # Subclassing @@ -29,7 +29,7 @@ # embedded rST text block: -class Pendulum(BaseObj): +class Pendulum(ObjBase): def __init__(self, A: Parameter, f: Parameter, p: Parameter): super(Pendulum, self).__init__('SimplePendulum', A=A, f=f, p=p) diff --git a/LICENSE b/LICENSE index c1ee0cf3..f21bf746 100644 --- a/LICENSE +++ b/LICENSE @@ -1,6 +1,6 @@ BSD 3-Clause License -Copyright (c) 2024, Easyscience contributors (https://github.com/EasyScience) +Copyright (c) 2025, Easyscience contributors (https://github.com/EasyScience) All rights reserved. Redistribution and use in source and binary forms, with or without diff --git a/docs/src/conf.py b/docs/src/conf.py index 2d95445a..9bf23650 100644 --- a/docs/src/conf.py +++ b/docs/src/conf.py @@ -51,9 +51,7 @@ intersphinx_mapping = { 'python': ('https://docs.python.org/3', None), - 'numpy': ('https://numpy.org/doc/stable/', None), - 'pint': ('https://pint.readthedocs.io/en/stable/', None), - 'xarray': ('https://xarray.pydata.org/en/stable/', None) + 'numpy': ('https://numpy.org/doc/stable/', None) } # -- General configuration --------------------------------------------------- diff --git a/docs/src/reference/base.rst b/docs/src/reference/base.rst index 59e9de32..ed3d05de 100644 --- a/docs/src/reference/base.rst +++ b/docs/src/reference/base.rst @@ -5,13 +5,13 @@ Parameters and Objects Descriptors =========== -.. autoclass:: easyscience.Objects.Variable.Descriptor +.. autoclass:: easyscience.variable.Descriptor :members: Parameters ========== -.. autoclass:: easyscience.Objects.Variable.Parameter +.. autoclass:: easyscience.variable.Parameter :members: :inherited-members: @@ -22,30 +22,17 @@ Super Classes and Collections Super Classes ============= -.. autoclass:: easyscience.Objects.ObjectClasses.BasedBase +.. autoclass:: easyscience.base_classes.BasedBase :members: :inherited-members: -.. autoclass:: easyscience.Objects.ObjectClasses.BaseObj +.. autoclass:: easyscience.base_classes.ObjBase :members: +_add_component :inherited-members: Collections =========== -.. autoclass:: easyscience.Objects.Groups.BaseCollection +.. autoclass:: easyscience.CollectionBase :members: :inherited-members: - -=============== -Data Containers -=============== - -.. autoclass:: easyscience.Datasets.xarray.EasyScienceDataarrayAccessor - :members: - :inherited-members: - -.. autoclass:: easyscience.Datasets.xarray.EasyScienceDatasetAccessor - :members: - :inherited-members: - diff --git a/examples_old/example4.py b/examples_old/example4.py index 3b01387a..720ff961 100644 --- a/examples_old/example4.py +++ b/examples_old/example4.py @@ -11,7 +11,7 @@ from easyscience import global_object from easyscience.fitting import Fitter -from easyscience.Objects.core import ComponentSerializer +from easyscience.Objects.component_serializer import ComponentSerializer from easyscience.Objects.ObjectClasses import BaseObj from easyscience.Objects.ObjectClasses import Parameter diff --git a/examples_old/example5_broken.py b/examples_old/example5_broken.py index de12e67a..a7967ee8 100644 --- a/examples_old/example5_broken.py +++ b/examples_old/example5_broken.py @@ -12,7 +12,7 @@ from easyscience.fitting import Fitter from easyscience.Objects.Base import BaseObj from easyscience.Objects.Base import Parameter -from easyscience.Objects.core import ComponentSerializer +from easyscience.Objects.component_serializer import ComponentSerializer # from easyscience.Objects.Base import LoggedProperty from easyscience.Objects.Inferface import InterfaceFactoryTemplate diff --git a/examples_old/example6_broken.py b/examples_old/example6_broken.py index c93f7577..fe2dbbf2 100644 --- a/examples_old/example6_broken.py +++ b/examples_old/example6_broken.py @@ -12,7 +12,7 @@ from easyscience.fitting import Fitter from easyscience.Objects.ObjectClasses import BaseObj from easyscience.Objects.Variable import Parameter -from easyscience.Objects.core import ComponentSerializer +from easyscience.Objects.component_serializer import ComponentSerializer from easyscience.Objects.Inferface import InterfaceFactoryTemplate # This is a much more complex case where we have calculators, interfaces, interface factory and an diff --git a/pyproject.toml b/pyproject.toml index 734444b8..cfa5f735 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,8 +35,6 @@ dependencies = [ "lmfit", "numpy", "uncertainties", - "xarray", - "pint", # Only to ensure that unit is reported as dimensionless rather than empty string "scipp" ] @@ -75,6 +73,7 @@ packages = ["src"] [tool.hatch.build.targets.wheel] packages = ["src/easyscience"] +exclude = ["src/easyscience/legacy"] [tool.coverage.run] source = ["src/easyscience"] diff --git a/resources/scripts/generate_html.py b/resources/scripts/generate_html.py index 92a90981..0a399639 100644 --- a/resources/scripts/generate_html.py +++ b/resources/scripts/generate_html.py @@ -1,5 +1,3 @@ -__author__ = 'github.com/wardsimon' -__version__ = '0.0.1' import sys diff --git a/src/easyscience/Datasets/__init__.py b/src/easyscience/Datasets/__init__.py deleted file mode 100644 index 22e236a6..00000000 --- a/src/easyscience/Datasets/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -# SPDX-FileCopyrightText: 2023 EasyScience contributors -# SPDX-License-Identifier: BSD-3-Clause -# © 2021-2023 Contributors to the EasyScience project -# SPDX-License-Identifier: BSD-3-Clause -# © 2021-2023 Contributors to the EasyScience project str: - """ - Get the common name of the DataSet. - - :return: Common name of the DataSet - :rtype: str - """ - return self._obj.attrs['name'] - - @name.setter - def name(self, new_name: str): - """ - Set the common name of the DataSet i.e could be experiment name... - - :param new_name: Common name of the DataSet - :type new_name: str - :return: None - :rtype: None - """ - self._obj.attrs['name'] = new_name - - @property - def description(self) -> str: - """ - Get a description of the DataSet - - :return: Description of the DataSet - :rtype: str - """ - return self._obj.attrs['description'] - - @description.setter - def description(self, new_description: str): - """ - Set the description of the DataSet - - :param new_description: Description of the DataSet - :type new_description: str - :return: None - :rtype: None - """ - self._obj.attrs['description'] = new_description - - @property - def url(self) -> str: - """ - Get the url of the DataSet - - :return: URL of the DataSet (empty if no URL) - :rtype: str - """ - return self._obj.attrs['url'] - - @url.setter - def url(self, new_url: str): - """ - Set the URL of the DataSet. This may be a DOI. - - :param new_url: New URL/DOI of the DataSet - :type new_url: str - :return:None - :rtype: None - """ - self._obj.attrs['url'] = new_url - - @property - def core_object(self): - """ - Get the core object associated to a DataSet. Note that this is called from a weakref. If the EasyScience obj is - garbage collected, None will be returned. - - :return: EasyScience object associated with the DataSet - :rtype: Any - """ - if self._core_object is None: - return None - return self._core_object() - - @core_object.setter - def core_object(self, new_core_object: Any): - """ - Associate an EasyScience object to a DataSet. - - :param new_core_object: EasyScience object to be associated to the DataSet - :type new_core_object: Any - :return: None - :rtype: None - """ - self._core_object = weakref.ref(new_core_object) - - def add_coordinate( - self, - coordinate_name: str, - coordinate_values: Union[List[T_], np.ndarray], - unit: str = '', - ): - """ - Add a coordinate to the DataSet. This can be then be assigned to one or more DataArrays. - - :param coordinate_name: Name of the coordinate e.g. `x` - :type coordinate_name: str - :param coordinate_values: Points for the coordinates - :type coordinate_values: Union[List[T_], numpy.ndarray] - :param unit: Unit associated with the coordinate - :type unit: str - :return: None - :rtype: None - """ - self._obj.coords[coordinate_name] = coordinate_values - self._obj.attrs['units'][coordinate_name] = ureg.Unit(unit) - - def remove_coordinate(self, coordinate_name: str): - """ - Remove a coordinate from the DataSet. Note that this will not remove the coordinate from DataArrays which have - already used the it! - - :param coordinate_name: Name of the coordinate to be removed - :type coordinate_name: str - :return: None - :rtype: None - """ - del self._obj.coords[coordinate_name] - del self._obj.attrs['units'][coordinate_name] - - def add_variable( - self, - variable_name, - variable_coordinates: Union[str, List[str]], - variable_values: Union[List[T_], np.ndarray], - variable_sigma: Union[List[T_], np.ndarray] = None, - unit: str = '', - auto_sigma: bool = False, - ): - """ - Create a DataArray from known coordinates and data, assign it to the dataset under a given name. Variances can - be calculated assuming gaussian distribution to 1 sigma. - - :param variable_name: Name of the DataArray which will be created and added to the dataset - :type variable_name: str - :param variable_coordinates: List of coordinates used in the supplied data array. - :type variable_coordinates: str, List[str] - :param variable_values: Numpy or list of data which will be assigned to the DataArray - :type variable_values: Union[numpy.ndarray, list] - :param variable_sigma: If the sigmas of the dataset are known, they can be supplied here. - :type variable_sigma: Union[numpy.ndarray, list] - :param unit: Unit associated with the DataArray - :type unit: str - :param auto_sigma: Should the sigma DataArray be automatically calculated assuming gaussian probability? - :type auto_sigma: bool - :return: None - :rtype: None - """ - - # Check if a user has supplied a coordinate as a string. Make it a list of strings - if isinstance(variable_coordinates, str): - variable_coordinates = [variable_coordinates] - - # The variable_coordinates can be any iterable object. Though we would assume list/tuple - if not isinstance(variable_coordinates, Iterable): - raise ValueError('The variable coordinates must be a list of strings') - - # Check to see if the user want to assign a coordinate which does not exist yet. - known_keys = self._obj.coords.keys() - for dimension in variable_coordinates: - if dimension not in known_keys: - raise ValueError(f'The supplied coordinate `{dimension}` must first be defined.') - - # Create the dataset. - self._obj[variable_name] = (variable_coordinates, variable_values) - - # Deal with sigmas - if variable_sigma is not None: - # CASE 1, user has supplied sigmas - if isinstance(variable_sigma, Callable): - # CASE 1-1, The sigmas are created by some kind of generator - self.sigma_generator(variable_name, variable_sigma) - elif isinstance(variable_sigma, np.ndarray): - # CASE 1-2, The sigmas are a numpy arrays - self.sigma_attach(variable_name, variable_sigma) - elif isinstance(variable_sigma, list): - # CASE 1-3, We have been given a list. Make it a numpy array - self.sigma_attach(variable_name, np.array(variable_sigma)) - else: - raise ValueError('User supplied sigmas must be of the form; Callable fn, numpy array, list') - else: - # CASE 2, No sigmas have been supplied. - if auto_sigma: - # CASE 2-1, Automatically generate the sigmas using gaussian probability - self.sigma_generator(variable_name) - - # Set units for the newly created DataArray - self._obj.attrs['units'][variable_name] = ureg.Unit(unit) - # If a sigma has been attached, attempt to work out the units. - if unit and variable_sigma is None and auto_sigma: - self._obj.attrs['units'][self.sigma_label_prefix + variable_name] = ureg.Unit(unit + ' ** 0.5') - else: - if auto_sigma: - self._obj.attrs['units'][self.sigma_label_prefix + variable_name] = ureg.Unit('') - - def remove_variable(self, variable_name: str): - """ - Remove a DataArray from the DataSet by supplied name. - - :param variable_name: Name of DataArray to be removed - :type variable_name: str - :return: None - :rtype: None - """ - del self._obj[variable_name] - - def sigma_generator( - self, - variable_label: str, - sigma_func: Callable = lambda x: np.sqrt(np.abs(x)), - label_prefix: str = None, - ): - """ - Generate sigmas off of a DataArray based on a function. - - :param variable_label: Name of the DataArray to perform the calculation on - :type variable_label: str - :param sigma_func: Function to generate the sigmas. Must be of the form f(x) and return an array of the same shape as the input. Default sqrt(\\|x\\|) - :type sigma_func: Callable - :param label_prefix: What prefix should be used to designate a sigma DataArray from a data DataArray - :type label_prefix: str - :return: None - :rtype: None - """ # noqa: E501 - sigma_values = sigma_func(self._obj[variable_label]) - self.sigma_attach(variable_label, sigma_values, label_prefix) - - def sigma_attach( - self, - variable_label: str, - sigma_values: Union[List[T_], np.ndarray, xr.DataArray], - label_prefix: str = None, - ): - """ - Attach an array of sigmas to the DataSet. - - :param variable_label: Name of the DataArray to perform the calculation on - :type variable_label: str - :param sigma_values: Array of sigmas in list, numpy or DataArray form - :type sigma_values: Union[List[T_], numpy.ndarray, xarray.DataArray] - :param label_prefix: What prefix should be used to designate a sigma DataArray from a data DataArray - :type label_prefix: str - :return: None - :rtype: None - """ - # Use the default sigma prefix if not defined. - if label_prefix is None: - label_prefix = self.sigma_label_prefix - - # Form the label for the new DataArray - sigma_label = label_prefix + variable_label - - # Map the original DataArray to the new sigma DataArray - self.__error_mapper[variable_label] = sigma_label - # Assign the sigma DataArray to the DataSet - if not isinstance(sigma_values, xr.DataArray): - self._obj[sigma_label] = ( - list(self._obj[variable_label].coords.keys()), - sigma_values, - ) - else: - self._obj[sigma_label] = sigma_values - - def generate_points(self, coordinates: List[str]) -> xr.DataArray: - """ - Generate an expanded DataArray of points which corresponds to broadcasted dimensions (`all_x`) which have been - concatenated along the second axis (`fit_dim`). - - :param coordinates: List of coordinate names to broadcast and concatenate along - :type coordinates: List[str] - :return: Broadcasted and concatenated coordinates - :rtype: xarray.DataArray - - .. code-block:: python - - x = [1, 2], y = [3, 4] - d = xr.DataArray() - d.EasyScience.add_coordinate('x', x) - d.EasyScience.add_coordinate('y', y) - points = d.EasyScience.generate_points(['x', 'y']) - print(points) - """ - - coords = [self._obj.coords[da] for da in coordinates] - c_array = [] - n_array = [] - for da in xr.broadcast(*coords): - c_array.append(da) - n_array.append(da.name) - - f = xr.concat(c_array, dim='fit_dim') - f = f.stack(all_x=n_array) - return f - - def fit( - self, - fitter, - data_arrays: list, - *args, - dask: str = 'forbidden', - fit_kwargs: dict = None, - fn_kwargs: dict = None, - vectorized: bool = False, - **kwargs, - ) -> List[FitResults]: - """ - Perform a fit on one or more DataArrays. This fit utilises a given fitter from `EasyScience.fitting.Fitter`, though - there are a few differences to a standard EasyScience fit. In particular, key-word arguments to control the - optimisation algorithm go in the `fit_kwargs` dictionary, fit function key-word arguments go in the `fn_kwargs` - and given key-word arguments control the `xarray.apply_ufunc` function. - - :param fitter: Fitting object which controls the fitting - :type fitter: EasyScience.fitting.Fitter - :param args: Arguments to go to the fit function - :type args: Any - :param dask: Dask control string. See `xarray.apply_ufunc` documentation - :type dask: str - :param fit_kwargs: Dictionary of key-word arguments to be supplied to the Fitting control - :type fit_kwargs: dict - :param fn_kwargs: Dictionary of key-words to be supplied to the fit function - :type fn_kwargs: dict - :param vectorized: Should the fit function be given dependents in a single object or split - :type vectorized: bool - :param kwargs: Key-word arguments for `xarray.apply_ufunc`. See `xarray.apply_ufunc` documentation - :type kwargs: Any - :return: Results of the fit - :rtype: List[FitResults] - """ - - if fn_kwargs is None: - fn_kwargs = {} - if fit_kwargs is None: - fit_kwargs = {} - if not isinstance(data_arrays, (list, tuple)): - data_arrays = [data_arrays] - - # In this case we are only fitting 1 dataset - if len(data_arrays) == 1: - variable_label = data_arrays[0] - dataset = self._obj[variable_label] - if self.__error_mapper.get(variable_label, False): - # Pull out any sigmas and send them to the fitter. - temp = self._obj[self.__error_mapper[variable_label]] - temp[xr.ufuncs.isnan(temp)] = 1e5 - fit_kwargs['weights'] = temp - # Perform a standard DataArray fit. - return dataset.EasyScience.fit( - fitter, - *args, - fit_kwargs=fit_kwargs, - fn_kwargs=fn_kwargs, - dask=dask, - vectorize=vectorized, - **kwargs, - ) - else: - # In this case we are fitting multiple datasets to the same fn! - bdim_f = [self._obj[p].EasyScience.fit_prep(fitter.fit_function) for p in data_arrays] - dim_names = [ - list(self._obj[p].dims.keys()) if isinstance(self._obj[p].dims, dict) else self._obj[p].dims - for p in data_arrays - ] - bdims = [bdim[0] for bdim in bdim_f] - fs = [bdim[1] for bdim in bdim_f] - old_fit_func = fitter.fit_function - - fn_array = [] - y_list = [] - for _idx, d in enumerate(bdims): - dims = self._obj[data_arrays[_idx]].dims - if isinstance(dims, dict): - dims = list(dims.keys()) - - def local_fit_func(x, *args, idx=None, **kwargs): - kwargs['vectorize'] = vectorized - res = xr.apply_ufunc( - fs[idx], - *bdims[idx], - *args, - dask=dask, - kwargs=fn_kwargs, - **kwargs, - ) - if dask != 'forbidden': - res.compute() - return res.stack(all_x=dim_names[idx]) - - y_list.append(self._obj[data_arrays[_idx]].stack(all_x=dims)) - fn_array.append(local_fit_func) - - def fit_func(x, *args, **kwargs): - res = [] - for idx in range(len(fn_array)): - res.append(fn_array[idx](x, *args, idx=idx, **kwargs)) - return xr.DataArray(np.concatenate(res, axis=0), coords={'all_x': x}, dims='all_x') - - fitter.initialize(fitter.fit_object, fit_func) - try: - if fit_kwargs.get('weights', None) is not None: - del fit_kwargs['weights'] - x = xr.DataArray(np.arange(np.sum([y.size for y in y_list])), dims='all_x') - y = xr.DataArray(np.concatenate(y_list, axis=0), coords={'all_x': x}, dims='all_x') - f_res = fitter.fit(x, y, **fit_kwargs) - f_res = check_sanity_multiple(f_res, [self._obj[p] for p in data_arrays]) - finally: - fitter.fit_function = old_fit_func - return f_res - - -@xr.register_dataarray_accessor('EasyScience') -class EasyScienceDataarrayAccessor: - """ - Accessor to extend an xarray DataArray to EasyScience. These functions can be accessed by `obj.EasyScience.func`. - - """ - - def __init__(self, xarray_obj: xr.DataArray): - self._obj = xarray_obj - self._core_object = None - self.sigma_label_prefix = 's_' - if self._obj.attrs.get('computation', None) is None: - self._obj.attrs['computation'] = { - 'precompute_func': None, - 'compute_func': None, - 'postcompute_func': None, - } - - def __empty_functional(self) -> Callable: - def outer(): - def empty_fn(input, *args, **kwargs): - return input - - return empty_fn - - class wrapper: - def __init__(obj): - obj.obj = self - obj.data = {} - obj.fn = outer() - - def __call__(self, *args, **kwargs): - return self.fn(*args, **kwargs) - - return wrapper() - - @property - def core_object(self): - """ - Get the core object associated to a DataArray. Note that this is called from a weakref. If the EasyScience obj is - garbage collected, None will be returned. - - :return: EasyScience object associated with the DataArray - :rtype: Any - """ - if self._core_object is None: - return None - return self._core_object() - - @core_object.setter - def core_object(self, new_core_object: Any): - """ - Set the core object associated to a dataset - - :param new_core_object: EasyScience object to be associated with the DataArray - :type new_core_object: Any - :return: None - :rtype: None - """ - self._core_object = weakref.ref(new_core_object) - - @property - def compute_func(self) -> Callable: - """ - Get the computational function which will be executed during a fit - - :return: Computational function applied to the DataArray - :rtype: Callable - """ - result = self._obj.attrs['computation']['compute_func'] - if result is None: - result = self.__empty_functional() - return result - - @compute_func.setter - def compute_func(self, new_computational_fn: Callable): - """ - Set the computational function which is called during a fit - - :param new_computational_fn: Computational function applied to the DataArray - :type new_computational_fn: Callable - :return: None - :rtype: None - """ - self._obj.attrs['computation']['compute_func'] = new_computational_fn - - @property - def precompute_func(self) -> Callable: - """ - Get the pre-computational function which will be executed before a fit - - :return: Computational function applied to the DataArray before fitting - :rtype: Callable - """ - result = self._obj.attrs['computation']['precompute_func'] - if result is None: - result = self.__empty_functional() - return result - - @precompute_func.setter - def precompute_func(self, new_computational_fn: Callable): - """ - Set the computational function which is called before a fit - - :param new_computational_fn: Computational function applied to the DataArray before fitting - :type new_computational_fn: Callable - :return: None - :rtype: None - """ - self._obj.attrs['computation']['precompute_func'] = new_computational_fn - - @property - def postcompute_func(self) -> Callable: - """ - Get the post-computational function which will be executed after a fit - - :return: Computational function applied to the DataArray after fitting - :rtype: Callable - """ - result = self._obj.attrs['computation']['postcompute_func'] - if result is None: - result = self.__empty_functional() - return result - - @postcompute_func.setter - def postcompute_func(self, new_computational_fn: Callable): - """ - Set the computational function which is called after a fit - - :param new_computational_fn: Computational function applied to the DataArray after fitting - :type new_computational_fn: Callable - :return: None - :rtype: None - """ - self._obj.attrs['computation']['postcompute_func'] = new_computational_fn - - def fit_prep(self, func_in: Callable, bdims=None, dask_chunks=None) -> Tuple[xr.DataArray, Callable]: - """ - Generate broadcasted coordinates for fitting and reform the fitting function into one which can handle xarrays. - - :param func_in: Function to be wrapped and made xarray fitting compatible. - :type func_in: Callable - :param bdims: Optional precomputed broadcasted dimensions. - :type bdims: xarray.DataArray - :param dask_chunks: How to split the broadcasted dimensions for dask. - :type dask_chunks: Tuple[int..] - :return: Tuple of broadcasted fit arrays and wrapped fit function. - :rtype: xarray.DataArray, Callable - """ - - if bdims is None: - coords = [self._obj.coords[da].transpose() for da in self._obj.dims] - bdims = xr.broadcast(*coords) - self._obj.attrs['computation']['compute_func'] = func_in - - def func(x, *args, vectorize: bool = False, **kwargs): - old_shape = x.shape - if not vectorize: - xs = [x_new.flatten() for x_new in [x, *args] if isinstance(x_new, np.ndarray)] - x_new = np.column_stack(xs) - if len(x_new.shape) > 1 and x_new.shape[1] == 1: - x_new = x_new.reshape((-1)) - result = self.compute_func(x_new, **kwargs) - else: - result = self.compute_func( - *[d for d in [x, args] if isinstance(d, np.ndarray)], - *[d for d in args if not isinstance(d, np.ndarray)], - **kwargs, - ) - if isinstance(result, np.ndarray): - result = result.reshape(old_shape) - result = self.postcompute_func(result) - return result - - return bdims, func - - def generate_points(self) -> xr.DataArray: - """ - Generate an expanded DataArray of points which corresponds to broadcasted dimensions (`all_x`) which have been - concatenated along the second axis (`fit_dim`). - - :return: Broadcasted and concatenated coordinates - :rtype: xarray.DataArray - """ - - coords = [self._obj.coords[da] for da in self._obj.dims] - c_array = [] - n_array = [] - for da in xr.broadcast(*coords): - c_array.append(da) - n_array.append(da.name) - - f = xr.concat(c_array, dim='fit_dim') - f = f.stack(all_x=n_array) - return f - - def fit( - self, - fitter, - *args, - fit_kwargs: dict = None, - fn_kwargs: dict = None, - vectorize: bool = False, - dask: str = 'forbidden', - **kwargs, - ) -> FitResults: - """ - Perform a fit on the given DataArray. This fit utilises a given fitter from `EasyScience.fitting.Fitter`, though - there are a few differences to a standard EasyScience fit. In particular, key-word arguments to control the - optimisation algorithm go in the `fit_kwargs` dictionary, fit function key-word arguments go in the `fn_kwargs` - and given key-word arguments control the `xarray.apply_ufunc` function. - - :param fitter: Fitting object which controls the fitting - :type fitter: EasyScience.fitting.Fitter - :param args: Arguments to go to the fit function - :type args: Any - :param dask: Dask control string. See `xarray.apply_ufunc` documentation - :type dask: str - :param fit_kwargs: Dictionary of key-word arguments to be supplied to the Fitting control - :type fit_kwargs: dict - :param fn_kwargs: Dictionary of key-words to be supplied to the fit function - :type fn_kwargs: dict - :param vectorize: Should the fit function be given dependents in a single object or split - :type vectorize: bool - :param kwargs: Key-word arguments for `xarray.apply_ufunc`. See `xarray.apply_ufunc` documentation - :type kwargs: Any - :return: Results of the fit - :rtype: FitResults - """ - - # Deal with any kwargs which has been given - if fn_kwargs is None: - fn_kwargs = {} - if fit_kwargs is None: - fit_kwargs = {} - old_fit_func = fitter.fit_function - - # Wrap and broadcast - bdims, f = self.fit_prep(fitter.fit_function) - dims = self._obj.dims - - # Find which coords we need - if isinstance(dims, dict): - dims = list(dims.keys()) - - # Wrap the wrap in a callable - def local_fit_func(x, *args, **kwargs): - """ - Function which will be called by the fitter. This will deal with sending the function the correct data. - """ - kwargs['vectorize'] = vectorize - res = xr.apply_ufunc(f, *bdims, *args, dask=dask, kwargs=fn_kwargs, **kwargs) - if dask != 'forbidden': - res.compute() - return res.stack(all_x=dims) - - # Set the new callable to the fitter and initialize - fitter.initialize(fitter.fit_object, local_fit_func) - # Make EasyScience.fitting.Fitter compatible `x` - x_for_fit = xr.concat(bdims, dim='fit_dim') - x_for_fit = x_for_fit.stack(all_x=[d.name for d in bdims]) - try: - # Deal with any sigmas if supplied - if fit_kwargs.get('weights', None) is not None: - fit_kwargs['weights'] = xr.DataArray( - np.array(fit_kwargs['weights']), - dims=['all_x'], - coords={'all_x': x_for_fit.all_x}, - ) - # Try to perform a fit - f_res = fitter.fit(x_for_fit, self._obj.stack(all_x=dims), **fit_kwargs) - f_res = check_sanity_single(f_res) - finally: - # Reset the fit function on the fitter to the old fit function. - fitter.fit_function = old_fit_func - return f_res - - -def check_sanity_single(fit_results: FitResults) -> FitResults: - """ - Convert the FitResults from a fitter compatible state to a recognizable DataArray state. - - :param fit_results: Results of a fit to be modified - :type fit_results: FitResults - :return: Modified fit results - :rtype: FitResults - """ - items = ['y_obs', 'y_calc', 'residual'] - - for item in items: - array = getattr(fit_results, item) - if isinstance(array, xr.DataArray): - array = array.unstack() - array.name = item - setattr(fit_results, item, array) - - x_array = fit_results.x - if isinstance(x_array, xr.DataArray): - fit_results.x.name = 'axes_broadcast' - x_array = x_array.unstack() - x_dataset = xr.Dataset() - dims = [dims for dims in x_array.dims if dims != 'fit_dim'] - for idx, dim in enumerate(dims): - x_dataset[dim + '_broadcast'] = x_array[idx] - x_dataset[dim + '_broadcast'].name = dim + '_broadcast' - fit_results.x_matrices = x_dataset - else: - fit_results.x_matrices = x_array - return fit_results - - -def check_sanity_multiple(fit_results: FitResults, originals: List[xr.DataArray]) -> List[FitResults]: - """ - Convert the multifit FitResults from a fitter compatible state to a list of recognizable DataArray states. - - :param fit_results: Results of a fit to be modified - :type fit_results: FitResults - :param originals: List of DataArrays which were fitted against, so we can resize and re-chunk the results - :type originals: List[xr.DataArray] - :return: Modified fit results - :rtype: List[FitResults] - """ - - return_results = [] - offset = 0 - for item in originals: - current_results = fit_results.__class__() - # Fill out the basic stuff.... - current_results.engine_result = fit_results.engine_result - current_results.minimizer_engine = fit_results.minimizer_engine - current_results.success = fit_results.success - current_results.p = fit_results.p - current_results.p0 = fit_results.p0 - # now the tricky stuff - current_results.x = item.EasyScience.generate_points() - current_results.y_obs = item.copy() - current_results.y_obs.name = f'{item.name}_obs' - current_results.y_calc = xr.DataArray( - fit_results.y_calc[offset : offset + item.size].data, - dims=item.dims, - coords=item.coords, - name=f'{item.name}_calc', - ) - offset += item.size - current_results.residual = current_results.y_calc - current_results.y_obs - current_results.residual.name = f'{item.name}_residual' - return_results.append(current_results) - return return_results diff --git a/src/easyscience/Objects/ObjectClasses.py b/src/easyscience/Objects/ObjectClasses.py deleted file mode 100644 index 376162c5..00000000 --- a/src/easyscience/Objects/ObjectClasses.py +++ /dev/null @@ -1,351 +0,0 @@ -from __future__ import annotations - -# SPDX-FileCopyrightText: 2023 EasyScience contributors -# SPDX-License-Identifier: BSD-3-Clause -# © 2021-2023 Contributors to the EasyScience project Set[str]: - base_cls = getattr(self, '__old_class__', self.__class__) - spec = getfullargspec(base_cls.__init__) - names = set(spec.args[1:]) - return names - - def __reduce__(self): - """ - Make the class picklable. - Due to the nature of the dynamic class definitions special measures need to be taken. - - :return: Tuple consisting of how to make the object - :rtype: tuple - """ - state = self.encode() - cls = getattr(self, '__old_class__', self.__class__) - return cls.from_dict, (state,) - - @property - def unique_name(self) -> str: - """Get the unique name of the object.""" - return self._unique_name - - @unique_name.setter - def unique_name(self, new_unique_name: str): - """Set a new unique name for the object. The old name is still kept in the map. - - :param new_unique_name: New unique name for the object""" - if not isinstance(new_unique_name, str): - raise TypeError('Unique name has to be a string.') - self._unique_name = new_unique_name - self._global_object.map.add_vertex(self) - - @property - def name(self) -> str: - """ - Get the common name of the object. - - :return: Common name of the object - """ - return self._name - - @name.setter - def name(self, new_name: str): - """ - Set a new common name for the object. - - :param new_name: New name for the object - :return: None - """ - self._name = new_name - - @property - def interface(self) -> iF: - """ - Get the current interface of the object - """ - return self._interface - - @interface.setter - def interface(self, new_interface: iF): - """ - Set the current interface to the object and generate bindings if possible. iF.e. - ``` - def __init__(self, bar, interface=None, **kwargs): - super().__init__(self, **kwargs) - self.foo = bar - self.interface = interface # As final step after initialization to set correct bindings. - ``` - """ - self._interface = new_interface - if new_interface is not None: - self.generate_bindings() - - def generate_bindings(self): - """ - Generate or re-generate bindings to an interface (if exists) - - :raises: AttributeError - """ - if self.interface is None: - raise AttributeError('Interface error for generating bindings. `interface` has to be set.') - interfaceable_children = [ - key - for key in self._global_object.map.get_edges(self) - if issubclass(type(self._global_object.map.get_item_by_key(key)), BasedBase) - ] - for child_key in interfaceable_children: - child = self._global_object.map.get_item_by_key(child_key) - child.interface = self.interface - self.interface.generate_bindings(self) - - def switch_interface(self, new_interface_name: str): - """ - Switch or create a new interface. - """ - if self.interface is None: - raise AttributeError('Interface error for generating bindings. `interface` has to be set.') - self.interface.switch(new_interface_name) - self.generate_bindings() - - def get_parameters(self) -> List[Parameter]: - """ - Get all parameter objects as a list. - - :return: List of `Parameter` objects. - """ - par_list = [] - for key, item in self._kwargs.items(): - if hasattr(item, 'get_parameters'): - par_list = [*par_list, *item.get_parameters()] - elif isinstance(item, Parameter): - par_list.append(item) - return par_list - - def _get_linkable_attributes(self) -> List[V]: - """ - Get all objects which can be linked against as a list. - - :return: List of `Descriptor`/`Parameter` objects. - """ - item_list = [] - for key, item in self._kwargs.items(): - if hasattr(item, '_get_linkable_attributes'): - item_list = [*item_list, *item._get_linkable_attributes()] - elif issubclass(type(item), (DescriptorBase)): - item_list.append(item) - return item_list - - def get_fit_parameters(self) -> List[Parameter]: - """ - Get all objects which can be fitted (and are not fixed) as a list. - - :return: List of `Parameter` objects which can be used in fitting. - """ - fit_list = [] - for key, item in self._kwargs.items(): - if hasattr(item, 'get_fit_parameters'): - fit_list = [*fit_list, *item.get_fit_parameters()] - elif isinstance(item, Parameter): - if item.independent and not item.fixed: - fit_list.append(item) - return fit_list - - def __dir__(self) -> Iterable[str]: - """ - This creates auto-completion and helps out in iPython notebooks. - - :return: list of function and parameter names for auto-completion - """ - new_class_objs = list(k for k in dir(self.__class__) if not k.startswith('_')) - return sorted(new_class_objs) - - def __copy__(self) -> BasedBase: - """Return a copy of the object.""" - temp = self.as_dict(skip=['unique_name']) - new_obj = self.__class__.from_dict(temp) - return new_obj - - -if TYPE_CHECKING: - B = TypeVar('B', bound=BasedBase) - BV = TypeVar('BV', bound=ComponentSerializer) - - -class BaseObj(BasedBase): - """ - This is the base class for which all higher level classes are built off of. - NOTE: This object is serializable only if parameters are supplied as: - `BaseObj(a=value, b=value)`. For `Parameter` or `Descriptor` objects we can - cheat with `BaseObj(*[Descriptor(...), Parameter(...), ...])`. - """ - - def __init__( - self, - name: str, - unique_name: Optional[str] = None, - *args: Optional[BV], - **kwargs: Optional[BV], - ): - """ - Set up the base class. - - :param name: Name of this object - :param args: Any arguments? - :param kwargs: Fields which this class should contain - """ - super(BaseObj, self).__init__(name=name, unique_name=unique_name) - # If Parameter or Descriptor is given as arguments... - for arg in args: - if issubclass(type(arg), (BaseObj, DescriptorBase)): - kwargs[getattr(arg, 'name')] = arg - # Set kwargs, also useful for serialization - known_keys = self.__dict__.keys() - self._kwargs = kwargs - for key in kwargs.keys(): - if key in known_keys: - raise AttributeError('Kwargs cannot overwrite class attributes in BaseObj.') - if issubclass(type(kwargs[key]), (BasedBase, DescriptorBase)) or 'BaseCollection' in [ - c.__name__ for c in type(kwargs[key]).__bases__ - ]: - self._global_object.map.add_edge(self, kwargs[key]) - self._global_object.map.reset_type(kwargs[key], 'created_internal') - addLoggedProp( - self, - key, - self.__getter(key), - self.__setter(key), - get_id=key, - my_self=self, - test_class=BaseObj, - ) - - def _add_component(self, key: str, component: BV) -> None: - """ - Dynamically add a component to the class. This is an internal method, though can be called remotely. - The recommended alternative is to use typing, i.e. - - class Foo(Bar): - def __init__(self, foo: Parameter, bar: Parameter): - super(Foo, self).__init__(bar=bar) - self._add_component("foo", foo) - - Goes to: - class Foo(Bar): - foo: ClassVar[Parameter] - def __init__(self, foo: Parameter, bar: Parameter): - super(Foo, self).__init__(bar=bar) - self.foo = foo - - :param key: Name of component to be added - :param component: Component to be added - :return: None - """ - self._kwargs[key] = component - self._global_object.map.add_edge(self, component) - self._global_object.map.reset_type(component, 'created_internal') - addLoggedProp( - self, - key, - self.__getter(key), - self.__setter(key), - get_id=key, - my_self=self, - test_class=BaseObj, - ) - - def __setattr__(self, key: str, value: BV) -> None: - # Assume that the annotation is a ClassVar - old_obj = None - if ( - hasattr(self.__class__, '__annotations__') - and key in self.__class__.__annotations__ - and hasattr(self.__class__.__annotations__[key], '__args__') - and issubclass( - getattr(value, '__old_class__', value.__class__), - self.__class__.__annotations__[key].__args__, - ) - ): - if issubclass(type(getattr(self, key, None)), (BasedBase, DescriptorBase)): - old_obj = self.__getattribute__(key) - self._global_object.map.prune_vertex_from_edge(self, old_obj) - self._add_component(key, value) - else: - if hasattr(self, key) and issubclass(type(value), (BasedBase, DescriptorBase)): - old_obj = self.__getattribute__(key) - self._global_object.map.prune_vertex_from_edge(self, old_obj) - self._global_object.map.add_edge(self, value) - super(BaseObj, self).__setattr__(key, value) - # Update the interface bindings if something changed (BasedBase and Descriptor) - if old_obj is not None: - old_interface = getattr(self, 'interface', None) - if old_interface is not None: - self.generate_bindings() - - def __repr__(self) -> str: - return f"{self.__class__.__name__} `{getattr(self, 'name')}`" - - @staticmethod - def __getter(key: str) -> Callable[[BV], BV]: - def getter(obj: BV) -> BV: - return obj._kwargs[key] - - return getter - - @staticmethod - def __setter(key: str) -> Callable[[BV], None]: - def setter(obj: BV, value: float) -> None: - if issubclass(obj._kwargs[key].__class__, (DescriptorBase)) and not issubclass( - value.__class__, (DescriptorBase) - ): - obj._kwargs[key].value = value - else: - obj._kwargs[key] = value - - return setter - - # @staticmethod - # def __setter(key: str) -> Callable[[Union[B, V]], None]: - # def setter(obj: Union[V, B], value: float) -> None: - # if issubclass(obj._kwargs[key].__class__, Descriptor): - # if issubclass(obj._kwargs[key].__class__, Descriptor): - # obj._kwargs[key] = value - # else: - # obj._kwargs[key].value = value - # else: - # obj._kwargs[key] = value - # - # return setter diff --git a/src/easyscience/Objects/__init__.py b/src/easyscience/Objects/__init__.py deleted file mode 100644 index 22e236a6..00000000 --- a/src/easyscience/Objects/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -# SPDX-FileCopyrightText: 2023 EasyScience contributors -# SPDX-License-Identifier: BSD-3-Clause -# © 2021-2023 Contributors to the EasyScience project -# SPDX-License-Identifier: BSD-3-Clause -# © 2021-2023 Contributors to the EasyScience project -# SPDX-License-Identifier: BSD-3-Clause -# © 2021-2023 Contributors to the EasyScience project -# SPDX-License-Identifier: BSD-3-Clause -# © 2021-2023 Contributors to the EasyScience project -# SPDX-License-Identifier: BSD-3-Clause -# © 2021-2023 Contributors to the EasyScience project +# SPDX-License-Identifier: BSD-3-Clause +# © 2021-2025 Contributors to the EasyScience project Set[str]: + base_cls = getattr(self, '__old_class__', self.__class__) + spec = getfullargspec(base_cls.__init__) + names = set(spec.args[1:]) + return names + + def __reduce__(self): + """ + Make the class picklable. + Due to the nature of the dynamic class definitions special measures need to be taken. + + :return: Tuple consisting of how to make the object + :rtype: tuple + """ + state = self.encode() + cls = getattr(self, '__old_class__', self.__class__) + return cls.from_dict, (state,) + + @property + def unique_name(self) -> str: + """Get the unique name of the object.""" + return self._unique_name + + @unique_name.setter + def unique_name(self, new_unique_name: str): + """Set a new unique name for the object. The old name is still kept in the map. + + :param new_unique_name: New unique name for the object""" + if not isinstance(new_unique_name, str): + raise TypeError('Unique name has to be a string.') + self._unique_name = new_unique_name + self._global_object.map.add_vertex(self) + + @property + def name(self) -> str: + """ + Get the common name of the object. + + :return: Common name of the object + """ + return self._name + + @name.setter + def name(self, new_name: str): + """ + Set a new common name for the object. + + :param new_name: New name for the object + :return: None + """ + self._name = new_name + + @property + def interface(self) -> InterfaceFactoryTemplate: + """ + Get the current interface of the object + """ + return self._interface + + @interface.setter + def interface(self, new_interface: InterfaceFactoryTemplate): + """ + Set the current interface to the object and generate bindings if possible. iF.e. + ``` + def __init__(self, bar, interface=None, **kwargs): + super().__init__(self, **kwargs) + self.foo = bar + self.interface = interface # As final step after initialization to set correct bindings. + ``` + """ + self._interface = new_interface + if new_interface is not None: + self.generate_bindings() + + def generate_bindings(self): + """ + Generate or re-generate bindings to an interface (if exists) + + :raises: AttributeError + """ + if self.interface is None: + raise AttributeError('Interface error for generating bindings. `interface` has to be set.') + interfaceable_children = [ + key + for key in self._global_object.map.get_edges(self) + if issubclass(type(self._global_object.map.get_item_by_key(key)), BasedBase) + ] + for child_key in interfaceable_children: + child = self._global_object.map.get_item_by_key(child_key) + child.interface = self.interface + self.interface.generate_bindings(self) + + def switch_interface(self, new_interface_name: str): + """ + Switch or create a new interface. + """ + if self.interface is None: + raise AttributeError('Interface error for generating bindings. `interface` has to be set.') + self.interface.switch(new_interface_name) + self.generate_bindings() + + def get_parameters(self) -> List[Parameter]: + """ + Get all parameter objects as a list. + + :return: List of `Parameter` objects. + """ + par_list = [] + for key, item in self._kwargs.items(): + if hasattr(item, 'get_parameters'): + par_list = [*par_list, *item.get_parameters()] + elif isinstance(item, Parameter): + par_list.append(item) + return par_list + + def _get_linkable_attributes(self) -> List[DescriptorBase]: + """ + Get all objects which can be linked against as a list. + + :return: List of `Descriptor`/`Parameter` objects. + """ + item_list = [] + for key, item in self._kwargs.items(): + if hasattr(item, '_get_linkable_attributes'): + item_list = [*item_list, *item._get_linkable_attributes()] + elif issubclass(type(item), (DescriptorBase)): + item_list.append(item) + return item_list + + def get_fit_parameters(self) -> List[Parameter]: + """ + Get all objects which can be fitted (and are not fixed) as a list. + + :return: List of `Parameter` objects which can be used in fitting. + """ + fit_list = [] + for key, item in self._kwargs.items(): + if hasattr(item, 'get_fit_parameters'): + fit_list = [*fit_list, *item.get_fit_parameters()] + elif isinstance(item, Parameter): + if item.independent and not item.fixed: + fit_list.append(item) + return fit_list + + def __dir__(self) -> Iterable[str]: + """ + This creates auto-completion and helps out in iPython notebooks. + + :return: list of function and parameter names for auto-completion + """ + new_class_objs = list(k for k in dir(self.__class__) if not k.startswith('_')) + return sorted(new_class_objs) + + def __copy__(self) -> BasedBase: + """Return a copy of the object.""" + temp = self.as_dict(skip=['unique_name']) + new_obj = self.__class__.from_dict(temp) + return new_obj + + diff --git a/src/easyscience/Objects/Groups.py b/src/easyscience/base_classes/collection_base.py similarity index 88% rename from src/easyscience/Objects/Groups.py rename to src/easyscience/base_classes/collection_base.py index 90d4f0c6..45d6f39c 100644 --- a/src/easyscience/Objects/Groups.py +++ b/src/easyscience/base_classes/collection_base.py @@ -1,12 +1,9 @@ -# SPDX-FileCopyrightText: 2023 EasyScience contributors +# SPDX-FileCopyrightText: 2025 EasyScience contributors # SPDX-License-Identifier: BSD-3-Clause -# © 2021-2023 Contributors to the EasyScience project None: + def insert(self, index: int, value: Union[DescriptorBase, BasedBase]) -> None: """ Insert an object into the collection at an index. @@ -122,14 +118,14 @@ def insert(self, index: int, value: Union[V, B]) -> None: else: raise AttributeError('Only EasyScience objects can be put into an EasyScience group') - def __getitem__(self, idx: Union[int, slice]) -> Union[V, B]: + def __getitem__(self, idx: Union[int, slice]) -> Union[DescriptorBase, BasedBase]: """ Get an item in the collection based on its index. :param idx: index or slice of the collection. :type idx: Union[int, slice] :return: Object at index `idx` - :rtype: Union[Parameter, Descriptor, BaseObj, 'BaseCollection'] + :rtype: Union[Parameter, Descriptor, ObjBase, 'CollectionBase'] """ if isinstance(idx, slice): start, stop, step = idx.indices(len(self)) @@ -156,7 +152,7 @@ def __getitem__(self, idx: Union[int, slice]) -> Union[V, B]: keys = list(self._kwargs.keys()) return self._kwargs[keys[idx]] - def __setitem__(self, key: int, value: Union[B, V]) -> None: + def __setitem__(self, key: int, value: Union[BasedBase, DescriptorBase]) -> None: """ Set an item via it's index. @@ -238,7 +234,7 @@ def data(self) -> Tuple: def __repr__(self) -> str: return f"{self.__class__.__name__} `{getattr(self, 'name')}` of length {len(self)}" - def sort(self, mapping: Callable[[Union[B, V]], Any], reverse: bool = False) -> None: + def sort(self, mapping: Callable[[Union[BasedBase, DescriptorBase]], Any], reverse: bool = False) -> None: """ Sort the collection according to the given mapping. diff --git a/src/easyscience/base_classes/obj_base.py b/src/easyscience/base_classes/obj_base.py new file mode 100644 index 00000000..33316259 --- /dev/null +++ b/src/easyscience/base_classes/obj_base.py @@ -0,0 +1,162 @@ +from __future__ import annotations + +# SPDX-FileCopyrightText: 2025 EasyScience contributors +# SPDX-License-Identifier: BSD-3-Clause +# © 2021-2025 Contributors to the EasyScience project None: + """ + Dynamically add a component to the class. This is an internal method, though can be called remotely. + The recommended alternative is to use typing, i.e. + + class Foo(Bar): + def __init__(self, foo: Parameter, bar: Parameter): + super(Foo, self).__init__(bar=bar) + self._add_component("foo", foo) + + Goes to: + class Foo(Bar): + foo: ClassVar[Parameter] + def __init__(self, foo: Parameter, bar: Parameter): + super(Foo, self).__init__(bar=bar) + self.foo = foo + + :param key: Name of component to be added + :param component: Component to be added + :return: None + """ + self._kwargs[key] = component + self._global_object.map.add_edge(self, component) + self._global_object.map.reset_type(component, 'created_internal') + addLoggedProp( + self, + key, + self.__getter(key), + self.__setter(key), + get_id=key, + my_self=self, + test_class=ObjBase, + ) + + def __setattr__(self, key: str, value: SerializerComponent) -> None: + # Assume that the annotation is a ClassVar + old_obj = None + if ( + hasattr(self.__class__, '__annotations__') + and key in self.__class__.__annotations__ + and hasattr(self.__class__.__annotations__[key], '__args__') + and issubclass( + getattr(value, '__old_class__', value.__class__), + self.__class__.__annotations__[key].__args__, + ) + ): + if issubclass(type(getattr(self, key, None)), (BasedBase, DescriptorBase)): + old_obj = self.__getattribute__(key) + self._global_object.map.prune_vertex_from_edge(self, old_obj) + self._add_component(key, value) + else: + if hasattr(self, key) and issubclass(type(value), (BasedBase, DescriptorBase)): + old_obj = self.__getattribute__(key) + self._global_object.map.prune_vertex_from_edge(self, old_obj) + self._global_object.map.add_edge(self, value) + super(ObjBase, self).__setattr__(key, value) + # Update the interface bindings if something changed (BasedBase and Descriptor) + if old_obj is not None: + old_interface = getattr(self, 'interface', None) + if old_interface is not None: + self.generate_bindings() + + def __repr__(self) -> str: + return f"{self.__class__.__name__} `{getattr(self, 'name')}`" + + @staticmethod + def __getter(key: str) -> Callable[[SerializerComponent], SerializerComponent]: + def getter(obj: SerializerComponent) -> SerializerComponent: + return obj._kwargs[key] + + return getter + + @staticmethod + def __setter(key: str) -> Callable[[SerializerComponent], None]: + def setter(obj: SerializerComponent, value: float) -> None: + if issubclass(obj._kwargs[key].__class__, (DescriptorBase)) and not issubclass( + value.__class__, (DescriptorBase) + ): + obj._kwargs[key].value = value + else: + obj._kwargs[key] = value + + return setter + + # @staticmethod + # def __setter(key: str) -> Callable[[Union[B, V]], None]: + # def setter(obj: Union[V, B], value: float) -> None: + # if issubclass(obj._kwargs[key].__class__, Descriptor): + # if issubclass(obj._kwargs[key].__class__, Descriptor): + # obj._kwargs[key] = value + # else: + # obj._kwargs[key].value = value + # else: + # obj._kwargs[key] = value + # + # return setter diff --git a/src/easyscience/fitting/calculators/__init__.py b/src/easyscience/fitting/calculators/__init__.py new file mode 100644 index 00000000..a3ca5d43 --- /dev/null +++ b/src/easyscience/fitting/calculators/__init__.py @@ -0,0 +1,7 @@ +# SPDX-FileCopyrightText: 2025 EasyScience contributors +# SPDX-License-Identifier: BSD-3-Clause +# © 2021-2025 Contributors to the EasyScience project +# SPDX-FileCopyrightText: 2025 EasyScience contributors # SPDX-License-Identifier: BSD-3-Clause -# © 2021-2023 Contributors to the EasyScience project List[str]: return [self.return_name(this_interface) for this_interface in self._interfaces] @property - def current_interface(self) -> _C: + def current_interface(self) -> ABCMeta: """ Returns the constructor for the currently selected interface. @@ -174,7 +168,7 @@ def generate_bindings(self, model, *args, ifun=None, **kwargs): prop._callback = item.make_prop(item_key) prop._callback.fset(prop_value) - def __call__(self, *args, **kwargs) -> _M: + def __call__(self, *args, **kwargs) -> None: return self.__interface_obj def __reduce__(self): @@ -233,6 +227,3 @@ def set_value(value): self.setter_fn(self.link_name, **{inner_key: value}) return set_value - - -iF = TypeVar('iF', bound=InterfaceFactoryTemplate) diff --git a/src/easyscience/fitting/fitter.py b/src/easyscience/fitting/fitter.py index 53007879..0cb67016 100644 --- a/src/easyscience/fitting/fitter.py +++ b/src/easyscience/fitting/fitter.py @@ -1,6 +1,6 @@ -# SPDX-FileCopyrightText: 2023 EasyScience contributors +# SPDX-FileCopyrightText: 2025 EasyScience contributors # SPDX-License-Identifier: BSD-3-Clause -# © 2021-2023 Contributors to the EasyScience project +# SPDX-FileCopyrightText: 2025 EasyScience contributors # SPDX-License-Identifier: BSD-3-Clause -# © 2021-2023 Contributors to the EasyScience project +# SPDX-FileCopyrightText: 2025 EasyScience contributors # SPDX-License-Identifier: BSD-3-Clause -# © 2021-2023 Contributors to the EasyScience project List[str]: @staticmethod @abstractmethod - def convert_to_par_object(obj): # todo after constraint changes, add type hint: obj: BaseObj + def convert_to_par_object(obj): # todo after constraint changes, add type hint: obj: ObjBase """ - Convert an `EasyScience.Objects.Base.Parameter` object to an engine Parameter object. + Convert an `EasyScience.variable.Parameter` object to an engine Parameter object. """ def _prepare_parameters(self, parameters: dict[str, float]) -> dict[str, float]: diff --git a/src/easyscience/fitting/minimizers/minimizer_bumps.py b/src/easyscience/fitting/minimizers/minimizer_bumps.py index 14df1d0f..ed2b140b 100644 --- a/src/easyscience/fitting/minimizers/minimizer_bumps.py +++ b/src/easyscience/fitting/minimizers/minimizer_bumps.py @@ -1,6 +1,6 @@ -# SPDX-FileCopyrightText: 2023 EasyScience contributors +# SPDX-FileCopyrightText: 2025 EasyScience contributors # SPDX-License-Identifier: BSD-3-Clause -# © 2021-2023 Contributors to the EasyScience project List[BumpsPara :rtype: List[BumpsParameter] """ if par_list is None: - # Assume that we have a BaseObj for which we can obtain a list + # Assume that we have a ObjBase for which we can obtain a list par_list = self._object.get_fit_parameters() pars_obj = [self.__class__.convert_to_par_object(obj) for obj in par_list] return pars_obj @@ -160,7 +160,7 @@ def convert_to_pars_obj(self, par_list: Optional[List] = None) -> List[BumpsPara @staticmethod def convert_to_par_object(obj) -> BumpsParameter: """ - Convert an `EasyScience.Objects.Base.Parameter` object to a bumps Parameter object + Convert an `EasyScience.variable.Parameter` object to a bumps Parameter object :return: bumps Parameter compatible object. :rtype: BumpsParameter diff --git a/src/easyscience/fitting/minimizers/minimizer_dfo.py b/src/easyscience/fitting/minimizers/minimizer_dfo.py index 27f7eba4..bcc5afea 100644 --- a/src/easyscience/fitting/minimizers/minimizer_dfo.py +++ b/src/easyscience/fitting/minimizers/minimizer_dfo.py @@ -1,6 +1,6 @@ -# SPDX-FileCopyrightText: 2023 EasyScience contributors +# SPDX-FileCopyrightText: 2025 EasyScience contributors # SPDX-License-Identifier: BSD-3-Clause -# © 2021-2023 Contributors to the EasyScience project +# SPDX-FileCopyrightText: 2025 EasyScience contributors # SPDX-License-Identifier: BSD-3-Clause -# © 2021-2023 Contributors to the EasyScience project L :return: lmfit Parameters compatible object """ if parameters is None: - # Assume that we have a BaseObj for which we can obtain a list + # Assume that we have a ObjBase for which we can obtain a list parameters = self._object.get_fit_parameters() lm_parameters = LMParameters().add_many([self.convert_to_par_object(parameter) for parameter in parameters]) return lm_parameters @@ -175,7 +175,7 @@ def convert_to_pars_obj(self, parameters: Optional[List[Parameter]] = None) -> L @staticmethod def convert_to_par_object(parameter: Parameter) -> LMParameter: """ - Convert an `EasyScience.Objects.Base.Parameter` object to a lmfit Parameter object. + Convert an EasyScience Parameter object to a lmfit Parameter object. :return: lmfit Parameter compatible object. :rtype: LMParameter diff --git a/src/easyscience/fitting/multi_fitter.py b/src/easyscience/fitting/multi_fitter.py index c812ff0e..a30bdcec 100644 --- a/src/easyscience/fitting/multi_fitter.py +++ b/src/easyscience/fitting/multi_fitter.py @@ -1,14 +1,13 @@ -# SPDX-FileCopyrightText: 2023 EasyScience contributors +# SPDX-FileCopyrightText: 2025 EasyScience contributors # SPDX-License-Identifier: BSD-3-Clause -# © 2021-2023 Contributors to the EasyScience project +# SPDX-FileCopyrightText: 2025 EasyScience contributors # SPDX-License-Identifier: BSD-3-Clause -# © 2021-2023 Contributors to the EasyScience project +# SPDX-FileCopyrightText: 2025 EasyScience contributors # SPDX-License-Identifier: BSD-3-Clause -# © 2021-2023 Contributors to the EasyScience project +# SPDX-FileCopyrightText: 2025 EasyScience contributors # SPDX-License-Identifier: BSD-3-Clause -# © 2021-2023 Contributors to the EasyScience project +# SPDX-FileCopyrightText: 2025 EasyScience contributors # SPDX-License-Identifier: BSD-3-Clause -# © 2021-2023 Contributors to the EasyScience project +# SPDX-FileCopyrightText: 2025 EasyScience contributors # SPDX-License-Identifier: BSD-3-Clause -# © 2021-2023 Contributors to the EasyScience project +# SPDX-FileCopyrightText: 2025 EasyScience contributors # SPDX-License-Identifier: BSD-3-Clause -# © 2021-2023 Contributors to the EasyScience project +# SPDX-FileCopyrightText: 2025 EasyScience contributors # SPDX-License-Identifier: BSD-3-Clause -# © 2021-2023 Contributors to the EasyScience project +# SPDX-License-Identifier: BSD-3-Clause +# © 2021-2025 Contributors to the EasyScience project +# SPDX-FileCopyrightText: 2025 EasyScience contributors # SPDX-License-Identifier: BSD-3-Clause -# © 2021-2023 Contributors to the EasyScience project any: + def encode(self, obj: SerializerComponent, skip: Optional[List[str]] = None, **kwargs) -> any: """ Abstract implementation of an encoder. @@ -56,7 +51,7 @@ def encode(self, obj: BV, skip: Optional[List[str]] = None, **kwargs) -> any: @abstractmethod def decode(cls, obj: Any) -> Any: """ - Re-create an EasyScience object from the output of an encoder. The default decoder is `DictSerializer`. + Re-create an EasyScience object from the output of an encoder. The default decoder is `SerializerDict`. :param obj: encoded EasyScience object :return: Reformed EasyScience object @@ -83,7 +78,7 @@ def _encode_objs(obj: Any) -> Dict[str, Any]: :param obj: any object to be encoded :param skip: List of field names as strings to skip when forming the encoded object - :param kwargs: Key-words to pass to `BaseEncoderDecoder` + :param kwargs: Key-words to pass to `SerializerBase` :return: JSON encoded dictionary """ @@ -117,7 +112,7 @@ def _encode_objs(obj: Any) -> Dict[str, Any]: def _convert_to_dict( self, - obj: BV, + obj: SerializerComponent, skip: Optional[List[str]] = None, full_encode: bool = False, **kwargs, @@ -129,20 +124,20 @@ def _convert_to_dict( skip = [] if full_encode: - new_obj = BaseEncoderDecoder._encode_objs(obj) + new_obj = SerializerBase._encode_objs(obj) if new_obj is not obj: return new_obj - d = {'@module': get_class_module(obj), '@class': obj.__class__.__name__} + d = {'@module': obj.__module__, '@class': obj.__class__.__name__} try: - parent_module = get_class_module(obj).split('.')[0] + parent_module = obj.__module__.split('.')[0] module_version = import_module(parent_module).__version__ # type: ignore d['@version'] = '{}'.format(module_version) except (AttributeError, ImportError): d['@version'] = None # type: ignore - spec, args = BaseEncoderDecoder.get_arg_spec(obj.__class__.__init__) + spec, args = SerializerBase.get_arg_spec(obj.__class__.__init__) if hasattr(obj, '_arg_spec'): args = obj._arg_spec @@ -150,7 +145,7 @@ def _convert_to_dict( def runner(o): if full_encode: - return BaseEncoderDecoder._encode_objs(o) + return SerializerBase._encode_objs(o) else: return o @@ -194,7 +189,7 @@ def runner(o): 'determine the dict format. Alternatively, ' 'you can implement both as_dict and from_dict.' ) - d[c] = recursive_encoder(a, skip=skip, encoder=self, full_encode=full_encode, **kwargs) + d[c] = self._recursive_encoder(a, skip=skip, encoder=self, full_encode=full_encode, **kwargs) if spec.varargs is not None and getattr(obj, spec.varargs, None) is not None: d.update({spec.varargs: getattr(obj, spec.varargs)}) if hasattr(obj, '_kwargs'): @@ -211,7 +206,7 @@ def runner(o): continue vv = redirect[k](obj) v_ = runner(vv) - d[k] = recursive_encoder( + d[k] = self._recursive_encoder( v_, skip=skip, encoder=self, @@ -240,9 +235,6 @@ def _convert_from_dict(d): if '@module' in d and '@class' in d: modname = d['@module'] classname = d['@class'] - # if classname in DictSerializer.REDIRECT.get(modname, {}): - # modname = DictSerializer.REDIRECT[modname][classname]["@module"] - # classname = DictSerializer.REDIRECT[modname][classname]["@class"] else: modname = None classname = None @@ -257,7 +249,7 @@ def _convert_from_dict(d): mod = __import__(modname, globals(), locals(), [classname], 0) if hasattr(mod, classname): cls_ = getattr(mod, classname) - data = {k: BaseEncoderDecoder._convert_from_dict(v) for k, v in d.items() if not k.startswith('@')} + data = {k: SerializerBase._convert_from_dict(v) for k, v in d.items() if not k.startswith('@')} return cls_(**data) elif np is not None and modname == 'numpy' and classname == 'array': if d['dtype'].startswith('complex'): @@ -265,38 +257,25 @@ def _convert_from_dict(d): return np.array(d['data'], dtype=d['dtype']) if issubclass(T_, (list, MutableSequence)): - return [BaseEncoderDecoder._convert_from_dict(x) for x in d] + return [SerializerBase._convert_from_dict(x) for x in d] return d - -if TYPE_CHECKING: - _ = TypeVar('EC', bound=BaseEncoderDecoder) - EC = Type[_] - - -def recursive_encoder(obj, skip: List[str] = [], encoder=None, full_encode=False, **kwargs): - """ - Walk through an object encoding it - """ - if encoder is None: - encoder = BaseEncoderDecoder() - T_ = type(obj) - if issubclass(T_, (list, tuple, MutableSequence)): - # Is it a core MutableSequence? + def _recursive_encoder(self, obj, skip: List[str] = [], encoder=None, full_encode=False, **kwargs): + """ + Walk through an object encoding it + """ + if encoder is None: + encoder = SerializerBase() + T_ = type(obj) + if issubclass(T_, (list, tuple, MutableSequence)): + # Is it a core MutableSequence? + if hasattr(obj, 'encode') and obj.__class__.__module__ != 'builtins': # strings have encode + return encoder._convert_to_dict(obj, skip, full_encode, **kwargs) + else: + return [self._recursive_encoder(it, skip, encoder, full_encode, **kwargs) for it in obj] + if isinstance(obj, dict): + return {kk: self._recursive_encoder(vv, skip, encoder, full_encode, **kwargs) for kk, vv in obj.items()} if hasattr(obj, 'encode') and obj.__class__.__module__ != 'builtins': # strings have encode return encoder._convert_to_dict(obj, skip, full_encode, **kwargs) - else: - return [recursive_encoder(it, skip, encoder, full_encode, **kwargs) for it in obj] - if isinstance(obj, dict): - return {kk: recursive_encoder(vv, skip, encoder, full_encode, **kwargs) for kk, vv in obj.items()} - if hasattr(obj, 'encode') and obj.__class__.__module__ != 'builtins': # strings have encode - return encoder._convert_to_dict(obj, skip, full_encode, **kwargs) - return obj + return obj - -def get_class_module(obj): - """ - Returns the REAL module of the class of the object. - """ - c = getattr(obj, '__old_class__', obj.__class__) - return c.__module__ diff --git a/src/easyscience/io/serializer_component.py b/src/easyscience/io/serializer_component.py new file mode 100644 index 00000000..15995412 --- /dev/null +++ b/src/easyscience/io/serializer_component.py @@ -0,0 +1,80 @@ +# SPDX-FileCopyrightText: 2025 EasyScience contributors +# SPDX-License-Identifier: BSD-3-Clause +# © 2021-2025 Contributors to the EasyScience project Any: + """ + Use an encoder to covert an EasyScience object into another format. Default is to a dictionary using `SerializerDict`. + + :param skip: List of field names as strings to skip when forming the encoded object + :param encoder: The encoder to be used for encoding the data. Default is `SerializerDict` + :param kwargs: Any additional key word arguments to be passed to the encoder + :return: encoded object containing all information to reform an EasyScience object. + """ + if encoder is None: + encoder = SerializerDict + encoder_obj = encoder() + return encoder_obj.encode(self, skip=skip, **kwargs) + + @classmethod + def decode(cls, obj: Any, decoder: Optional[SerializerBase] = None) -> Any: + """ + Re-create an EasyScience object from the output of an encoder. The default decoder is `SerializerDict`. + + :param obj: encoded EasyScience object + :param decoder: decoder to be used to reform the EasyScience object + :return: Reformed EasyScience object + """ + + if decoder is None: + decoder = SerializerDict + return decoder.decode(obj) + + def as_dict(self, skip: Optional[List[str]] = None) -> Dict[str, Any]: + """ + Convert an EasyScience object into a full dictionary using `SerializerDict`. + This is a shortcut for ```obj.encode(encoder=SerializerDict)``` + + :param skip: List of field names as strings to skip when forming the dictionary + :return: encoded object containing all information to reform an EasyScience object. + """ + + return self.encode(skip=skip, encoder=SerializerDict) + + @classmethod + def from_dict(cls, obj_dict: Dict[str, Any]) -> None: + """ + Re-create an EasyScience object from a full encoded dictionary. + + :param obj_dict: dictionary containing the serialized contents (from `SerializerDict`) of an EasyScience object + :return: Reformed EasyScience object + """ + + return cls.decode(obj_dict, decoder=SerializerDict) diff --git a/src/easyscience/io/serializer_dict.py b/src/easyscience/io/serializer_dict.py new file mode 100644 index 00000000..95b28f09 --- /dev/null +++ b/src/easyscience/io/serializer_dict.py @@ -0,0 +1,54 @@ +from __future__ import annotations + +__author__ = "https://github.com/materialsvirtuallab/monty/blob/master/monty/json.py" +__version__ = "3.0.0" +# SPDX-FileCopyrightText: 2025 EasyScience contributors +# SPDX-License-Identifier: BSD-3-Clause +# © 2021-2025 Contributors to the EasyScience project SerializerComponent: + """ + Re-create an EasyScience object from the dictionary representation. + + :param d: Dict representation of an EasyScience object. + :return: EasyScience object. + """ + + return SerializerBase._convert_from_dict(d) \ No newline at end of file diff --git a/src/easyscience/Objects/job/__init__.py b/src/easyscience/job/__init__.py similarity index 100% rename from src/easyscience/Objects/job/__init__.py rename to src/easyscience/job/__init__.py diff --git a/src/easyscience/Objects/job/analysis.py b/src/easyscience/job/analysis.py similarity index 72% rename from src/easyscience/Objects/job/analysis.py rename to src/easyscience/job/analysis.py index 1d99ece1..512ae556 100644 --- a/src/easyscience/Objects/job/analysis.py +++ b/src/easyscience/job/analysis.py @@ -1,18 +1,16 @@ -# SPDX-FileCopyrightText: 2023 EasyScience contributors +# SPDX-FileCopyrightText: 2025 EasyScience contributors # SPDX-License-Identifier: BSD-3-Clause -# © 2021-2023 Contributors to the EasyScience project np.ndarray: raise NotImplementedError("calculate_theory not implemented") @abstractmethod def fit(self, - x: Union[xr.DataArray, np.ndarray], - y: Union[xr.DataArray, np.ndarray], - e: Union[xr.DataArray, np.ndarray], + x: np.ndarray, + y: np.ndarray, + e: np.ndarray, **kwargs) -> None: raise NotImplementedError("fit not implemented") diff --git a/src/easyscience/Objects/job/experiment.py b/src/easyscience/job/experiment.py similarity index 68% rename from src/easyscience/Objects/job/experiment.py rename to src/easyscience/job/experiment.py index 1f2a63aa..807e0572 100644 --- a/src/easyscience/Objects/job/experiment.py +++ b/src/easyscience/job/experiment.py @@ -1,12 +1,12 @@ -# SPDX-FileCopyrightText: 2023 EasyScience contributors +# SPDX-FileCopyrightText: 2025 EasyScience contributors # SPDX-License-Identifier: BSD-3-Clause -# © 2021-2023 Contributors to the EasyScience project +# SPDX-FileCopyrightText: 2025 EasyScience contributors # SPDX-License-Identifier: BSD-3-Clause -# © 2021-2023 Contributors to the EasyScience project +# SPDX-FileCopyrightText: 2025 EasyScience contributors # SPDX-License-Identifier: BSD-3-Clause -# © 2021-2023 Contributors to the EasyScience project +# SPDX-FileCopyrightText: 2025 EasyScience contributors # SPDX-License-Identifier: BSD-3-Clause -# © 2021-2023 Contributors to the EasyScience project BV: + def decode(cls, d: Dict) -> ComponentSerializer: """ :param d: Dict representation. :return: ComponentSerializer class. @@ -55,7 +55,7 @@ def decode(cls, d: Dict) -> BV: return BaseEncoderDecoder._convert_from_dict(d) @classmethod - def from_dict(cls, d: Dict[str, Any]) -> BV: + def from_dict(cls, d: Dict[str, Any]) -> ComponentSerializer: """ :param d: Dict representation. :return: ComponentSerializer class. @@ -70,7 +70,7 @@ class DataDictSerializer(DictSerializer): def encode( self, - obj: BV, + obj: ComponentSerializer, skip: Optional[List[str]] = None, full_encode: bool = False, **kwargs, @@ -95,7 +95,7 @@ def encode( return self._parse_dict(encoded) @classmethod - def decode(cls, d: Dict[str, Any]) -> BV: + def decode(cls, d: Dict[str, Any]) -> ComponentSerializer: """ This function is not implemented as a data dictionary does not contain the necessary information to re-form an EasyScience object. diff --git a/src/easyscience/Utils/io/json.py b/src/easyscience/legacy/json.py similarity index 57% rename from src/easyscience/Utils/io/json.py rename to src/easyscience/legacy/json.py index 6307c69b..6464e52a 100644 --- a/src/easyscience/Utils/io/json.py +++ b/src/easyscience/legacy/json.py @@ -3,27 +3,23 @@ __author__ = 'https://github.com/materialsvirtuallab/monty/blob/master/monty/json.py' __version__ = '3.0.0' -# SPDX-FileCopyrightText: 2023 EasyScience contributors +# SPDX-FileCopyrightText: 2025 EasyScience contributors # SPDX-License-Identifier: BSD-3-Clause -# © 2021-2023 Contributors to the EasyScience project str: + def encode(self, obj: ComponentSerializer, skip: List[str] = []) -> str: """ Returns a json string representation of the ComponentSerializer object. """ @@ -35,16 +31,16 @@ def encode(self, obj: BV, skip: List[str] = []) -> str: return json.dumps(obj, cls=ENCODER) @classmethod - def decode(cls, data: str) -> BV: + def decode(cls, data: str) -> ComponentSerializer: return json.loads(data, cls=JsonDecoderTemplate) class JsonDataSerializer(BaseEncoderDecoder): - def encode(self, obj: BV, skip: List[str] = []) -> str: + def encode(self, obj: ComponentSerializer, skip: List[str] = []) -> str: """ Returns a json string representation of the ComponentSerializer object. """ - from easyscience.Utils.io.dict import DataDictSerializer + from .dict import DataDictSerializer ENCODER = type( JsonEncoderTemplate.__name__, @@ -60,7 +56,7 @@ def encode(self, obj: BV, skip: List[str] = []) -> str: return json.dumps(obj, cls=ENCODER) @classmethod - def decode(cls, data: str) -> BV: + def decode(cls, data: str) -> ComponentSerializer: raise NotImplementedError('It is not possible to reconstitute objects from data only objects.') @@ -121,51 +117,3 @@ def decode(self, s): """ d = json.JSONDecoder.decode(self, s) return self.__class__._converter(d) - - -def jsanitize(obj, strict=False, allow_bson=False): - """ - This method cleans an input json-like object, either a list or a dict or - some sequence, nested or otherwise, by converting all non-string - dictionary keys (such as int and float) to strings, and also recursively - encodes all objects using Monty's as_dict() protocol. - - Args: - obj: input json-like object. - strict (bool): This parameters sets the behavior when jsanitize - encounters an object it does not understand. If strict is True, - jsanitize will try to get the as_dict() attribute of the object. If - no such attribute is found, an attribute error will be thrown. If - strict is False, jsanitize will simply call str(object) to convert - the object to a string representation. - allow_bson (bool): This parameters sets the behavior when jsanitize - encounters an bson supported type such as objectid and datetime. If - True, such bson types will be ignored, allowing for proper - insertion into MongoDb databases. - - Returns: - Sanitized dict that can be json serialized. - """ - # if allow_bson and ( - # isinstance(obj, (datetime.datetime, bytes)) - # or (bson is not None and isinstance(obj, bson.objectid.ObjectId)) - # ): - # return obj - if isinstance(obj, (list, tuple)): - return [jsanitize(i, strict=strict, allow_bson=allow_bson) for i in obj] - if np is not None and isinstance(obj, np.ndarray): - return [jsanitize(i, strict=strict, allow_bson=allow_bson) for i in obj.tolist()] - if isinstance(obj, dict): - return {k.__str__(): jsanitize(v, strict=strict, allow_bson=allow_bson) for k, v in obj.items()} - if isinstance(obj, (int, float)): - return obj - if obj is None: - return None - - if not strict: - return obj.__str__() - - if isinstance(obj, str): - return obj.__str__() - - return jsanitize(obj.as_dict(), strict=strict, allow_bson=allow_bson) diff --git a/src/easyscience/Objects/core.py b/src/easyscience/legacy/legacy_core.py similarity index 90% rename from src/easyscience/Objects/core.py rename to src/easyscience/legacy/legacy_core.py index 0754040f..daddabee 100644 --- a/src/easyscience/Objects/core.py +++ b/src/easyscience/legacy/legacy_core.py @@ -1,12 +1,9 @@ -# SPDX-FileCopyrightText: 2023 EasyScience contributors +# SPDX-FileCopyrightText: 2025 EasyScience contributors # SPDX-License-Identifier: BSD-3-Clause -# © 2021-2023 Contributors to the EasyScience project Any: + def encode(self, skip: Optional[List[str]] = None, encoder: Optional[BaseEncoderDecoder] = None, **kwargs) -> Any: """ Use an encoder to covert an EasyScience object into another format. Default is to a dictionary using `DictSerializer`. @@ -50,7 +47,7 @@ def encode(self, skip: Optional[List[str]] = None, encoder: Optional[EC] = None, return encoder_obj.encode(self, skip=skip, **kwargs) @classmethod - def decode(cls, obj: Any, decoder: Optional[EC] = None) -> Any: + def decode(cls, obj: Any, decoder: Optional[BaseEncoderDecoder] = None) -> Any: """ Re-create an EasyScience object from the output of an encoder. The default decoder is `DictSerializer`. @@ -85,7 +82,7 @@ def from_dict(cls, obj_dict: Dict[str, Any]) -> None: return cls.decode(obj_dict, decoder=DictSerializer) - def encode_data(self, skip: Optional[List[str]] = None, encoder: Optional[EC] = None, **kwargs) -> Any: + def encode_data(self, skip: Optional[List[str]] = None, encoder: Optional[BaseEncoderDecoder] = None, **kwargs) -> Any: """ Returns just the data in an EasyScience object win the format specified by an encoder. diff --git a/src/easyscience/Utils/io/xml.py b/src/easyscience/legacy/xml.py similarity index 92% rename from src/easyscience/Utils/io/xml.py rename to src/easyscience/legacy/xml.py index 7179b259..e0898409 100644 --- a/src/easyscience/Utils/io/xml.py +++ b/src/easyscience/legacy/xml.py @@ -1,12 +1,9 @@ -# SPDX-FileCopyrightText: 2023 EasyScience contributors +# SPDX-FileCopyrightText: 2025 EasyScience contributors # SPDX-License-Identifier: BSD-3-Clause -# © 2021-2023 Contributors to the EasyScience project 2) & (sys.version_info.minor > 8) @@ -35,7 +32,7 @@ class XMLSerializer(BaseEncoderDecoder): def encode( self, - obj: BV, + obj: ComponentSerializer, skip: Optional[List[str]] = None, data_only: bool = False, fast: bool = False, @@ -76,7 +73,7 @@ def encode( return header + ET.tostring(block, encoding='unicode') @classmethod - def decode(cls, data: str) -> BV: + def decode(cls, data: str) -> ComponentSerializer: """ Decode an EasyScience object which has been encoded in XML format. diff --git a/src/easyscience/models/__init__.py b/src/easyscience/models/__init__.py index 47316878..de175a27 100644 --- a/src/easyscience/models/__init__.py +++ b/src/easyscience/models/__init__.py @@ -1,6 +1,3 @@ -# SPDX-FileCopyrightText: 2023 EasyScience contributors +# SPDX-FileCopyrightText: 2025 EasyScience contributors # SPDX-License-Identifier: BSD-3-Clause -# © 2021-2023 Contributors to the EasyScience project +# SPDX-FileCopyrightText: 2025 EasyScience contributors # SPDX-License-Identifier: BSD-3-Clause -# © 2021-2023 Contributors to the EasyScience project np.ndarray: return np.polyval([c.value for c in self.coefficients], x) @@ -78,25 +64,3 @@ def __repr__(self): s = ' + '.join(s) return 'Polynomial({}, {})'.format(self.name, s) - -class Line(BaseObj): - m: ClassVar[Parameter] - c: ClassVar[Parameter] - - def __init__( - self, - m: Optional[Union[Parameter, float]] = None, - c: Optional[Union[Parameter, float]] = None, - ): - super(Line, self).__init__('line', m=Parameter('m', 1.0), c=Parameter('c', 0.0)) - if m is not None: - self.m = m - if c is not None: - self.c = c - - # @designate_calc_fn can be used to inject parameters into the calculation function. i.e. _m = m.value - def __call__(self, x: np.ndarray, *args, **kwargs) -> np.ndarray: - return self.m.value * x + self.c.value - - def __repr__(self): - return '{}({}, {})'.format(self.__class__.__name__, self.m, self.c) diff --git a/src/easyscience/utils/__init__.py b/src/easyscience/utils/__init__.py new file mode 100644 index 00000000..de175a27 --- /dev/null +++ b/src/easyscience/utils/__init__.py @@ -0,0 +1,3 @@ +# SPDX-FileCopyrightText: 2025 EasyScience contributors +# SPDX-License-Identifier: BSD-3-Clause +# © 2021-2025 Contributors to the EasyScience project +# SPDX-FileCopyrightText: 2025 EasyScience contributors # SPDX-License-Identifier: BSD-3-Clause -# © 2021-2023 Contributors to the EasyScience project None: +def addLoggedProp(inst: SerializerComponent, name: str, *args, **kwargs) -> None: cls = type(inst) annotations = getattr(cls, '__annotations__', False) if not hasattr(cls, '__perinstance'): @@ -32,7 +29,7 @@ def addLoggedProp(inst: BV, name: str, *args, **kwargs) -> None: setattr(cls, name, LoggedProperty(*args, **kwargs)) -def addProp(inst: BV, name: str, *args, **kwargs) -> None: +def addProp(inst: SerializerComponent, name: str, *args, **kwargs) -> None: cls = type(inst) annotations = getattr(cls, '__annotations__', False) if not hasattr(cls, '__perinstance'): @@ -46,7 +43,7 @@ def addProp(inst: BV, name: str, *args, **kwargs) -> None: setattr(cls, name, property(*args, **kwargs)) -def removeProp(inst: BV, name: str) -> None: +def removeProp(inst: SerializerComponent, name: str) -> None: cls = type(inst) if not hasattr(cls, '__perinstance'): cls = type(cls.__name__, (cls,), {'__module__': __name__}) @@ -56,7 +53,7 @@ def removeProp(inst: BV, name: str) -> None: delattr(cls, name) -def generatePath(model_obj: B, skip_first: bool = False) -> Tuple[List[int], List[str]]: +def generatePath(model_obj: BasedBase, skip_first: bool = False) -> Tuple[List[int], List[str]]: pars = model_obj.get_parameters() start_idx = 0 + int(skip_first) unique_names = [] diff --git a/src/easyscience/Utils/classUtils.py b/src/easyscience/utils/classUtils.py similarity index 94% rename from src/easyscience/Utils/classUtils.py rename to src/easyscience/utils/classUtils.py index 0b405fa5..e0c01e97 100644 --- a/src/easyscience/Utils/classUtils.py +++ b/src/easyscience/utils/classUtils.py @@ -1,9 +1,6 @@ -# SPDX-FileCopyrightText: 2023 EasyScience contributors +# SPDX-FileCopyrightText: 2025 EasyScience contributors # SPDX-License-Identifier: BSD-3-Clause -# © 2021-2023 Contributors to the EasyScience project +# SPDX-FileCopyrightText: 2025 EasyScience contributors # SPDX-License-Identifier: BSD-3-Clause -# © 2021-2023 Contributors to the EasyScience project +# SPDX-FileCopyrightText: 2025 EasyScience contributors # SPDX-License-Identifier: BSD-3-Clause -# © 2021-2023 Contributors to the EasyScience project +# SPDX-FileCopyrightText: 2025 EasyScience contributors # SPDX-License-Identifier: BSD-3-Clause -# © 2021-2023 Contributors to the EasyScience project None: """Detach an observer from the descriptor.""" self._observers.remove(observer) - def _notify_observers(self, update_id=None) -> None: - """Notify all observers of a change. - - :param update_id: Optional update ID to pass to observers. Used to avoid cyclic depenencies. + def _notify_observers(self) -> None: + """Notify all observers of a change.""" + for observer in self._observers: + observer._update() + + def _validate_dependencies(self, origin=None) -> None: + """Ping all observers to check if any cyclic dependencies have been introduced. + :param origin: Unique_name of the origin of this validation check. Used to avoid cyclic depenencies. """ - if update_id is None: - self._global_object.update_id_iterator += 1 - update_id = self._global_object.update_id_iterator + if origin == self.unique_name: + raise RuntimeError('\n Cyclic dependency detected!\n' + + f'An update of {self.unique_name} leads to it updating itself.\n' + + 'Please check your dependencies.') + if origin is None: + origin = self.unique_name for observer in self._observers: - observer._update(update_id=update_id, updating_object=self.unique_name) + observer._validate_dependencies(origin=origin) @property def full_value(self) -> Variable: diff --git a/src/easyscience/Objects/variable/descriptor_str.py b/src/easyscience/variable/descriptor_str.py similarity index 100% rename from src/easyscience/Objects/variable/descriptor_str.py rename to src/easyscience/variable/descriptor_str.py diff --git a/src/easyscience/Objects/variable/parameter.py b/src/easyscience/variable/parameter.py similarity index 92% rename from src/easyscience/Objects/variable/parameter.py rename to src/easyscience/variable/parameter.py index 63746993..c459946c 100644 --- a/src/easyscience/Objects/variable/parameter.py +++ b/src/easyscience/variable/parameter.py @@ -1,6 +1,6 @@ -# SPDX-FileCopyrightText: 2023 EasyScience contributors +# SPDX-FileCopyrightText: 2025 EasyScience contributors # SPDX-License-Identifier: BSD-3-Clause -# © 2021-2023 Contributors to the EasyScience project None: + def _update(self) -> None: """ Update the parameter. This is called by the DescriptorNumbers/Parameters who have this Parameter as a dependency. - - :param update_id: The id of the update. This is used to avoid cyclic dependencies. - :param updating_object: The unique_name of the object which is updating this parameter. - """ if not self._independent: - # Check if this parameter has already been updated by the updating object with this update id - if updating_object not in self._dependency_updates: - self._dependency_updates[updating_object] = 0 - if self._dependency_updates[updating_object] == update_id: - raise RuntimeError('\n Potential cyclic dependency detected!\n' + - f'This parameter, {self.unique_name}, has already been updated by {updating_object} during this update.\n' + # noqa: E501 - 'Please check your dependencies.') - else: - # Update the value of the parameter using the dependency interpreter - temporary_parameter = self._dependency_interpreter(self._clean_dependency_string) - self._scalar.value = temporary_parameter.value - self._scalar.unit = temporary_parameter.unit - self._scalar.variance = temporary_parameter.variance - self._min.value = temporary_parameter.min if isinstance(temporary_parameter, Parameter) else temporary_parameter.value # noqa: E501 - self._max.value = temporary_parameter.max if isinstance(temporary_parameter, Parameter) else temporary_parameter.value # noqa: E501 - self._min.unit = temporary_parameter.unit - self._max.unit = temporary_parameter.unit - self._dependency_updates[updating_object] = update_id - self._notify_observers(update_id=update_id) + # Update the value of the parameter using the dependency interpreter + temporary_parameter = self._dependency_interpreter(self._clean_dependency_string) + self._scalar.value = temporary_parameter.value + self._scalar.unit = temporary_parameter.unit + self._scalar.variance = temporary_parameter.variance + self._min.value = temporary_parameter.min if isinstance(temporary_parameter, Parameter) else temporary_parameter.value # noqa: E501 + self._max.value = temporary_parameter.max if isinstance(temporary_parameter, Parameter) else temporary_parameter.value # noqa: E501 + self._min.unit = temporary_parameter.unit + self._max.unit = temporary_parameter.unit + self._notify_observers() else: warnings.warn('This parameter is not dependent. It cannot be updated.') @@ -185,11 +169,20 @@ def make_dependent_on(self, dependency_expression: str, dependency_map: Optional if not isinstance(value, DescriptorNumber): raise TypeError(f'`dependency_map` values must be DescriptorNumbers or Parameters. Got {type(value)} for {key}.') # noqa: E501 - # If we're overwriting the dependency + # If we're overwriting the dependency, store the old attributes + # in case we need to revert back to the old dependency + self._previous_independent = self._independent if not self._independent: - for old_dependency in self._dependency_map.values(): - old_dependency._detach_observer(self) + self._previous_dependency = { + '_dependency_string': self._dependency_string, + '_dependency_map': self._dependency_map, + '_dependency_interpreter': self._dependency_interpreter, + '_clean_dependency_string': self._clean_dependency_string, + } + for dependency in self._dependency_map.values(): + dependency._detach_observer(self) + self._independent = False self._dependency_string = dependency_expression self._dependency_map = dependency_map if dependency_map is not None else {} # List of allowed python constructs for the asteval interpreter @@ -200,35 +193,44 @@ def make_dependent_on(self, dependency_expression: str, dependency_map: Optional 'listcomp': False, 'dictcomp': False, 'setcomp': False, 'try': False, 'while': False, 'with': False} self._dependency_interpreter = Interpreter(config=asteval_config) - self._dependency_updates = {} # Used to track update ids to avoid cyclic dependencies - - self._process_dependency_unique_names(self._dependency_string) + + # Process the dependency expression for unique names + try: + self._process_dependency_unique_names(self._dependency_string) + except ValueError as error: + self._revert_dependency(skip_detach=True) + raise error + for key, value in self._dependency_map.items(): self._dependency_interpreter.symtable[key] = value self._dependency_interpreter.readonly_symbols.add(key) # Dont allow overwriting of the dependencies in the dependency expression # noqa: E501 value._attach_observer(self) + # Check the dependency expression for errors try: dependency_result = self._dependency_interpreter.eval(self._clean_dependency_string, raise_errors=True) except NameError as message: + self._revert_dependency() raise NameError('\nUnknown name encountered in dependecy expression:'+ '\n'+'\n'.join(str(message).split("\n")[1:])+ '\nPlease check your expression or add the name to the `dependency_map`') from None except Exception as message: + self._revert_dependency() raise SyntaxError('\nError encountered in dependecy expression:'+ '\n'+'\n'.join(str(message).split("\n")[1:])+ '\nPlease check your expression') from None if not isinstance(dependency_result, DescriptorNumber): - raise TypeError(f'The dependency expression: "{self._dependency_string}" returned a {type(dependency_result)}, it should return a Parameter or DescriptorNumber.') # noqa: E501 - self._scalar.value = dependency_result.value - self._scalar.unit = dependency_result.unit - self._scalar.variance = dependency_result.variance - self._min.value = dependency_result.min if isinstance(dependency_result, Parameter) else dependency_result.value - self._max.value = dependency_result.max if isinstance(dependency_result, Parameter) else dependency_result.value - self._min.unit = dependency_result.unit - self._max.unit = dependency_result.unit - self._independent = False + error_string = self._dependency_string + self._revert_dependency() + raise TypeError(f'The dependency expression: "{error_string}" returned a {type(dependency_result)}, it should return a Parameter or DescriptorNumber.') # noqa: E501 + # Check for cyclic dependencies + try: + self._validate_dependencies() + except RuntimeError as error: + self._revert_dependency() + raise error + # Update the parameter with the dependency result self._fixed = False - self._notify_observers() + self._update() def make_independent(self) -> None: """ @@ -242,7 +244,6 @@ def make_independent(self) -> None: dependency._detach_observer(self) self._independent = True del self._dependency_map - del self._dependency_updates del self._dependency_interpreter del self._dependency_string del self._clean_dependency_string @@ -511,6 +512,23 @@ def free(self) -> bool: def free(self, value: bool) -> None: self.fixed = not value + def _revert_dependency(self, skip_detach=False) -> None: + """ + Revert the dependency to the old dependency. This is used when an error is raised during setting the dependency. + """ + if self._previous_independent is True: + self.make_independent() + else: + if not skip_detach: + for dependency in self._dependency_map.values(): + dependency._detach_observer(self) + for key, value in self._previous_dependency.items(): + setattr(self, key, value) + for dependency in self._dependency_map.values(): + dependency._attach_observer(self) + del self._previous_dependency + del self._previous_independent + def _process_dependency_unique_names(self, dependency_expression: str): """ Add the unique names of the parameters to the ASTEval interpreter. This is used to evaluate the dependency expression. diff --git a/tests/__init__.py b/tests/__init__.py index a4ab5234..462d95af 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1,3 +1,3 @@ -# SPDX-FileCopyrightText: 2023 EasyScience contributors +# SPDX-FileCopyrightText: 2025 EasyScience contributors # SPDX-License-Identifier: BSD-3-Clause -# © 2021-2023 Contributors to the EasyScience project +# SPDX-FileCopyrightText: 2025 EasyScience contributors # SPDX-License-Identifier: BSD-3-Clause -# © 2021-2023 Contributors to the EasyScience project +# SPDX-FileCopyrightText: 2025 EasyScience contributors # SPDX-License-Identifier: BSD-3-Clause -# © 2021-2023 Contributors to the EasyScience project +# SPDX-FileCopyrightText: 2025 EasyScience contributors # SPDX-License-Identifier: BSD-3-Clause -# © 2021-2023 Contributors to the EasyScience project -# SPDX-License-Identifier: BSD-3-Clause -# © 2021-2023 Contributors to the EasyScience project -# SPDX-License-Identifier: BSD-3-Clause -# © 2021-2023 Contributors to the EasyScience project -# SPDX-License-Identifier: BSD-3-Clause -# © 2021-2023 Contributors to the EasyScience project -# SPDX-License-Identifier: BSD-3-Clause -# © 2021-2023 Contributors to the EasyScience project " - d = DescriptorNumber("test", 1, unit="cm") - assert repr(d) == f"<{d.__class__.__name__} 'test': 1.0000 cm>" - - -def test_descriptor_number_as_dict(): - d = DescriptorNumber("test", 1) - result = d.as_dict() - expected = { - "@module": DescriptorNumber.__module__, - "@class": DescriptorNumber.__name__, - "@version": easyscience.__version__, - "name": "test", - "value": 1, - "unit": "dimensionless", - "description": "", - "url": "", - "display_name": "test", - "callback": None, - } - for key in expected.keys(): - if key == "callback": - continue - assert result[key] == expected[key] - - -@pytest.mark.parametrize( - "reference, constructor", - ( - [ - { - "@module": DescriptorBool.__module__, - "@class": DescriptorBool.__name__, - "@version": easyscience.__version__, - "name": "test", - "value": False, - "description": "", - "url": "", - "display_name": "test", - }, - DescriptorBool, - ], - [ - { - "@module": DescriptorNumber.__module__, - "@class": DescriptorNumber.__name__, - "@version": easyscience.__version__, - "name": "test", - "value": 1, - "unit": "dimensionless", - "variance": 0.0, - "description": "", - "url": "", - "display_name": "test", - }, - DescriptorNumber, - ], - [ - { - "@module": DescriptorStr.__module__, - "@class": DescriptorStr.__name__, - "@version": easyscience.__version__, - "name": "test", - "value": "string", - "description": "", - "url": "", - "display_name": "test", - }, - DescriptorStr, - ], - ), - ids=["DescriptorBool", "DescriptorNumber", "DescriptorStr"], -) -def test_item_from_dict(reference, constructor): - d = constructor.from_dict(reference) - for key, item in reference.items(): - if key.startswith("@"): - continue - obtained = getattr(d, key) - assert obtained == item - - -@pytest.mark.parametrize("value", ("This is ", "a fun ", "test")) -def test_parameter_display_name(value): - p = DescriptorNumber("test", 1, display_name=value) - assert p.display_name == value - - -def test_item_boolean_value(): - item = DescriptorBool("test", True) - assert item.value is True - item.value = False - assert item.value is False - - item = DescriptorBool("test", False) - assert item.value is False - item.value = True - assert item.value is True diff --git a/tests/unit_tests/__init__.py b/tests/unit_tests/__init__.py index 22e236a6..462d95af 100644 --- a/tests/unit_tests/__init__.py +++ b/tests/unit_tests/__init__.py @@ -1,6 +1,3 @@ -# SPDX-FileCopyrightText: 2023 EasyScience contributors +# SPDX-FileCopyrightText: 2025 EasyScience contributors # SPDX-License-Identifier: BSD-3-Clause -# © 2021-2023 Contributors to the EasyScience project +# SPDX-FileCopyrightText: 2025 EasyScience contributors # SPDX-License-Identifier: BSD-3-Clause -# © 2021-2023 Contributors to the EasyScience project +# SPDX-FileCopyrightText: 2025 EasyScience contributors # SPDX-License-Identifier: BSD-3-Clause -# © 2021-2023 Contributors to the EasyScience project +# SPDX-License-Identifier: BSD-3-Clause +# © 2021-2025 Contributors to the EasyScience project +# SPDX-FileCopyrightText: 2025 EasyScience contributors # SPDX-License-Identifier: BSD-3-Clause -# © 2021-2023 Contributors to the EasyScience project +# SPDX-FileCopyrightText: 2025 EasyScience contributors # SPDX-License-Identifier: BSD-3-Clause -# © 2021-2023 Contributors to the EasyScience project +# SPDX-License-Identifier: BSD-3-Clause +# © 2025 Contributors to the EasyScience project + diff --git a/tests/unit_tests/utils/io_tests/test_core.py b/tests/unit_tests/io/test_serializer_component.py similarity index 75% rename from tests/unit_tests/utils/io_tests/test_core.py rename to tests/unit_tests/io/test_serializer_component.py index 2083ac3c..80994b4d 100644 --- a/tests/unit_tests/utils/io_tests/test_core.py +++ b/tests/unit_tests/io/test_serializer_component.py @@ -1,5 +1,3 @@ -__author__ = "github.com/wardsimon" -__version__ = "0.0.1" import numpy as np from copy import deepcopy @@ -8,8 +6,8 @@ import pytest import easyscience -from easyscience.Objects.variable import DescriptorNumber -from easyscience.Objects.variable import Parameter +from easyscience import DescriptorNumber +from easyscience import Parameter dp_param_dict = { "argnames": "dp_kwargs, dp_cls", @@ -97,27 +95,3 @@ def test_variable_as_dict_methods(dp_kwargs: dict, dp_cls: Type[DescriptorNumber check_dict(dp_kwargs, enc) - -@pytest.mark.parametrize(**skip_dict) -@pytest.mark.parametrize(**dp_param_dict) -def test_variable_as_data_dict_methods(dp_kwargs: dict, dp_cls: Type[DescriptorNumber], skip): - data_dict = {k: v for k, v in dp_kwargs.items() if k[0] != "@"} - - obj = dp_cls(**data_dict) - - if isinstance(skip, str): - del data_dict[skip] - - if not isinstance(skip, list): - skip = [skip] - - enc_d = obj.as_data_dict(skip=skip) - - expected_keys = set(data_dict.keys()) - obtained_keys = set(enc_d.keys()) - - dif = expected_keys.difference(obtained_keys) - - assert len(dif) == 0 - - check_dict(data_dict, enc_d) diff --git a/tests/unit_tests/io/test_serializer_dict.py b/tests/unit_tests/io/test_serializer_dict.py new file mode 100644 index 00000000..7d75f02d --- /dev/null +++ b/tests/unit_tests/io/test_serializer_dict.py @@ -0,0 +1,117 @@ + +from copy import deepcopy +from typing import Type + +import pytest + +from easyscience.io.serializer_dict import SerializerDict +from easyscience import DescriptorNumber +from easyscience import ObjBase + +from .test_serializer_component import check_dict +from .test_serializer_component import dp_param_dict +from .test_serializer_component import skip_dict +from easyscience import global_object + + +def recursive_remove(d, remove_keys: list) -> dict: + """ + Remove keys from a dictionary. + """ + if not isinstance(remove_keys, list): + remove_keys = [remove_keys] + if isinstance(d, dict): + dd = {} + for k in d.keys(): + if k not in remove_keys: + dd[k] = recursive_remove(d[k], remove_keys) + return dd + else: + return d + + +######################################################################################################################## +# TESTING ENCODING +######################################################################################################################## +@pytest.mark.parametrize(**skip_dict) +@pytest.mark.parametrize(**dp_param_dict) +def test_variable_SerializerDict(dp_kwargs: dict, dp_cls: Type[DescriptorNumber], skip): + data_dict = {k: v for k, v in dp_kwargs.items() if k[0] != "@"} + + obj = dp_cls(**data_dict) + + dp_kwargs = deepcopy(dp_kwargs) + + if isinstance(skip, str): + del dp_kwargs[skip] + + if not isinstance(skip, list): + skip = [skip] + + enc = obj.encode(skip=skip, encoder=SerializerDict) + + expected_keys = set(dp_kwargs.keys()) + obtained_keys = set(enc.keys()) + + dif = expected_keys.difference(obtained_keys) + + assert len(dif) == 0 + + check_dict(dp_kwargs, enc) + +######################################################################################################################## +# TESTING DECODING +######################################################################################################################## +@pytest.mark.parametrize(**dp_param_dict) +def test_variable_SerializerDict_decode(dp_kwargs: dict, dp_cls: Type[DescriptorNumber]): + data_dict = {k: v for k, v in dp_kwargs.items() if k[0] != "@"} + + obj = dp_cls(**data_dict) + + enc = obj.encode(encoder=SerializerDict) + global_object.map._clear() + dec = dp_cls.decode(enc, decoder=SerializerDict) + + for k in data_dict.keys(): + if hasattr(obj, k) and hasattr(dec, k): + assert getattr(obj, k) == getattr(dec, k) + else: + raise AttributeError(f"{k} not found in decoded object") + + +@pytest.mark.parametrize(**dp_param_dict) +def test_variable_SerializerDict_from_dict(dp_kwargs: dict, dp_cls: Type[DescriptorNumber]): + data_dict = {k: v for k, v in dp_kwargs.items() if k[0] != "@"} + + obj = dp_cls(**data_dict) + + enc = obj.encode(encoder=SerializerDict) + global_object.map._clear() + dec = dp_cls.from_dict(enc) + + for k in data_dict.keys(): + if hasattr(obj, k) and hasattr(dec, k): + assert getattr(obj, k) == getattr(dec, k) + else: + raise AttributeError(f"{k} not found in decoded object") + +def test_group_encode(): + d0 = DescriptorNumber("a", 0) + d1 = DescriptorNumber("b", 1) + + from easyscience.base_classes import CollectionBase + + b = CollectionBase("test", d0, d1) + d = b.as_dict() + assert isinstance(d["data"], list) + + +def test_group_encode2(): + d0 = DescriptorNumber("a", 0) + d1 = DescriptorNumber("b", 1) + + from easyscience.base_classes import CollectionBase + + b = ObjBase("outer", b=CollectionBase("test", d0, d1)) + d = b.as_dict() + assert isinstance(d["b"], dict) \ No newline at end of file diff --git a/tests/unit_tests/legacy/test_dict.py b/tests/unit_tests/legacy/test_dict.py new file mode 100644 index 00000000..70af9d78 --- /dev/null +++ b/tests/unit_tests/legacy/test_dict.py @@ -0,0 +1,156 @@ + +# from copy import deepcopy +# from typing import Type + +# import pytest + +# from easyscience.io.dict import DataDictSerializer +# from easyscience.io.dict import DictSerializer +# from easyscience.variable import DescriptorNumber +# from easyscience.base_classes import BaseObj + +# from .test_core import check_dict +# from .test_core import dp_param_dict +# from .test_core import skip_dict +# from easyscience import global_object + + +# def recursive_remove(d, remove_keys: list) -> dict: +# """ +# Remove keys from a dictionary. +# """ +# if not isinstance(remove_keys, list): +# remove_keys = [remove_keys] +# if isinstance(d, dict): +# dd = {} +# for k in d.keys(): +# if k not in remove_keys: +# dd[k] = recursive_remove(d[k], remove_keys) +# return dd +# else: +# return d + + +# ######################################################################################################################## +# # TESTING ENCODING +# ######################################################################################################################## +# @pytest.mark.parametrize(**skip_dict) +# @pytest.mark.parametrize(**dp_param_dict) +# def test_variable_DictSerializer(dp_kwargs: dict, dp_cls: Type[DescriptorNumber], skip): +# data_dict = {k: v for k, v in dp_kwargs.items() if k[0] != "@"} + +# obj = dp_cls(**data_dict) + +# dp_kwargs = deepcopy(dp_kwargs) + +# if isinstance(skip, str): +# del dp_kwargs[skip] + +# if not isinstance(skip, list): +# skip = [skip] + +# enc = obj.encode(skip=skip, encoder=DictSerializer) + +# expected_keys = set(dp_kwargs.keys()) +# obtained_keys = set(enc.keys()) + +# dif = expected_keys.difference(obtained_keys) + +# assert len(dif) == 0 + +# check_dict(dp_kwargs, enc) + + +# @pytest.mark.parametrize(**skip_dict) +# @pytest.mark.parametrize(**dp_param_dict) +# def test_variable_DataDictSerializer(dp_kwargs: dict, dp_cls: Type[DescriptorNumber], skip): +# data_dict = {k: v for k, v in dp_kwargs.items() if k[0] != "@"} + +# obj = dp_cls(**data_dict) + +# if isinstance(skip, str): +# del data_dict[skip] + +# if not isinstance(skip, list): +# skip = [skip] + +# enc_d = obj.encode(skip=skip, encoder=DataDictSerializer) + +# expected_keys = set(data_dict.keys()) +# obtained_keys = set(enc_d.keys()) + +# dif = expected_keys.difference(obtained_keys) + +# assert len(dif) == 0 + +# check_dict(data_dict, enc_d) + + +# ######################################################################################################################## +# # TESTING DECODING +# ######################################################################################################################## +# @pytest.mark.parametrize(**dp_param_dict) +# def test_variable_DictSerializer_decode(dp_kwargs: dict, dp_cls: Type[DescriptorNumber]): +# data_dict = {k: v for k, v in dp_kwargs.items() if k[0] != "@"} + +# obj = dp_cls(**data_dict) + +# enc = obj.encode(encoder=DictSerializer) +# global_object.map._clear() +# dec = dp_cls.decode(enc, decoder=DictSerializer) + +# for k in data_dict.keys(): +# if hasattr(obj, k) and hasattr(dec, k): +# assert getattr(obj, k) == getattr(dec, k) +# else: +# raise AttributeError(f"{k} not found in decoded object") + + +# @pytest.mark.parametrize(**dp_param_dict) +# def test_variable_DictSerializer_from_dict(dp_kwargs: dict, dp_cls: Type[DescriptorNumber]): +# data_dict = {k: v for k, v in dp_kwargs.items() if k[0] != "@"} + +# obj = dp_cls(**data_dict) + +# enc = obj.encode(encoder=DictSerializer) +# global_object.map._clear() +# dec = dp_cls.from_dict(enc) + +# for k in data_dict.keys(): +# if hasattr(obj, k) and hasattr(dec, k): +# assert getattr(obj, k) == getattr(dec, k) +# else: +# raise AttributeError(f"{k} not found in decoded object") + + +# @pytest.mark.parametrize(**dp_param_dict) +# def test_variable_DataDictSerializer_decode(dp_kwargs: dict, dp_cls: Type[DescriptorNumber]): +# data_dict = {k: v for k, v in dp_kwargs.items() if k[0] != "@"} + +# obj = dp_cls(**data_dict) + +# enc = obj.encode(encoder=DataDictSerializer) +# with pytest.raises(NotImplementedError): +# dec = obj.decode(enc, decoder=DataDictSerializer) + + +# def test_group_encode(): +# d0 = DescriptorNumber("a", 0) +# d1 = DescriptorNumber("b", 1) + +# from easyscience.base_classes import BaseCollection + +# b = BaseCollection("test", d0, d1) +# d = b.as_dict() +# assert isinstance(d["data"], list) + + +# def test_group_encode2(): +# d0 = DescriptorNumber("a", 0) +# d1 = DescriptorNumber("b", 1) + +# from easyscience.base_classes import BaseCollection + +# b = BaseObj("outer", b=BaseCollection("test", d0, d1)) +# d = b.as_dict() +# assert isinstance(d["b"], dict) \ No newline at end of file diff --git a/tests/unit_tests/legacy/test_json.py b/tests/unit_tests/legacy/test_json.py new file mode 100644 index 00000000..651bd950 --- /dev/null +++ b/tests/unit_tests/legacy/test_json.py @@ -0,0 +1,123 @@ + +# import json +# from copy import deepcopy +# from typing import Type + +# import pytest + +# from easyscience.io.json import JsonDataSerializer +# from easyscience.io.json import JsonSerializer +# from easyscience.variable import DescriptorNumber + +# from .test_core import check_dict +# from .test_core import dp_param_dict +# from .test_core import skip_dict +# from easyscience import global_object + + +# def recursive_remove(d, remove_keys: list) -> dict: +# """ +# Remove keys from a dictionary. +# """ +# if not isinstance(remove_keys, list): +# remove_keys = [remove_keys] +# if isinstance(d, dict): +# dd = {} +# for k in d.keys(): +# if k not in remove_keys: +# dd[k] = recursive_remove(d[k], remove_keys) +# return dd +# else: +# return d + + +# ######################################################################################################################## +# # TESTING ENCODING +# ######################################################################################################################## +# @pytest.mark.parametrize(**skip_dict) +# @pytest.mark.parametrize(**dp_param_dict) +# def test_variable_DictSerializer(dp_kwargs: dict, dp_cls: Type[DescriptorNumber], skip): +# data_dict = {k: v for k, v in dp_kwargs.items() if k[0] != "@"} + +# obj = dp_cls(**data_dict) + +# dp_kwargs = deepcopy(dp_kwargs) + +# if isinstance(skip, str): +# del dp_kwargs[skip] + +# if not isinstance(skip, list): +# skip = [skip] + +# enc = obj.encode(skip=skip, encoder=JsonSerializer) +# assert isinstance(enc, str) + +# # We can test like this as we don't have "complex" objects yet +# dec = json.loads(enc) +# expected_keys = set(dp_kwargs.keys()) +# obtained_keys = set(dec.keys()) + +# dif = expected_keys.difference(obtained_keys) + +# assert len(dif) == 0 + +# check_dict(dp_kwargs, dec) + + +# @pytest.mark.parametrize(**skip_dict) +# @pytest.mark.parametrize(**dp_param_dict) +# def test_variable_DataDictSerializer(dp_kwargs: dict, dp_cls: Type[DescriptorNumber], skip): +# data_dict = {k: v for k, v in dp_kwargs.items() if k[0] != "@"} + +# obj = dp_cls(**data_dict) + +# if isinstance(skip, str): +# del data_dict[skip] + +# if not isinstance(skip, list): +# skip = [skip] + +# enc = obj.encode(skip=skip, encoder=JsonDataSerializer) +# assert isinstance(enc, str) +# enc_d = json.loads(enc) + +# expected_keys = set(data_dict.keys()) +# obtained_keys = set(enc_d.keys()) + +# dif = expected_keys.difference(obtained_keys) + +# assert len(dif) == 0 + +# check_dict(data_dict, enc_d) + +# # ######################################################################################################################## +# # # TESTING DECODING +# # ######################################################################################################################## +# @pytest.mark.parametrize(**dp_param_dict) +# def test_variable_DictSerializer_decode(dp_kwargs: dict, dp_cls: Type[DescriptorNumber]): +# data_dict = {k: v for k, v in dp_kwargs.items() if k[0] != "@"} + +# obj = dp_cls(**data_dict) + +# enc = obj.encode(encoder=JsonSerializer) +# global_object.map._clear() +# assert isinstance(enc, str) +# dec = obj.decode(enc, decoder=JsonSerializer) + +# for k in data_dict.keys(): +# if hasattr(obj, k) and hasattr(dec, k): +# assert getattr(obj, k) == getattr(dec, k) +# else: +# raise AttributeError(f"{k} not found in decoded object") + + +# @pytest.mark.parametrize(**dp_param_dict) +# def test_variable_DataDictSerializer_decode(dp_kwargs: dict, dp_cls: Type[DescriptorNumber]): +# data_dict = {k: v for k, v in dp_kwargs.items() if k[0] != "@"} + +# obj = dp_cls(**data_dict) + +# enc = obj.encode(encoder=JsonDataSerializer) +# global_object.map._clear() +# with pytest.raises(NotImplementedError): +# dec = obj.decode(enc, decoder=JsonDataSerializer) diff --git a/tests/unit_tests/legacy/test_xml.py b/tests/unit_tests/legacy/test_xml.py new file mode 100644 index 00000000..7094de85 --- /dev/null +++ b/tests/unit_tests/legacy/test_xml.py @@ -0,0 +1,111 @@ + +# import sys +# import xml.etree.ElementTree as ET +# from copy import deepcopy +# from typing import Type + +# import pytest + +# from easyscience.legacy.xml import XMLSerializer +# from easyscience.variable import DescriptorNumber + +# from ..io.test_core import dp_param_dict +# from ..io.test_core import skip_dict +# from easyscience import global_object + +# def recursive_remove(d, remove_keys: list) -> dict: +# """ +# Remove keys from a dictionary. +# """ +# if not isinstance(remove_keys, list): +# remove_keys = [remove_keys] +# if isinstance(d, dict): +# dd = {} +# for k in d.keys(): +# if k not in remove_keys: +# dd[k] = recursive_remove(d[k], remove_keys) +# return dd +# else: +# return d + + +# def recursive_test(testing_obj, reference_obj): +# for i, (k, v) in enumerate(testing_obj.items()): +# if isinstance(v, dict): +# recursive_test(v, reference_obj[i]) +# else: +# assert v == XMLSerializer.string_to_variable(reference_obj[i].text) + + +# ######################################################################################################################## +# # TESTING ENCODING +# ######################################################################################################################## +# @pytest.mark.parametrize(**skip_dict) +# @pytest.mark.parametrize(**dp_param_dict) +# def test_variable_XMLDictSerializer(dp_kwargs: dict, dp_cls: Type[DescriptorNumber], skip): +# data_dict = {k: v for k, v in dp_kwargs.items() if k[0] != "@"} + +# obj = dp_cls(**data_dict) + +# dp_kwargs = deepcopy(dp_kwargs) + +# if isinstance(skip, str): +# del dp_kwargs[skip] + +# if not isinstance(skip, list): +# skip = [skip] + +# enc = obj.encode(skip=skip, encoder=XMLSerializer) +# ref_encode = obj.encode(skip=skip) +# assert isinstance(enc, str) +# data_xml = ET.XML(enc) +# assert data_xml.tag == "data" +# recursive_test(data_xml, ref_encode) + +# # ######################################################################################################################## +# # # TESTING DECODING +# # ######################################################################################################################## +# @pytest.mark.parametrize(**dp_param_dict) +# def test_variable_XMLDictSerializer_decode(dp_kwargs: dict, dp_cls: Type[DescriptorNumber]): +# data_dict = {k: v for k, v in dp_kwargs.items() if k[0] != "@"} + +# obj = dp_cls(**data_dict) + +# enc = obj.encode(encoder=XMLSerializer) +# assert isinstance(enc, str) +# data_xml = ET.XML(enc) +# assert data_xml.tag == "data" +# global_object.map._clear() +# dec = dp_cls.decode(enc, decoder=XMLSerializer) + +# for k in data_dict.keys(): +# if hasattr(obj, k) and hasattr(dec, k): +# assert getattr(obj, k) == getattr(dec, k) +# else: +# raise AttributeError(f"{k} not found in decoded object") + + +# def test_slow_encode(): + +# if sys.version_info < (3, 9): +# pytest.skip("This test is only for python 3.9+") + +# a = {"a": [1, 2, 3]} +# slow_xml = XMLSerializer().encode(a, fast=False) +# reference = """ +# 1 +# 2 +# 3 +# """ +# assert slow_xml == reference + + +# def test_include_header(): + +# if sys.version_info < (3, 9): +# pytest.skip("This test is only for python 3.9+") + +# a = {"a": [1, 2, 3]} +# header_xml = XMLSerializer().encode(a, use_header=True) +# reference = '?xml version="1.0" encoding="UTF-8"?\n\n 1\n 2\n 3\n' +# assert header_xml == reference diff --git a/tests/unit_tests/models/__init__.py b/tests/unit_tests/models/__init__.py index 3d57f66c..bb769856 100644 --- a/tests/unit_tests/models/__init__.py +++ b/tests/unit_tests/models/__init__.py @@ -1,6 +1,4 @@ -# SPDX-FileCopyrightText: 2022 EasyScience contributors +# SPDX-FileCopyrightText: 2025 EasyScience contributors # SPDX-License-Identifier: BSD-3-Clause -# © 2022 Contributors to the EasyScience project +# © 2025 Contributors to the EasyScience project -__author__ = "github.com/wardsimon" -__version__ = "0.0.1" diff --git a/tests/unit_tests/models/test_polynomial.py b/tests/unit_tests/models/test_polynomial.py index adccb9e1..799a917a 100644 --- a/tests/unit_tests/models/test_polynomial.py +++ b/tests/unit_tests/models/test_polynomial.py @@ -1,18 +1,13 @@ -# SPDX-FileCopyrightText: 2022 EasyScience contributors +# SPDX-FileCopyrightText: 2025 EasyScience contributors # SPDX-License-Identifier: BSD-3-Clause -# © 2022 Contributors to the EasyScience project +# © 2025 Contributors to the EasyScience project -__author__ = "github.com/wardsimon" -__version__ = "0.0.1" import numpy as np import pytest -from easyscience.models.polynomial import Line from easyscience.models.polynomial import Polynomial -from easyscience.Objects.variable.parameter import Parameter -line_test_cases = ((1, 2), (-1, -2), (0.72, 6.48)) poly_test_cases = ( (1.,), ( @@ -24,33 +19,6 @@ (0.72, 6.48, -0.48), ) - -@pytest.mark.parametrize("m, c", line_test_cases) -def test_Line_pars(m, c): - line = Line(m, c) - - assert line.m.value == m - assert line.c.value == c - - x = np.linspace(0, 10, 100) - y = line.m.value * x + line.c.value - assert np.allclose(line(x), y) - - -@pytest.mark.parametrize("m, c", line_test_cases) -def test_Line_constructor(m, c): - m_ = Parameter("m", m) - c_ = Parameter("c", c) - line = Line(m_, c_) - - assert line.m.value == m - assert line.c.value == c - - x = np.linspace(0, 10, 100) - y = line.m.value * x + line.c.value - assert np.allclose(line(x), y) - - @pytest.mark.parametrize("coo", poly_test_cases) def test_Polynomial_pars(coo): poly = Polynomial(coefficients=coo) diff --git a/tests/unit_tests/utils/__init__.py b/tests/unit_tests/utils/__init__.py deleted file mode 100644 index 3d57f66c..00000000 --- a/tests/unit_tests/utils/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -# SPDX-FileCopyrightText: 2022 EasyScience contributors -# SPDX-License-Identifier: BSD-3-Clause -# © 2022 Contributors to the EasyScience project - -__author__ = "github.com/wardsimon" -__version__ = "0.0.1" diff --git a/tests/unit_tests/utils/io_tests/__init__.py b/tests/unit_tests/utils/io_tests/__init__.py deleted file mode 100644 index 3d57f66c..00000000 --- a/tests/unit_tests/utils/io_tests/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -# SPDX-FileCopyrightText: 2022 EasyScience contributors -# SPDX-License-Identifier: BSD-3-Clause -# © 2022 Contributors to the EasyScience project - -__author__ = "github.com/wardsimon" -__version__ = "0.0.1" diff --git a/tests/unit_tests/utils/io_tests/test_dict.py b/tests/unit_tests/utils/io_tests/test_dict.py deleted file mode 100644 index a9b8ccd4..00000000 --- a/tests/unit_tests/utils/io_tests/test_dict.py +++ /dev/null @@ -1,186 +0,0 @@ -__author__ = "github.com/wardsimon" -__version__ = "0.0.1" - -from copy import deepcopy -from typing import Type - -import pytest - -from easyscience.Utils.io.dict import DataDictSerializer -from easyscience.Utils.io.dict import DictSerializer -from easyscience.Objects.variable import DescriptorNumber -from easyscience.Objects.ObjectClasses import BaseObj - -from .test_core import check_dict -from .test_core import dp_param_dict -from .test_core import skip_dict -from easyscience import global_object - - -def recursive_remove(d, remove_keys: list) -> dict: - """ - Remove keys from a dictionary. - """ - if not isinstance(remove_keys, list): - remove_keys = [remove_keys] - if isinstance(d, dict): - dd = {} - for k in d.keys(): - if k not in remove_keys: - dd[k] = recursive_remove(d[k], remove_keys) - return dd - else: - return d - - -######################################################################################################################## -# TESTING ENCODING -######################################################################################################################## -@pytest.mark.parametrize(**skip_dict) -@pytest.mark.parametrize(**dp_param_dict) -def test_variable_DictSerializer(dp_kwargs: dict, dp_cls: Type[DescriptorNumber], skip): - data_dict = {k: v for k, v in dp_kwargs.items() if k[0] != "@"} - - obj = dp_cls(**data_dict) - - dp_kwargs = deepcopy(dp_kwargs) - - if isinstance(skip, str): - del dp_kwargs[skip] - - if not isinstance(skip, list): - skip = [skip] - - enc = obj.encode(skip=skip, encoder=DictSerializer) - - expected_keys = set(dp_kwargs.keys()) - obtained_keys = set(enc.keys()) - - dif = expected_keys.difference(obtained_keys) - - assert len(dif) == 0 - - check_dict(dp_kwargs, enc) - - -@pytest.mark.parametrize(**skip_dict) -@pytest.mark.parametrize(**dp_param_dict) -def test_variable_DataDictSerializer(dp_kwargs: dict, dp_cls: Type[DescriptorNumber], skip): - data_dict = {k: v for k, v in dp_kwargs.items() if k[0] != "@"} - - obj = dp_cls(**data_dict) - - if isinstance(skip, str): - del data_dict[skip] - - if not isinstance(skip, list): - skip = [skip] - - enc_d = obj.encode(skip=skip, encoder=DataDictSerializer) - - expected_keys = set(data_dict.keys()) - obtained_keys = set(enc_d.keys()) - - dif = expected_keys.difference(obtained_keys) - - assert len(dif) == 0 - - check_dict(data_dict, enc_d) - - -@pytest.mark.parametrize( - "encoder", [None, DataDictSerializer], ids=["Default", "DataDictSerializer"] -) -@pytest.mark.parametrize(**skip_dict) -@pytest.mark.parametrize(**dp_param_dict) -def test_variable_encode_data(dp_kwargs: dict, dp_cls: Type[DescriptorNumber], skip, encoder): - data_dict = {k: v for k, v in dp_kwargs.items() if k[0] != "@"} - - obj = dp_cls(**data_dict) - - if isinstance(skip, str): - del data_dict[skip] - - if not isinstance(skip, list): - skip = [skip] - - enc_d = obj.encode_data(skip=skip, encoder=encoder) - - expected_keys = set(data_dict.keys()) - obtained_keys = set(enc_d.keys()) - - dif = expected_keys.difference(obtained_keys) - - assert len(dif) == 0 - - check_dict(data_dict, enc_d) - - -######################################################################################################################## -# TESTING DECODING -######################################################################################################################## -@pytest.mark.parametrize(**dp_param_dict) -def test_variable_DictSerializer_decode(dp_kwargs: dict, dp_cls: Type[DescriptorNumber]): - data_dict = {k: v for k, v in dp_kwargs.items() if k[0] != "@"} - - obj = dp_cls(**data_dict) - - enc = obj.encode(encoder=DictSerializer) - global_object.map._clear() - dec = dp_cls.decode(enc, decoder=DictSerializer) - - for k in data_dict.keys(): - if hasattr(obj, k) and hasattr(dec, k): - assert getattr(obj, k) == getattr(dec, k) - else: - raise AttributeError(f"{k} not found in decoded object") - - -@pytest.mark.parametrize(**dp_param_dict) -def test_variable_DictSerializer_from_dict(dp_kwargs: dict, dp_cls: Type[DescriptorNumber]): - data_dict = {k: v for k, v in dp_kwargs.items() if k[0] != "@"} - - obj = dp_cls(**data_dict) - - enc = obj.encode(encoder=DictSerializer) - global_object.map._clear() - dec = dp_cls.from_dict(enc) - - for k in data_dict.keys(): - if hasattr(obj, k) and hasattr(dec, k): - assert getattr(obj, k) == getattr(dec, k) - else: - raise AttributeError(f"{k} not found in decoded object") - - -@pytest.mark.parametrize(**dp_param_dict) -def test_variable_DataDictSerializer_decode(dp_kwargs: dict, dp_cls: Type[DescriptorNumber]): - data_dict = {k: v for k, v in dp_kwargs.items() if k[0] != "@"} - - obj = dp_cls(**data_dict) - - enc = obj.encode(encoder=DataDictSerializer) - with pytest.raises(NotImplementedError): - dec = obj.decode(enc, decoder=DataDictSerializer) - - -def test_group_encode(): - d0 = DescriptorNumber("a", 0) - d1 = DescriptorNumber("b", 1) - - from easyscience.Objects.Groups import BaseCollection - - b = BaseCollection("test", d0, d1) - d = b.as_dict() - assert isinstance(d["data"], list) - - -def test_group_encode2(): - d0 = DescriptorNumber("a", 0) - d1 = DescriptorNumber("b", 1) - - from easyscience.Objects.Groups import BaseCollection - - b = BaseObj("outer", b=BaseCollection("test", d0, d1)) - d = b.as_dict() - assert isinstance(d["b"], dict) \ No newline at end of file diff --git a/tests/unit_tests/utils/io_tests/test_json.py b/tests/unit_tests/utils/io_tests/test_json.py deleted file mode 100644 index 54f9ccb9..00000000 --- a/tests/unit_tests/utils/io_tests/test_json.py +++ /dev/null @@ -1,125 +0,0 @@ -__author__ = "github.com/wardsimon" -__version__ = "0.0.1" - -import json -from copy import deepcopy -from typing import Type - -import pytest - -from easyscience.Utils.io.json import JsonDataSerializer -from easyscience.Utils.io.json import JsonSerializer -from easyscience.Objects.variable import DescriptorNumber - -from .test_core import check_dict -from .test_core import dp_param_dict -from .test_core import skip_dict -from easyscience import global_object - - -def recursive_remove(d, remove_keys: list) -> dict: - """ - Remove keys from a dictionary. - """ - if not isinstance(remove_keys, list): - remove_keys = [remove_keys] - if isinstance(d, dict): - dd = {} - for k in d.keys(): - if k not in remove_keys: - dd[k] = recursive_remove(d[k], remove_keys) - return dd - else: - return d - - -######################################################################################################################## -# TESTING ENCODING -######################################################################################################################## -@pytest.mark.parametrize(**skip_dict) -@pytest.mark.parametrize(**dp_param_dict) -def test_variable_DictSerializer(dp_kwargs: dict, dp_cls: Type[DescriptorNumber], skip): - data_dict = {k: v for k, v in dp_kwargs.items() if k[0] != "@"} - - obj = dp_cls(**data_dict) - - dp_kwargs = deepcopy(dp_kwargs) - - if isinstance(skip, str): - del dp_kwargs[skip] - - if not isinstance(skip, list): - skip = [skip] - - enc = obj.encode(skip=skip, encoder=JsonSerializer) - assert isinstance(enc, str) - - # We can test like this as we don't have "complex" objects yet - dec = json.loads(enc) - expected_keys = set(dp_kwargs.keys()) - obtained_keys = set(dec.keys()) - - dif = expected_keys.difference(obtained_keys) - - assert len(dif) == 0 - - check_dict(dp_kwargs, dec) - - -@pytest.mark.parametrize(**skip_dict) -@pytest.mark.parametrize(**dp_param_dict) -def test_variable_DataDictSerializer(dp_kwargs: dict, dp_cls: Type[DescriptorNumber], skip): - data_dict = {k: v for k, v in dp_kwargs.items() if k[0] != "@"} - - obj = dp_cls(**data_dict) - - if isinstance(skip, str): - del data_dict[skip] - - if not isinstance(skip, list): - skip = [skip] - - enc = obj.encode(skip=skip, encoder=JsonDataSerializer) - assert isinstance(enc, str) - enc_d = json.loads(enc) - - expected_keys = set(data_dict.keys()) - obtained_keys = set(enc_d.keys()) - - dif = expected_keys.difference(obtained_keys) - - assert len(dif) == 0 - - check_dict(data_dict, enc_d) - -# ######################################################################################################################## -# # TESTING DECODING -# ######################################################################################################################## -@pytest.mark.parametrize(**dp_param_dict) -def test_variable_DictSerializer_decode(dp_kwargs: dict, dp_cls: Type[DescriptorNumber]): - data_dict = {k: v for k, v in dp_kwargs.items() if k[0] != "@"} - - obj = dp_cls(**data_dict) - - enc = obj.encode(encoder=JsonSerializer) - global_object.map._clear() - assert isinstance(enc, str) - dec = obj.decode(enc, decoder=JsonSerializer) - - for k in data_dict.keys(): - if hasattr(obj, k) and hasattr(dec, k): - assert getattr(obj, k) == getattr(dec, k) - else: - raise AttributeError(f"{k} not found in decoded object") - - -@pytest.mark.parametrize(**dp_param_dict) -def test_variable_DataDictSerializer_decode(dp_kwargs: dict, dp_cls: Type[DescriptorNumber]): - data_dict = {k: v for k, v in dp_kwargs.items() if k[0] != "@"} - - obj = dp_cls(**data_dict) - - enc = obj.encode(encoder=JsonDataSerializer) - global_object.map._clear() - with pytest.raises(NotImplementedError): - dec = obj.decode(enc, decoder=JsonDataSerializer) diff --git a/tests/unit_tests/utils/io_tests/test_xml.py b/tests/unit_tests/utils/io_tests/test_xml.py deleted file mode 100644 index b382bf89..00000000 --- a/tests/unit_tests/utils/io_tests/test_xml.py +++ /dev/null @@ -1,113 +0,0 @@ -__author__ = "github.com/wardsimon" -__version__ = "0.0.1" - -import sys -import xml.etree.ElementTree as ET -from copy import deepcopy -from typing import Type - -import pytest - -from easyscience.Utils.io.xml import XMLSerializer -from easyscience.Objects.variable import DescriptorNumber - -from .test_core import dp_param_dict -from .test_core import skip_dict -from easyscience import global_object - -def recursive_remove(d, remove_keys: list) -> dict: - """ - Remove keys from a dictionary. - """ - if not isinstance(remove_keys, list): - remove_keys = [remove_keys] - if isinstance(d, dict): - dd = {} - for k in d.keys(): - if k not in remove_keys: - dd[k] = recursive_remove(d[k], remove_keys) - return dd - else: - return d - - -def recursive_test(testing_obj, reference_obj): - for i, (k, v) in enumerate(testing_obj.items()): - if isinstance(v, dict): - recursive_test(v, reference_obj[i]) - else: - assert v == XMLSerializer.string_to_variable(reference_obj[i].text) - - -######################################################################################################################## -# TESTING ENCODING -######################################################################################################################## -@pytest.mark.parametrize(**skip_dict) -@pytest.mark.parametrize(**dp_param_dict) -def test_variable_XMLDictSerializer(dp_kwargs: dict, dp_cls: Type[DescriptorNumber], skip): - data_dict = {k: v for k, v in dp_kwargs.items() if k[0] != "@"} - - obj = dp_cls(**data_dict) - - dp_kwargs = deepcopy(dp_kwargs) - - if isinstance(skip, str): - del dp_kwargs[skip] - - if not isinstance(skip, list): - skip = [skip] - - enc = obj.encode(skip=skip, encoder=XMLSerializer) - ref_encode = obj.encode(skip=skip) - assert isinstance(enc, str) - data_xml = ET.XML(enc) - assert data_xml.tag == "data" - recursive_test(data_xml, ref_encode) - -# ######################################################################################################################## -# # TESTING DECODING -# ######################################################################################################################## -@pytest.mark.parametrize(**dp_param_dict) -def test_variable_XMLDictSerializer_decode(dp_kwargs: dict, dp_cls: Type[DescriptorNumber]): - data_dict = {k: v for k, v in dp_kwargs.items() if k[0] != "@"} - - obj = dp_cls(**data_dict) - - enc = obj.encode(encoder=XMLSerializer) - assert isinstance(enc, str) - data_xml = ET.XML(enc) - assert data_xml.tag == "data" - global_object.map._clear() - dec = dp_cls.decode(enc, decoder=XMLSerializer) - - for k in data_dict.keys(): - if hasattr(obj, k) and hasattr(dec, k): - assert getattr(obj, k) == getattr(dec, k) - else: - raise AttributeError(f"{k} not found in decoded object") - - -def test_slow_encode(): - - if sys.version_info < (3, 9): - pytest.skip("This test is only for python 3.9+") - - a = {"a": [1, 2, 3]} - slow_xml = XMLSerializer().encode(a, fast=False) - reference = """ - 1 - 2 - 3 -""" - assert slow_xml == reference - - -def test_include_header(): - - if sys.version_info < (3, 9): - pytest.skip("This test is only for python 3.9+") - - a = {"a": [1, 2, 3]} - header_xml = XMLSerializer().encode(a, use_header=True) - reference = '?xml version="1.0" encoding="UTF-8"?\n\n 1\n 2\n 3\n' - assert header_xml == reference diff --git a/tests/unit_tests/variable/__init__.py b/tests/unit_tests/variable/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unit_tests/Objects/variable/test_descriptor_any_type.py b/tests/unit_tests/variable/test_descriptor_any_type.py similarity index 79% rename from tests/unit_tests/Objects/variable/test_descriptor_any_type.py rename to tests/unit_tests/variable/test_descriptor_any_type.py index 5b8a131b..70c4cc65 100644 --- a/tests/unit_tests/Objects/variable/test_descriptor_any_type.py +++ b/tests/unit_tests/variable/test_descriptor_any_type.py @@ -1,7 +1,7 @@ import pytest import numpy as np -from easyscience.Objects.variable.descriptor_any_type import DescriptorAnyType +from easyscience.variable import DescriptorAnyType from easyscience import global_object class TestDescriptorAnyType: @@ -75,18 +75,4 @@ def test_copy(self, descriptor: DescriptorAnyType): # Expect assert type(descriptor_copy) == DescriptorAnyType - assert descriptor_copy._value == descriptor._value - - def test_as_data_dict(self, clear, descriptor: DescriptorAnyType): - # When Then - descriptor_dict = descriptor.as_data_dict() - - # Expect - assert descriptor_dict == { - "name": "name", - "value": "string", - "description": "description", - "url": "url", - "display_name": "display_name", - "unique_name": "DescriptorAnyType_0" - } \ No newline at end of file + assert descriptor_copy._value == descriptor._value \ No newline at end of file diff --git a/tests/unit_tests/Objects/variable/test_descriptor_array.py b/tests/unit_tests/variable/test_descriptor_array.py similarity index 97% rename from tests/unit_tests/Objects/variable/test_descriptor_array.py rename to tests/unit_tests/variable/test_descriptor_array.py index 2708f4ed..695f5fe2 100644 --- a/tests/unit_tests/Objects/variable/test_descriptor_array.py +++ b/tests/unit_tests/variable/test_descriptor_array.py @@ -6,8 +6,8 @@ import numpy as np -from easyscience.Objects.variable.descriptor_array import DescriptorArray -from easyscience.Objects.variable.descriptor_number import DescriptorNumber +from easyscience.variable import DescriptorArray +from easyscience import DescriptorNumber from easyscience import global_object class TestDescriptorArray: @@ -218,32 +218,6 @@ def test_copy(self, descriptor: DescriptorArray): assert type(descriptor_copy) == DescriptorArray assert np.array_equal(descriptor_copy._array.values, descriptor._array.values) assert descriptor_copy._array.unit == descriptor._array.unit - - def test_as_data_dict(self, clear, descriptor: DescriptorArray): - # When - descriptor_dict = descriptor.as_data_dict() - - # Expected dictionary - expected_dict = { - "name": "name", - "value": np.array([[1.0, 2.0], [3.0, 4.0]]), # Use numpy array for comparison - "unit": "m", - "variance": np.array([[0.1, 0.2], [0.3, 0.4]]), # Use numpy array for comparison - "description": "description", - "url": "url", - "display_name": "display_name", - "unique_name": "DescriptorArray_0", - "dimensions": np.array(['dim0', 'dim1']), # Use numpy array for comparison - } - - # Then: Compare dictionaries key by key - for key, expected_value in expected_dict.items(): - if isinstance(expected_value, np.ndarray): - # Compare numpy arrays - assert np.array_equal(descriptor_dict[key], expected_value), f"Mismatch for key: {key}" - else: - # Compare other values directly - assert descriptor_dict[key] == expected_value, f"Mismatch for key: {key}" @pytest.mark.parametrize("unit_string, expected", [ ("1e+9", "dimensionless"), diff --git a/tests/unit_tests/Objects/variable/test_descriptor_base.py b/tests/unit_tests/variable/test_descriptor_base.py similarity index 93% rename from tests/unit_tests/Objects/variable/test_descriptor_base.py rename to tests/unit_tests/variable/test_descriptor_base.py index 38140a34..aeeb823e 100644 --- a/tests/unit_tests/Objects/variable/test_descriptor_base.py +++ b/tests/unit_tests/variable/test_descriptor_base.py @@ -1,7 +1,7 @@ import pytest from easyscience import global_object -from easyscience.Objects.variable.descriptor_base import DescriptorBase +from easyscience.variable.descriptor_base import DescriptorBase class TestDesciptorBase: @@ -140,19 +140,6 @@ def test_copy(self, descriptor: DescriptorBase): assert descriptor_copy._url == descriptor._url assert descriptor_copy._display_name == descriptor._display_name - def test_as_data_dict(self, clear, descriptor: DescriptorBase): - # When Then - descriptor_dict = descriptor.as_data_dict() - - # Expect - assert descriptor_dict == { - "name": "name", - "description": "description", - "url": "url", - "display_name": "display_name", - "unique_name": "DescriptorBase_0", - } - def test_unique_name_generator(self, clear, descriptor: DescriptorBase): # When second_descriptor = DescriptorBase(name="test", unique_name="DescriptorBase_2") diff --git a/tests/unit_tests/Objects/variable/test_descriptor_bool.py b/tests/unit_tests/variable/test_descriptor_bool.py similarity index 81% rename from tests/unit_tests/Objects/variable/test_descriptor_bool.py rename to tests/unit_tests/variable/test_descriptor_bool.py index 63bf484a..4be20657 100644 --- a/tests/unit_tests/Objects/variable/test_descriptor_bool.py +++ b/tests/unit_tests/variable/test_descriptor_bool.py @@ -1,6 +1,6 @@ import pytest -from easyscience.Objects.variable.descriptor_bool import DescriptorBool +from easyscience.variable import DescriptorBool from easyscience import global_object class TestDescriptorBool: @@ -74,18 +74,4 @@ def test_copy(self, descriptor: DescriptorBool): # Expect assert type(descriptor_copy) == DescriptorBool - assert descriptor_copy._bool_value == descriptor._bool_value - - def test_as_data_dict(self, clear, descriptor: DescriptorBool): - # When Then - descriptor_dict = descriptor.as_data_dict() - - # Expect - assert descriptor_dict == { - "name": "name", - "value": True, - "description": "description", - "url": "url", - "display_name": "display_name", - "unique_name": "DescriptorBool_0" - } \ No newline at end of file + assert descriptor_copy._bool_value == descriptor._bool_value \ No newline at end of file diff --git a/tests/unit_tests/Objects/variable/test_descriptor_number.py b/tests/unit_tests/variable/test_descriptor_number.py similarity index 96% rename from tests/unit_tests/Objects/variable/test_descriptor_number.py rename to tests/unit_tests/variable/test_descriptor_number.py index b9de196a..5dc4e060 100644 --- a/tests/unit_tests/Objects/variable/test_descriptor_number.py +++ b/tests/unit_tests/variable/test_descriptor_number.py @@ -3,7 +3,7 @@ import scipp as sc from scipp import UnitError -from easyscience.Objects.variable.descriptor_number import DescriptorNumber +from easyscience import DescriptorNumber from easyscience import global_object class TestDescriptorNumber: @@ -200,22 +200,6 @@ def test_copy(self, descriptor: DescriptorNumber): assert descriptor_copy._scalar.value == descriptor._scalar.value assert descriptor_copy._scalar.unit == descriptor._scalar.unit - def test_as_data_dict(self, clear, descriptor: DescriptorNumber): - # When Then - descriptor_dict = descriptor.as_data_dict() - - # Expect - assert descriptor_dict == { - "name": "name", - "value": 1.0, - "unit": "m", - "variance": 0.1, - "description": "description", - "url": "url", - "display_name": "display_name", - "unique_name": "DescriptorNumber_0", - } - @pytest.mark.parametrize("unit_string, expected", [ ("1e+9", "dimensionless"), ("1000", "dimensionless"), diff --git a/tests/unit_tests/Objects/variable/test_descriptor_str.py b/tests/unit_tests/variable/test_descriptor_str.py similarity index 79% rename from tests/unit_tests/Objects/variable/test_descriptor_str.py rename to tests/unit_tests/variable/test_descriptor_str.py index 71c50715..aa593ed9 100644 --- a/tests/unit_tests/Objects/variable/test_descriptor_str.py +++ b/tests/unit_tests/variable/test_descriptor_str.py @@ -1,6 +1,6 @@ import pytest -from easyscience.Objects.variable.descriptor_str import DescriptorStr +from easyscience.variable import DescriptorStr from easyscience import global_object class TestDescriptorStr: @@ -73,18 +73,4 @@ def test_copy(self, descriptor: DescriptorStr): # Expect assert type(descriptor_copy) == DescriptorStr - assert descriptor_copy._string == descriptor._string - - def test_as_data_dict(self, clear, descriptor: DescriptorStr): - # When Then - descriptor_dict = descriptor.as_data_dict() - - # Expect - assert descriptor_dict == { - "name": "name", - "value": "string", - "description": "description", - "url": "url", - "display_name": "display_name", - "unique_name": "DescriptorStr_0" - } \ No newline at end of file + assert descriptor_copy._string == descriptor._string \ No newline at end of file diff --git a/tests/unit_tests/Objects/variable/test_parameter.py b/tests/unit_tests/variable/test_parameter.py similarity index 95% rename from tests/unit_tests/Objects/variable/test_parameter.py rename to tests/unit_tests/variable/test_parameter.py index 8418a3da..356ed83d 100644 --- a/tests/unit_tests/Objects/variable/test_parameter.py +++ b/tests/unit_tests/variable/test_parameter.py @@ -5,10 +5,10 @@ from scipp import UnitError -from easyscience.Objects.variable.parameter import Parameter -from easyscience.Objects.variable.descriptor_number import DescriptorNumber +from easyscience import Parameter +from easyscience import DescriptorNumber from easyscience import global_object -from easyscience.Objects.ObjectClasses import BaseObj +from easyscience import ObjBase class TestParameter: @pytest.fixture @@ -214,7 +214,7 @@ def test_process_dependency_unique_names_exception_unique_name_does_not_exist(se def test_process_dependency_unique_names_exception_not_a_descriptorNumber(self, clear, normal_parameter: Parameter): # When normal_parameter._dependency_map = {} - base_obj = BaseObj(name='BaseObj', unique_name='base_obj') + base_obj = ObjBase(name='ObjBase', unique_name='base_obj') # Then Expect with pytest.raises(ValueError, match='The object with unique_name base_obj is not a Parameter or DescriptorNumber. Please check your dependency expression.'): @@ -224,8 +224,8 @@ def test_process_dependency_unique_names_exception_not_a_descriptorNumber(self, (2, {'a': Parameter(name='a', value=1)}), ('2*a', ['a', Parameter(name='a', value=1)]), ('2*a', {4: Parameter(name='a', value=1)}), - ('2*a', {'a': BaseObj(name='a')}), - ], ids=["dependecy_expression_not_a_string", "dependency_map_not_a_dict", "dependency_map_keys_not_strings", "dependency_map_values_not_descriptor_number"]) + ('2*a', {'a': ObjBase(name='a')}), + ], ids=["dependency_expression_not_a_string", "dependency_map_not_a_dict", "dependency_map_keys_not_strings", "dependency_map_values_not_descriptor_number"]) def test_parameter_from_dependency_input_exceptions(self, dependency_expression, dependency_map): # When Then Expect with pytest.raises(TypeError): @@ -239,15 +239,46 @@ def test_parameter_from_dependency_input_exceptions(self, dependency_expression, ('2*a + b', NameError), ('2*a + 3*', SyntaxError), ('2 + 2', TypeError), - ], ids=["parameter_not_in_map", "invalid_dependency_expression", "result_not_a_descriptor_number"]) - def test_parameter_from_dependency_evaluation_exceptions(self, normal_parameter, dependency_expression, error): - # When Then Expect + ('2*"special_name"', ValueError), + ], ids=["parameter_not_in_map", "invalid_dependency_expression", "result_not_a_descriptor_number", "unique_name_does_not_exist"]) + def test_parameter_make_dependent_on_exceptions_cleanup_previously_dependent(self, normal_parameter, dependency_expression, error): + # When + independent_parameter = Parameter(name='independent', value=10, unit='s', variance=0.02) + dependent_parameter = Parameter.from_dependency( + name= 'dependent', + dependency_expression='best', + dependency_map={'best': independent_parameter} + ) + # Then Expect + # Check that the correct error is raised with pytest.raises(error): - Parameter.from_dependency( - name = 'dependent', + dependent_parameter.make_dependent_on( dependency_expression=dependency_expression, dependency_map={'a': normal_parameter}, - ) + ) + # Check that everything is properly cleaned up + assert normal_parameter._observers == [] + assert dependent_parameter.independent == False + assert dependent_parameter.dependency_expression == 'best' + assert dependent_parameter.dependency_map == {'best': independent_parameter} + independent_parameter.value = 50 + self.compare_parameters(dependent_parameter, independent_parameter) + + def test_parameter_make_dependent_on_exceptions_cleanup_previously_independent(self, normal_parameter): + # When + independent_parameter = Parameter(name='independent', value=10, unit='s', variance=0.02) + # Then Expect + # Check that the correct error is raised + with pytest.raises(NameError): + independent_parameter.make_dependent_on( + dependency_expression='2*a + b', + dependency_map={'a': normal_parameter}, + ) + # Check that everything is properly cleaned up + assert normal_parameter._observers == [] + assert independent_parameter.independent == True + normal_parameter.value = 50 + assert independent_parameter.value == 10 def test_dependent_parameter_updates(self, normal_parameter: Parameter): # When @@ -317,6 +348,12 @@ def test_dependent_parameter_cyclic_dependencies(self, normal_parameter: Paramet # Then Expect with pytest.raises(RuntimeError): normal_parameter.make_dependent_on(dependency_expression='2*c', dependency_map={'c': dependent_parameter_2}) + # Check that everything is properly cleaned up + assert dependent_parameter_2._observers == [] + assert normal_parameter.independent == True + assert normal_parameter.value == 1 + normal_parameter.value = 50 + self.compare_parameters(dependent_parameter_2, 4*normal_parameter) def test_dependent_parameter_logical_dependency(self, normal_parameter: Parameter): # When @@ -663,26 +700,6 @@ def test_copy(self, parameter: Parameter): assert parameter_copy._display_name == parameter._display_name assert parameter_copy._independent == parameter._independent - def test_as_data_dict(self, clear, parameter: Parameter): - # When Then - self.mock_callback.fget.return_value = 1.0 # Ensure fget returns a scalar value - parameter_dict = parameter.as_data_dict() - - # Expect - assert parameter_dict == { - "name": "name", - "value": 1.0, - "unit": "m", - "variance": 0.01, - "min": 0, - "max": 10, - "fixed": False, - "description": "description", - "url": "url", - "display_name": "display_name", - "unique_name": "Parameter_0", - } - @pytest.mark.parametrize("test, expected, expected_reverse", [ (Parameter("test", 2, "m", 0.01, -10, 20), Parameter("name + test", 3, "m", 0.02, -10, 30), Parameter("test + name", 3, "m", 0.02, -10, 30)), (Parameter("test", 2, "m", 0.01), Parameter("name + test", 3, "m", 0.02, min=-np.inf, max=np.inf),Parameter("test + name", 3, "m", 0.02, min=-np.inf, max=np.inf)),