diff --git a/firedrake/assemble.py b/firedrake/assemble.py index 181909161e..447345d563 100644 --- a/firedrake/assemble.py +++ b/firedrake/assemble.py @@ -29,6 +29,7 @@ from firedrake.utils import ScalarType, assert_empty, tuplify from pyop2 import op2 from pyop2.exceptions import MapValueError, SparsityFormatError +from pyop2.types.mat import _GlobalMatPayload, _DatMatPayload from pyop2.utils import cached_property @@ -965,22 +966,24 @@ def assemble(self, tensor=None): Result of assembly: `float` for 0-forms, `firedrake.cofunction.Cofunction` or `firedrake.function.Function` for 1-forms, and `matrix.MatrixBase` for 2-forms. """ - self._check_tensor(tensor) - if tensor is None: - tensor = self.allocate() - needs_zeroing = False - else: - needs_zeroing = self._needs_zeroing if annotate_tape(): raise NotImplementedError( "Taping with explicit FormAssembler objects is not supported yet. " "Use assemble instead." ) - if needs_zeroing: - type(self)._as_pyop2_type(tensor).zero() + + if tensor is None: + tensor = self.allocate() + else: + self._check_tensor(tensor) + if self._needs_zeroing: + self._as_pyop2_type(tensor).zero() + self.execute_parloops(tensor) + for bc in self._bcs: self._apply_bc(tensor, bc) + return self.result(tensor) @abc.abstractmethod @@ -992,9 +995,9 @@ def _check_tensor(self, tensor): """Check input tensor.""" @staticmethod - def _as_pyop2_type(tensor): - """Return tensor as pyop2 type.""" - raise NotImplementedError + @abc.abstractmethod + def _as_pyop2_type(tensor, indices=None): + """Cast a Firedrake tensor into a PyOP2 data structure, optionally indexing it.""" def execute_parloops(self, tensor): for parloop in self.parloops(tensor): @@ -1003,20 +1006,14 @@ def execute_parloops(self, tensor): def parloops(self, tensor): if hasattr(self, "_parloops"): for (lknl, _), parloop in zip(self.local_kernels, self._parloops): - data = _FormHandler.index_tensor(tensor, self._form, lknl.indices, self.diagonal) + data = self._as_pyop2_type(tensor, lknl.indices) parloop.arguments[0].data = data + else: # Make parloops for one concrete output tensor and cache them. - # TODO: Make parloops only with some symbolic information of the output tensor. - self._parloops = tuple(parloop_builder.build(tensor) for parloop_builder in self.parloop_builders) - return self._parloops - - @cached_property - def parloop_builders(self): - out = [] - for local_kernel, subdomain_id in self.local_kernels: - out.append( - ParloopBuilder( + parloops_ = [] + for local_kernel, subdomain_id in self.local_kernels: + parloop_builder = ParloopBuilder( self._form, self._bcs, local_kernel, @@ -1024,8 +1021,12 @@ def parloop_builders(self): self.all_integer_subdomain_ids, diagonal=self.diagonal, ) - ) - return tuple(out) + pyop2_tensor = self._as_pyop2_type(tensor, local_kernel.indices) + parloop = parloop_builder.build(pyop2_tensor) + parloops_.append(parloop) + self._parloops = tuple(parloops_) + + return self._parloops @cached_property def local_kernels(self): @@ -1120,10 +1121,11 @@ def _apply_bc(self, tensor, bc): pass def _check_tensor(self, tensor): - assert tensor is None + pass @staticmethod - def _as_pyop2_type(tensor): + def _as_pyop2_type(tensor, indices=None): + assert not indices return tensor def result(self, tensor): @@ -1198,15 +1200,16 @@ def _apply_dirichlet_bc(self, tensor, bc): bc.zero(tensor) def _check_tensor(self, tensor): - rank = len(self._form.arguments()) - if rank == 1: - test, = self._form.arguments() - if tensor is not None and test.function_space() != tensor.function_space(): - raise ValueError("Form's argument does not match provided result tensor") + if tensor.function_space() != self._form.arguments()[0].function_space(): + raise ValueError("Form's argument does not match provided result tensor") @staticmethod - def _as_pyop2_type(tensor): - return tensor.dat + def _as_pyop2_type(tensor, indices=None): + if indices is not None and any(index is not None for index in indices): + i, = indices + return tensor.dat[i] + else: + return tensor.dat def execute_parloops(self, tensor): # We are repeatedly incrementing into the same Dat so intermediate halo exchanges @@ -1454,12 +1457,26 @@ def _apply_bcs_mat_real_block(op2tensor, i, j, component, node_set): dat.zero(subset=node_set) def _check_tensor(self, tensor): - if tensor is not None and tensor.a.arguments() != self._form.arguments(): + if tensor.a.arguments() != self._form.arguments(): raise ValueError("Form's arguments do not match provided result tensor") @staticmethod - def _as_pyop2_type(tensor): - return tensor.M + def _as_pyop2_type(tensor, indices=None): + if indices is not None and any(index is not None for index in indices): + i, j = indices + mat = tensor.M[i, j] + else: + mat = tensor.M + + if mat.handle.getType() == "python": + mat_context = mat.handle.getPythonContext() + if isinstance(mat_context, _GlobalMatPayload): + mat = mat_context.global_ + else: + assert isinstance(mat_context, _DatMatPayload) + mat = mat_context.dat + + return mat def result(self, tensor): tensor.M.assemble() @@ -1471,7 +1488,7 @@ class MatrixFreeAssembler(FormAssembler): Parameters ---------- - form : ufl.Form or slate.TensorBasehe + form : ufl.Form or slate.TensorBase 2-form. Notes @@ -1498,14 +1515,15 @@ def allocate(self): appctx=self._appctx or {}) def assemble(self, tensor=None): - self._check_tensor(tensor) if tensor is None: tensor = self.allocate() + else: + self._check_tensor(tensor) tensor.assemble() return tensor def _check_tensor(self, tensor): - if tensor is not None and tensor.a.arguments() != self._form.arguments(): + if tensor.a.arguments() != self._form.arguments(): raise ValueError("Form's arguments do not match provided result tensor") @@ -1820,12 +1838,12 @@ def __init__(self, form, bcs, local_knl, subdomain_id, self._active_coefficients = _FormHandler.iter_active_coefficients(form, local_knl.kinfo) self._constants = _FormHandler.iter_constants(form, local_knl.kinfo) - def build(self, tensor): + def build(self, tensor: op2.Global | op2.Dat | op2.Mat) -> op2.Parloop: """Construct the parloop. Parameters ---------- - tensor : op2.Global or firedrake.cofunction.Cofunction or matrix.MatrixBase + tensor : The output tensor. """ @@ -1909,17 +1927,28 @@ def collect_lgmaps(self): :param local_knl: A :class:`tsfc_interface.SplitKernel`. :param bcs: Iterable of boundary conditions. """ + if len(self._form.arguments()) == 2 and not self._diagonal: if not self._bcs: return None - lgmaps = [] - for i, j in self.get_indicess(): + + if any(i is not None for i in self._local_knl.indices): + i, j = self._local_knl.indices row_bcs, col_bcs = self._filter_bcs(i, j) - rlgmap, clgmap = self._tensor.M[i, j].local_to_global_maps + # the tensor is already indexed + rlgmap, clgmap = self._tensor.local_to_global_maps rlgmap = self.test_function_space[i].local_to_global_map(row_bcs, rlgmap) clgmap = self.trial_function_space[j].local_to_global_map(col_bcs, clgmap) - lgmaps.append((rlgmap, clgmap)) - return tuple(lgmaps) + return ((rlgmap, clgmap),) + else: + lgmaps = [] + for i, j in self.get_indicess(): + row_bcs, col_bcs = self._filter_bcs(i, j) + rlgmap, clgmap = self._tensor[i, j].local_to_global_maps + rlgmap = self.test_function_space[i].local_to_global_map(row_bcs, rlgmap) + clgmap = self.trial_function_space[j].local_to_global_map(col_bcs, clgmap) + lgmaps.append((rlgmap, clgmap)) + return tuple(lgmaps) else: return None @@ -1939,10 +1968,6 @@ def _integral_type(self): def _indexed_function_spaces(self): return _FormHandler.index_function_spaces(self._form, self._indices) - @property - def _indexed_tensor(self): - return _FormHandler.index_tensor(self._tensor, self._form, self._indices, self._diagonal) - @cached_property def _mesh(self): return self._form.ufl_domains()[self._kinfo.domain_number] @@ -1990,28 +2015,27 @@ def _as_parloop_arg(tsfc_arg, self): @_as_parloop_arg.register(kernel_args.OutputKernelArg) def _as_parloop_arg_output(_, self): rank = len(self._form.arguments()) - tensor = self._indexed_tensor Vs = self._indexed_function_spaces if rank == 0: - return op2.GlobalParloopArg(tensor) + return op2.GlobalParloopArg(self._tensor) elif rank == 1 or rank == 2 and self._diagonal: V, = Vs if V.ufl_element().family() == "Real": - return op2.GlobalParloopArg(tensor) + return op2.GlobalParloopArg(self._tensor) else: - return op2.DatParloopArg(tensor, self._get_map(V)) + return op2.DatParloopArg(self._tensor, self._get_map(V)) elif rank == 2: rmap, cmap = [self._get_map(V) for V in Vs] if all(V.ufl_element().family() == "Real" for V in Vs): assert rmap is None and cmap is None - return op2.GlobalParloopArg(tensor.handle.getPythonContext().global_) + return op2.GlobalParloopArg(self._tensor) elif any(V.ufl_element().family() == "Real" for V in Vs): m = rmap or cmap - return op2.DatParloopArg(tensor.handle.getPythonContext().dat, m) + return op2.DatParloopArg(self._tensor, m) else: - return op2.MatParloopArg(tensor, (rmap, cmap), lgmaps=self.collect_lgmaps()) + return op2.MatParloopArg(self._tensor, (rmap, cmap), lgmaps=self.collect_lgmaps()) else: raise AssertionError @@ -2122,22 +2146,3 @@ def index_function_spaces(form, indices): return tuple(a.ufl_function_space()[i] for i, a in zip(indices, form.arguments())) else: raise AssertionError - - @staticmethod - def index_tensor(tensor, form, indices, diagonal): - """Return the PyOP2 data structure tied to ``tensor``, indexed - if necessary. - """ - rank = len(form.arguments()) - is_indexed = any(i is not None for i in indices) - - if rank == 0: - return tensor - elif rank == 1 or rank == 2 and diagonal: - i, = indices - return tensor.dat[i] if is_indexed else tensor.dat - elif rank == 2: - i, j = indices - return tensor.M[i, j] if is_indexed else tensor.M - else: - raise AssertionError diff --git a/tests/regression/test_assemble.py b/tests/regression/test_assemble.py index a80b46d5f0..9ee0e1d9e7 100644 --- a/tests/regression/test_assemble.py +++ b/tests/regression/test_assemble.py @@ -1,6 +1,7 @@ import pytest import numpy as np from firedrake import * +from firedrake.assemble import TwoFormAssembler from firedrake.utils import ScalarType, IntType @@ -125,6 +126,23 @@ def test_assemble_mat_with_tensor(mesh): assert np.allclose(M.M.values, 2*assemble(a).M.values, rtol=1e-14) +@pytest.mark.skipcomplex +def test_mat_nest_real_block_assembler_correctly_reuses_tensor(mesh): + V = FunctionSpace(mesh, "CG", 1) + R = FunctionSpace(mesh, "R", 0) + W = V * R + + u = TrialFunction(W) + v = TestFunction(W) + a = inner(v, u) * dx + + assembler = TwoFormAssembler(a, mat_type="nest") + A1 = assembler.assemble() + A2 = assembler.assemble(tensor=A1) + + assert A2.M is A1.M + + def test_assemble_diagonal(mesh): V = FunctionSpace(mesh, "P", 3) u = TrialFunction(V)