Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
6719eb2
evaluate for physics condition
ndem0 Mar 19, 2026
a873e28
input target evaluate
ndem0 Mar 19, 2026
1450330
first new generation solver
ndem0 Mar 19, 2026
9dac226
pinn and supervised
ndem0 Mar 20, 2026
b9a094c
autoregressive
ndem0 Mar 20, 2026
a6c1b0e
multi model
ndem0 Apr 2, 2026
24fec4f
ensemble solver
ndem0 Apr 16, 2026
9b549f0
labelize_forward on ensemble
ndem0 Apr 16, 2026
37f0755
Clean inheritance
ndem0 May 7, 2026
05f3059
Fix conflicts
ndem0 May 8, 2026
3d4c6e5
fix import
ndem0 May 11, 2026
1da0445
minor
ndem0 May 12, 2026
a10da4d
remove print statements
GiovanniCanali May 12, 2026
103d747
delete tests/data_manager.py (duplicate)
GiovanniCanali May 12, 2026
64835e6
fix loss-related imports
GiovanniCanali May 12, 2026
c5325a9
use BaseProblem instead of AbstractProblem
GiovanniCanali May 12, 2026
74ccfda
emptied __init__ files in _src
GiovanniCanali May 12, 2026
f84fb29
fix optimizers and schedulers type
GiovanniCanali May 12, 2026
a7ab130
delete __init__ from interface
GiovanniCanali May 12, 2026
b185859
fix import in test_solver
GiovanniCanali May 12, 2026
2079405
pylint
GiovanniCanali May 12, 2026
2adb38b
fix bugs
ndem0 May 12, 2026
14695ea
fix test_input_target_condition
GiovanniCanali May 12, 2026
7677059
fix conditions and tests
GiovanniCanali May 14, 2026
00116b8
minor bug fixing
GiovanniCanali May 14, 2026
78fb2a6
fix autoregressive solver
GiovanniCanali May 15, 2026
0f0d7e9
fix autoregressive tests + small bugs
GiovanniCanali May 15, 2026
2c396a5
remove aliasing
GiovanniCanali May 18, 2026
d1166b6
fix bug in equation conditions
GiovanniCanali May 18, 2026
4519b36
enable gradient via decorators
GiovanniCanali May 19, 2026
2fc1018
simplify reduction logic
GiovanniCanali May 19, 2026
b16e8f5
fix supervised ensemble
GiovanniCanali May 19, 2026
35f66a7
temporarily disable old solver tests
GiovanniCanali May 19, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions pina/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,10 @@
"Condition",
"PinaDataModule",
"Graph",
"SolverInterface",
"MultiSolverInterface",
]

from pina._src.core.label_tensor import LabelTensor
from pina._src.core.graph import Graph
from pina._src.solver.solver import SolverInterface, MultiSolverInterface
from pina._src.core.trainer import Trainer
from pina._src.condition.condition import Condition
from pina._src.data.data_module import PinaDataModule
6 changes: 3 additions & 3 deletions pina/_src/callback/refinement/r3_refinement.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
)
from pina._src.core.label_tensor import LabelTensor
from pina._src.core.utils import check_consistency
from pina._src.loss.loss_interface import LossInterface
from pina._src.loss.loss_interface import DualLossInterface


class R3Refinement(RefinementInterface):
Expand Down Expand Up @@ -44,7 +44,7 @@ def __init__(
:param int sample_every: The sampling frequency.
:param loss: The loss function to compute the residuals.
Default is :class:`~torch.nn.L1Loss`.
:type loss: LossInterface | :class:`~torch.nn.modules.loss._Loss`
:type loss: DualLossInterface | :class:`~torch.nn.modules.loss._Loss`
:param condition_to_update: The conditions to update during the
refinement process. If None, all conditions will be updated.
Default is None.
Expand All @@ -59,7 +59,7 @@ def __init__(
# Check consistency
check_consistency(
residual_loss,
(LossInterface, torch.nn.modules.loss._Loss),
(DualLossInterface, torch.nn.modules.loss._Loss),
subclass=True,
)

Expand Down
16 changes: 7 additions & 9 deletions pina/_src/callback/refinement/refinement_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,7 @@
from abc import ABCMeta, abstractmethod
from lightning.pytorch import Callback
from pina._src.core.utils import check_consistency
from pina._src.solver.physics_informed_solver.pinn_interface import (
PINNInterface,
)
from pina._src.solver.pinn import PINN


class RefinementInterface(Callback, metaclass=ABCMeta):
Expand Down Expand Up @@ -52,7 +50,7 @@ def on_train_start(self, trainer, solver):
object.
:param ~pina.solver.solver.SolverInterface solver: The solver
object associated with the trainer.
:raises RuntimeError: If the solver is not a PINNInterface.
:raises RuntimeError: If the solver is not a PINN.
:raises RuntimeError: If the conditions do not have a domain to sample
from.
"""
Expand All @@ -76,11 +74,11 @@ def on_train_start(self, trainer, solver):
"sample from."
)
# check solver
if not isinstance(solver, PINNInterface):
if not isinstance(solver, PINN):
raise RuntimeError(
"Refinment strategies are currently implemented only "
"for physics informed based solvers. Please use a Solver "
"inheriting from 'PINNInterface'."
"inheriting from 'PINN'."
)
# store dataset
self._dataset = trainer.datamodule.train_dataset
Expand All @@ -95,7 +93,7 @@ def on_train_epoch_end(self, trainer, solver):
Performs the refinement at the end of each training epoch (if needed).

:param ~lightning.pytorch.trainer.trainer.Trainer: The trainer object.
:param PINNInterface solver: The solver object.
:param PINN solver: The solver object.
"""
if (trainer.current_epoch % self.sample_every == 0) and (
trainer.current_epoch != 0
Expand All @@ -110,7 +108,7 @@ def sample(self, current_points, condition_name, solver):

:param current_points: Current points in the domain.
:param condition_name: Name of the condition to update.
:param PINNInterface solver: The solver object.
:param PINN solver: The solver object.
:return: New points sampled based on the R3 strategy.
:rtype: LabelTensor
"""
Expand All @@ -133,7 +131,7 @@ def _update_points(self, solver):
"""
Performs the refinement of the points.

:param PINNInterface solver: The solver object.
:param PINN solver: The solver object.
"""
new_points = {}
for name in self._condition_to_update:
Expand Down
97 changes: 58 additions & 39 deletions pina/_src/condition/condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from pina._src.condition.input_equation_condition import InputEquationCondition
from pina._src.condition.input_target_condition import InputTargetCondition
from pina._src.condition.time_series_condition import TimeSeriesCondition
from pina._src.condition.data_condition import DataCondition
from pina._src.condition.domain_equation_condition import (
DomainEquationCondition,
Expand Down Expand Up @@ -45,20 +46,29 @@ class Condition:
represents a general physics-informed condition defined by ``input``
points and an ``equation``. The model learns to minimize the equation
residual through evaluations performed at the provided ``input``.
Supported data types for the ``input`` include
:class:`~pina.label_tensor.LabelTensor` or :class:`~pina.graph.Graph`.
Supported data types for the ``input`` include :class:`~pina.graph.Graph`
or :class:`~pina.label_tensor.LabelTensor`. The class automatically
selects the appropriate implementation based on the types of the
``input``.

- :class:`~pina.condition.time_series_condition.TimeSeriesCondition`:
represents a condition designed for time series data, where the model is
trained to capture temporal dependencies and dynamics. It is defined by an
``input`` tensor of shape ``[trajectories, time_steps, *features]``
containing time series data. Supported data types for the ``input``
include class:`~pina.label_tensor.LabelTensor` or :class:`torch.Tensor`.
The class automatically selects the appropriate implementation based on
the types of the ``input``.
the type of the ``input``.

- :class:`~pina.condition.data_condition.DataCondition`: represents an
unsupervised, data-driven condition defined by the ``input`` only.
The model is trained using a custom unsupervised loss determined by the
chosen :class:`~pina.solver.solver.SolverInterface`, while leveraging the
provided data during training. Optional ``conditional_variables`` can be
specified when the model depends on additional parameters.
Supported data types include :class:`torch.Tensor`,
:class:`~pina.label_tensor.LabelTensor`, :class:`~pina.graph.Graph`, or
:class:`~torch_geometric.data.Data`. The class automatically selects the
Supported data types include :class:`~pina.label_tensor.LabelTensor`,
:class:`torch.Tensor`, :class:`~torch_geometric.data.Data`, or
:class:`~pina.graph.Graph`. The class automatically selects the
appropriate implementation based on the type of the ``input``.

.. note::
Expand All @@ -80,20 +90,32 @@ class Condition:
>>> # Example of InputEquationCondition signature
>>> condition = Condition(input=input, equation=equation)

>>> # Example of TimeSeriesCondition signature
>>> condition = Condition(
... input=input, n_windows=n_windows, unroll_length=unroll_length
... )

>>> # Example of DataCondition signature
>>> condition = Condition(input=data, conditional_variables=cond_vars)
"""

# Combine all possible keyword arguments from the different Condition types
available_kwargs = list(
set(
InputTargetCondition.__fields__
+ InputEquationCondition.__fields__
+ DomainEquationCondition.__fields__
+ DataCondition.__fields__
)
# Internal specifications for condition types, used for dispatching
# Each tuple contains: (condition class, required kwargs, optional kwargs)
_SPECS = (
(InputTargetCondition, {"input", "target"}, set()),
(InputEquationCondition, {"input", "equation"}, set()),
(DomainEquationCondition, {"domain", "equation"}, set()),
(DataCondition, {"input"}, {"conditional_variables"}),
(
TimeSeriesCondition,
{"input", "n_windows", "unroll_length"},
{"randomize"},
),
)

# Compute the set of all available keyword arguments (optional + required)
available_kwargs = sorted(set().union(*(rq | op for _, rq, op in _SPECS)))

def __new__(cls, *args, **kwargs):
"""
Instantiate the appropriate :class:`Condition` object based on the
Expand All @@ -103,38 +125,35 @@ def __new__(cls, *args, **kwargs):
:param dict kwargs: The keyword arguments corresponding to the
parameters of the specific :class:`Condition` type to instantiate.
:raises ValueError: If unexpected positional arguments are provided.
:raises ValueError: If the keyword arguments are invalid.
:raises ValueError: If the keyword arguments do not match any valid
signature for the available condition types.
:return: The appropriate :class:`Condition` object.
:rtype: ConditionInterface
"""
# Check keyword arguments
if len(args) != 0:
# Ensure no positional arguments are provided
if args:
raise ValueError(
"Condition takes only the following keyword "
f"arguments: {Condition.available_kwargs}."
"Condition takes only keyword arguments. "
f"Available arguments are: {cls.available_kwargs}."
)

# Class specialization based on keyword arguments
sorted_keys = sorted(kwargs.keys())

# Input - Target Condition
if sorted_keys == sorted(InputTargetCondition.__fields__):
return InputTargetCondition(**kwargs)
# Iterate through the specifications to find a matching condition type
for condition_cls, required, optional in cls._SPECS:

# Input - Equation Condition
if sorted_keys == sorted(InputEquationCondition.__fields__):
return InputEquationCondition(**kwargs)
# Find allowed keys for condition type
allowed = required | optional

# Domain - Equation Condition
if sorted_keys == sorted(DomainEquationCondition.__fields__):
return DomainEquationCondition(**kwargs)
# Check if the provided keys match the required and optional keys
if required <= set(kwargs) <= allowed:
return condition_cls(**kwargs)

# Data Condition
if (
sorted_keys == sorted(DataCondition.__fields__)
or sorted_keys[0] == DataCondition.__fields__[0]
):
return DataCondition(**kwargs)
# If no valid signature is found, prepare a list of valid signatures
valid_signatures = [
sorted(required | optional) for _, required, optional in cls._SPECS
]

# Invalid keyword arguments
raise ValueError(f"Invalid keyword arguments {kwargs.keys()}.")
# If no valid signature is found, raise an error
raise ValueError(
f"Invalid keyword arguments {sorted(set(kwargs))}. "
f"Valid signatures are: {valid_signatures}."
)
25 changes: 25 additions & 0 deletions pina/_src/condition/condition_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,31 @@ def create_dataloader(
:rtype: torch.utils.data.DataLoader
"""

@abstractmethod
def evaluate(self, batch, solver, loss):
"""
Evaluate the residual of the condition on the given batch using the
solver.
This method computes the non-aggregated, element-wise residual of the
condition. A forward pass of the solver's model is performed on the
input samples, and the condition residual is evaluated accordingly.
The returned tensor is not reduced, preserving the per-sample residual
values.
:param dict batch: The batch containing the data required by the
condition evaluation.
:param SolverInterface solver: The solver used to perform the forward
pass and compute the residual. The solver provides access to the
model and its parameters, which may be necessary for evaluating the
condition residual.
:param torch.nn.Module loss: The non-aggregating loss function used to
compare the condition residual against its reference value.
:return: The non-aggregated residual tensor.
:rtype: torch.Tensor | LabelTensor
"""

@abstractmethod
def switch_dataloader_fn(self, create_dataloader_fn):
"""
Expand Down
26 changes: 26 additions & 0 deletions pina/_src/condition/data_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,32 @@ def store_data(self, **kwargs):

return _DataManager(**data_dict)

def evaluate(self, batch, solver, loss):
"""
Evaluate the residual of the condition on the given batch using the
solver.

This method computes the non-aggregated, element-wise residual of the
condition. A forward pass of the solver's model is performed on the
input samples, and the condition residual is evaluated accordingly.

The returned tensor is not reduced, preserving the per-sample residual
values.

:param dict batch: The batch containing the data required by the
condition evaluation.
:param SolverInterface solver: The solver used to perform the forward
pass and compute the residual. The solver provides access to the
model and its parameters, which may be necessary for evaluating the
condition residual.
:param torch.nn.Module loss: The non-aggregating loss function used to
compare the condition residual against its reference value.
:return: The non-aggregated residual tensor.
:rtype: torch.Tensor | LabelTensor
"""
output_ = solver.forward(batch["input"])
return loss(output_, torch.zeros_like(output_))

@property
def conditional_variables(self):
"""
Expand Down
32 changes: 32 additions & 0 deletions pina/_src/condition/domain_equation_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,38 @@ def store_data(self, **kwargs):
setattr(self, "domain", kwargs.get("domain"))
setattr(self, "equation", kwargs.get("equation"))

def evaluate(self, batch, solver, loss):
"""
Evaluate the residual of the condition on the given batch using the
solver.

This method computes the non-aggregated, element-wise residual of the
condition. A forward pass of the solver's model is performed on the
input samples, and the condition residual is evaluated accordingly.

The returned tensor is not reduced, preserving the per-sample residual
values.

:param dict batch: The batch containing the data required by the
condition evaluation.
:param SolverInterface solver: The solver used to perform the forward
pass and compute the residual. The solver provides access to the
model and its parameters, which may be necessary for evaluating the
condition residual.
:param torch.nn.Module loss: The non-aggregating loss function used to
compare the condition residual against its reference value.
:raises NotImplementedError: Always raised since any domain-equation
condition is transformed into an input-equation condition before
evaluation, and the residual is computed using the input-equation
condition's evaluation method.
"""
raise NotImplementedError(
"Domain-equation conditions are transformed into input-equation "
"conditions before evaluation, and the residual is computed using "
"the input-equation condition's evaluation method. Therefore, the "
"evaluate method is not implemented for domain-equation conditions."
)

@property
def equation(self):
"""
Expand Down
32 changes: 32 additions & 0 deletions pina/_src/condition/input_equation_condition.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Module for the Input-Equation Condition class."""

import torch
from pina._src.condition.base_condition import BaseCondition
from pina._src.core.label_tensor import LabelTensor
from pina._src.core.graph import Graph
Expand Down Expand Up @@ -107,3 +108,34 @@ def equation(self, value):
# Check consistency
check_consistency(value, self._avail_equation_cls)
self._equation = value

def evaluate(self, batch, solver, loss):
"""
Evaluate the residual of the condition on the given batch using the
solver.
This method computes the non-aggregated, element-wise residual of the
condition. A forward pass of the solver's model is performed on the
input samples, and the condition residual is evaluated accordingly.
The returned tensor is not reduced, preserving the per-sample residual
values.
:param dict batch: The batch containing the data required by the
condition evaluation.
:param SolverInterface solver: The solver used to perform the forward
pass and compute the residual. The solver provides access to the
model and its parameters, which may be necessary for evaluating the
condition residual.
:param torch.nn.Module loss: The non-aggregating loss function used to
compare the condition residual against its reference value.
:return: The non-aggregated residual tensor.
:rtype: LabelTensor
"""
# Compute residuals
samples = batch["input"].requires_grad_(True)
residual = self.equation.residual(
samples, solver.forward(samples), solver._params
)

return loss(residual, torch.zeros_like(residual))
Loading
Loading