From acdac39b8ef6240f8e613d4f375f20906ab828c9 Mon Sep 17 00:00:00 2001 From: Connor Ward Date: Thu, 31 Oct 2024 16:27:33 +0000 Subject: [PATCH 01/10] Fix assembly of Real matrices --- firedrake/assemble.py | 130 ++++++++++++++-------------- tests/regression/test_real_space.py | 16 ++++ 2 files changed, 83 insertions(+), 63 deletions(-) diff --git a/firedrake/assemble.py b/firedrake/assemble.py index 181909161e..2ed8d9915f 100644 --- a/firedrake/assemble.py +++ b/firedrake/assemble.py @@ -29,6 +29,7 @@ from firedrake.utils import ScalarType, assert_empty, tuplify from pyop2 import op2 from pyop2.exceptions import MapValueError, SparsityFormatError +from pyop2.types.mat import _GlobalMatPayload, _DatMatPayload from pyop2.utils import cached_property @@ -965,22 +966,24 @@ def assemble(self, tensor=None): Result of assembly: `float` for 0-forms, `firedrake.cofunction.Cofunction` or `firedrake.function.Function` for 1-forms, and `matrix.MatrixBase` for 2-forms. """ - self._check_tensor(tensor) - if tensor is None: - tensor = self.allocate() - needs_zeroing = False - else: - needs_zeroing = self._needs_zeroing if annotate_tape(): raise NotImplementedError( "Taping with explicit FormAssembler objects is not supported yet. " "Use assemble instead." ) - if needs_zeroing: - type(self)._as_pyop2_type(tensor).zero() + + if tensor is None: + tensor = self.allocate() + else: + self._check_tensor(tensor) + if self._needs_zeroing: + _as_pyop2_tensor(tensor).zero() + self.execute_parloops(tensor) + for bc in self._bcs: self._apply_bc(tensor, bc) + return self.result(tensor) @abc.abstractmethod @@ -991,11 +994,6 @@ def _apply_bc(self, tensor, bc): def _check_tensor(self, tensor): """Check input tensor.""" - @staticmethod - def _as_pyop2_type(tensor): - """Return tensor as pyop2 type.""" - raise NotImplementedError - def execute_parloops(self, tensor): for parloop in self.parloops(tensor): parloop() @@ -1003,12 +1001,14 @@ def execute_parloops(self, tensor): def parloops(self, tensor): if hasattr(self, "_parloops"): for (lknl, _), parloop in zip(self.local_kernels, self._parloops): - data = _FormHandler.index_tensor(tensor, self._form, lknl.indices, self.diagonal) + data = _as_parloop_type(tensor, lknl.indices) parloop.arguments[0].data = data else: # Make parloops for one concrete output tensor and cache them. - # TODO: Make parloops only with some symbolic information of the output tensor. - self._parloops = tuple(parloop_builder.build(tensor) for parloop_builder in self.parloop_builders) + self._parloops = tuple( + parloop_builder.build(tensor) + for parloop_builder in self.parloop_builders + ) return self._parloops @cached_property @@ -1120,11 +1120,7 @@ def _apply_bc(self, tensor, bc): pass def _check_tensor(self, tensor): - assert tensor is None - - @staticmethod - def _as_pyop2_type(tensor): - return tensor + pass def result(self, tensor): return tensor.data[0] @@ -1198,15 +1194,8 @@ def _apply_dirichlet_bc(self, tensor, bc): bc.zero(tensor) def _check_tensor(self, tensor): - rank = len(self._form.arguments()) - if rank == 1: - test, = self._form.arguments() - if tensor is not None and test.function_space() != tensor.function_space(): - raise ValueError("Form's argument does not match provided result tensor") - - @staticmethod - def _as_pyop2_type(tensor): - return tensor.dat + if tensor.function_space() != self._form.arguments()[0].function_space(): + raise ValueError("Form's argument does not match provided result tensor") def execute_parloops(self, tensor): # We are repeatedly incrementing into the same Dat so intermediate halo exchanges @@ -1454,13 +1443,9 @@ def _apply_bcs_mat_real_block(op2tensor, i, j, component, node_set): dat.zero(subset=node_set) def _check_tensor(self, tensor): - if tensor is not None and tensor.a.arguments() != self._form.arguments(): + if tensor.a.arguments() != self._form.arguments(): raise ValueError("Form's arguments do not match provided result tensor") - @staticmethod - def _as_pyop2_type(tensor): - return tensor.M - def result(self, tensor): tensor.M.assemble() return tensor @@ -1471,7 +1456,7 @@ class MatrixFreeAssembler(FormAssembler): Parameters ---------- - form : ufl.Form or slate.TensorBasehe + form : ufl.Form or slate.TensorBase 2-form. Notes @@ -1498,14 +1483,15 @@ def allocate(self): appctx=self._appctx or {}) def assemble(self, tensor=None): - self._check_tensor(tensor) if tensor is None: tensor = self.allocate() + else: + self._check_tensor(tensor) tensor.assemble() return tensor def _check_tensor(self, tensor): - if tensor is not None and tensor.a.arguments() != self._form.arguments(): + if tensor.a.arguments() != self._form.arguments(): raise ValueError("Form's arguments do not match provided result tensor") @@ -1939,10 +1925,6 @@ def _integral_type(self): def _indexed_function_spaces(self): return _FormHandler.index_function_spaces(self._form, self._indices) - @property - def _indexed_tensor(self): - return _FormHandler.index_tensor(self._tensor, self._form, self._indices, self._diagonal) - @cached_property def _mesh(self): return self._form.ufl_domains()[self._kinfo.domain_number] @@ -1990,28 +1972,28 @@ def _as_parloop_arg(tsfc_arg, self): @_as_parloop_arg.register(kernel_args.OutputKernelArg) def _as_parloop_arg_output(_, self): rank = len(self._form.arguments()) - tensor = self._indexed_tensor + pyop2_tensor = _as_pyop2_tensor(self._tensor, self._indices) Vs = self._indexed_function_spaces if rank == 0: - return op2.GlobalParloopArg(tensor) + return op2.GlobalParloopArg(pyop2_tensor) elif rank == 1 or rank == 2 and self._diagonal: V, = Vs if V.ufl_element().family() == "Real": - return op2.GlobalParloopArg(tensor) + return op2.GlobalParloopArg(pyop2_tensor) else: - return op2.DatParloopArg(tensor, self._get_map(V)) + return op2.DatParloopArg(pyop2_tensor, self._get_map(V)) elif rank == 2: rmap, cmap = [self._get_map(V) for V in Vs] if all(V.ufl_element().family() == "Real" for V in Vs): assert rmap is None and cmap is None - return op2.GlobalParloopArg(tensor.handle.getPythonContext().global_) + return op2.GlobalParloopArg(pyop2_tensor.handle.getPythonContext().global_) elif any(V.ufl_element().family() == "Real" for V in Vs): m = rmap or cmap - return op2.DatParloopArg(tensor.handle.getPythonContext().dat, m) + return op2.DatParloopArg(pyop2_tensor.handle.getPythonContext().dat, m) else: - return op2.MatParloopArg(tensor, (rmap, cmap), lgmaps=self.collect_lgmaps()) + return op2.MatParloopArg(pyop2_tensor, (rmap, cmap), lgmaps=self.collect_lgmaps()) else: raise AssertionError @@ -2123,21 +2105,43 @@ def index_function_spaces(form, indices): else: raise AssertionError - @staticmethod - def index_tensor(tensor, form, indices, diagonal): - """Return the PyOP2 data structure tied to ``tensor``, indexed - if necessary. - """ - rank = len(form.arguments()) - is_indexed = any(i is not None for i in indices) - if rank == 0: - return tensor - elif rank == 1 or rank == 2 and diagonal: +def _as_pyop2_tensor(tensor, indices=None): + """Cast a Firedrake tensor into a PyOP2 data structure, optionally indexing it.""" + if isinstance(tensor, op2.Global): + assert not indices + return tensor + + if isinstance(tensor, (firedrake.Function, firedrake.Cofunction)): + if indices: i, = indices - return tensor.dat[i] if is_indexed else tensor.dat - elif rank == 2: + return tensor.dat[i] + else: + return tensor.dat + else: + assert isinstance(tensor, firedrake.Matrix) + if indices: i, j = indices - return tensor.M[i, j] if is_indexed else tensor.M + return tensor.M[i, j] else: - raise AssertionError + return tensor.M + + +def _as_parloop_type(tensor, indices): + """Cast a Firedrake tensor into a PyOP2 data structure suitable for a parloop. + + This function differs from `_as_pyop2_tensor` in that matrices with ``"python"`` + type are unpacked into their underlying `op2.Global` or `op2.Dat`. + + """ + pyop2_tensor = _as_pyop2_tensor(tensor, indices) + + if isinstance(pyop2_tensor, op2.Mat) and pyop2_tensor.handle.getType() == "python": + mat_context = pyop2_tensor.handle.getPythonContext() + if isinstance(mat_context, _GlobalMatPayload): + pyop2_tensor = mat_context.global_ + else: + assert isinstance(mat_context, _DatMatPayload) + pyop2_tensor = mat_context.dat + + return pyop2_tensor diff --git a/tests/regression/test_real_space.py b/tests/regression/test_real_space.py index a8ee58760e..49ef58a882 100644 --- a/tests/regression/test_real_space.py +++ b/tests/regression/test_real_space.py @@ -258,6 +258,22 @@ def poisson(resolution): assert ln(poisson(50)/poisson(100))/ln(2) > 1.99 +def test_real_space_nonlinear_solve(): + M = UnitIntervalMesh(5) + + V = FunctionSpace(M, "CG", 1) + R = FunctionSpace(M, "R", 0) + Z = V * R + + func = Function(Z) + u, l = split(func) + v, w = TestFunctions(Z) + F = (u + u**2 - 8)*v*dx + l*w*dx + + solve(F == 0, func, solver_parameters={"mat_type": "nest"}) + # TODO: Need some accuracy test? + + @pytest.mark.skipcomplex def test_real_space_eq(): mesh = UnitIntervalMesh(4) From 8b086d39dbba4184754c6ef7096c90c2f5342e25 Mon Sep 17 00:00:00 2001 From: Connor Ward Date: Thu, 31 Oct 2024 16:28:58 +0000 Subject: [PATCH 02/10] Add skipcomplex --- tests/regression/test_real_space.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/regression/test_real_space.py b/tests/regression/test_real_space.py index 49ef58a882..e7b5fb0261 100644 --- a/tests/regression/test_real_space.py +++ b/tests/regression/test_real_space.py @@ -258,6 +258,7 @@ def poisson(resolution): assert ln(poisson(50)/poisson(100))/ln(2) > 1.99 +@pytest.mark.skipcomplex def test_real_space_nonlinear_solve(): M = UnitIntervalMesh(5) From cfc9fcf75b3a16bfa03da5db9315ca2a0d55a235 Mon Sep 17 00:00:00 2001 From: Connor Ward Date: Fri, 1 Nov 2024 10:37:23 +0000 Subject: [PATCH 03/10] Fix for None indices --- firedrake/assemble.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/firedrake/assemble.py b/firedrake/assemble.py index 2ed8d9915f..2468493e4e 100644 --- a/firedrake/assemble.py +++ b/firedrake/assemble.py @@ -2108,19 +2108,21 @@ def index_function_spaces(form, indices): def _as_pyop2_tensor(tensor, indices=None): """Cast a Firedrake tensor into a PyOP2 data structure, optionally indexing it.""" + is_indexed = indices and any(index is not None for index in indices) + if isinstance(tensor, op2.Global): - assert not indices + assert not is_indexed return tensor if isinstance(tensor, (firedrake.Function, firedrake.Cofunction)): - if indices: + if is_indexed: i, = indices return tensor.dat[i] else: return tensor.dat else: assert isinstance(tensor, firedrake.Matrix) - if indices: + if is_indexed: i, j = indices return tensor.M[i, j] else: From 77cf77fbfe763a1264d7fbff8e31b210b4820236 Mon Sep 17 00:00:00 2001 From: Connor Ward Date: Fri, 1 Nov 2024 11:14:15 +0000 Subject: [PATCH 04/10] Use a better test --- tests/regression/test_assemble.py | 22 ++++++++++++++++++++++ tests/regression/test_real_space.py | 17 ----------------- 2 files changed, 22 insertions(+), 17 deletions(-) diff --git a/tests/regression/test_assemble.py b/tests/regression/test_assemble.py index a80b46d5f0..8051e1a40b 100644 --- a/tests/regression/test_assemble.py +++ b/tests/regression/test_assemble.py @@ -1,6 +1,7 @@ import pytest import numpy as np from firedrake import * +from firedrake.assemble import TwoFormAssembler from firedrake.utils import ScalarType, IntType @@ -125,6 +126,27 @@ def test_assemble_mat_with_tensor(mesh): assert np.allclose(M.M.values, 2*assemble(a).M.values, rtol=1e-14) +@pytest.mark.parametrize("space", ["CG", "CGxR"]) +def test_assembler_reuse_respects_tensor(mesh, space): + if space == "CG": + W = FunctionSpace(mesh, "CG", 1) + else: + assert space == "CGxR" + V = FunctionSpace(mesh, "CG", 1) + R = FunctionSpace(mesh, "R", 0) + W = V * R + + u = TrialFunction(W) + v = TestFunction(W) + a = inner(v, u) * dx + + assembler = TwoFormAssembler(a) + A1 = assembler.assemble() + A2 = assembler.assemble(tensor=A1) + + assert A2.M is A1.M + + def test_assemble_diagonal(mesh): V = FunctionSpace(mesh, "P", 3) u = TrialFunction(V) diff --git a/tests/regression/test_real_space.py b/tests/regression/test_real_space.py index e7b5fb0261..a8ee58760e 100644 --- a/tests/regression/test_real_space.py +++ b/tests/regression/test_real_space.py @@ -258,23 +258,6 @@ def poisson(resolution): assert ln(poisson(50)/poisson(100))/ln(2) > 1.99 -@pytest.mark.skipcomplex -def test_real_space_nonlinear_solve(): - M = UnitIntervalMesh(5) - - V = FunctionSpace(M, "CG", 1) - R = FunctionSpace(M, "R", 0) - Z = V * R - - func = Function(Z) - u, l = split(func) - v, w = TestFunctions(Z) - F = (u + u**2 - 8)*v*dx + l*w*dx - - solve(F == 0, func, solver_parameters={"mat_type": "nest"}) - # TODO: Need some accuracy test? - - @pytest.mark.skipcomplex def test_real_space_eq(): mesh = UnitIntervalMesh(4) From 71ba96debb0f2e48da5d8d24237754d03069818c Mon Sep 17 00:00:00 2001 From: Connor Ward Date: Fri, 1 Nov 2024 15:08:09 +0000 Subject: [PATCH 05/10] Update test --- tests/regression/test_assemble.py | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/tests/regression/test_assemble.py b/tests/regression/test_assemble.py index 8051e1a40b..9ee0e1d9e7 100644 --- a/tests/regression/test_assemble.py +++ b/tests/regression/test_assemble.py @@ -126,21 +126,17 @@ def test_assemble_mat_with_tensor(mesh): assert np.allclose(M.M.values, 2*assemble(a).M.values, rtol=1e-14) -@pytest.mark.parametrize("space", ["CG", "CGxR"]) -def test_assembler_reuse_respects_tensor(mesh, space): - if space == "CG": - W = FunctionSpace(mesh, "CG", 1) - else: - assert space == "CGxR" - V = FunctionSpace(mesh, "CG", 1) - R = FunctionSpace(mesh, "R", 0) - W = V * R +@pytest.mark.skipcomplex +def test_mat_nest_real_block_assembler_correctly_reuses_tensor(mesh): + V = FunctionSpace(mesh, "CG", 1) + R = FunctionSpace(mesh, "R", 0) + W = V * R u = TrialFunction(W) v = TestFunction(W) a = inner(v, u) * dx - assembler = TwoFormAssembler(a) + assembler = TwoFormAssembler(a, mat_type="nest") A1 = assembler.assemble() A2 = assembler.assemble(tensor=A1) From 34a29ed0d95844c503a3034b9ed0cd5771968c82 Mon Sep 17 00:00:00 2001 From: Connor Ward Date: Mon, 11 Nov 2024 11:07:31 +0000 Subject: [PATCH 06/10] WIP --- firedrake/assemble.py | 126 ++++++++++++++++++++---------------------- 1 file changed, 60 insertions(+), 66 deletions(-) diff --git a/firedrake/assemble.py b/firedrake/assemble.py index 2468493e4e..1b9dfb6e57 100644 --- a/firedrake/assemble.py +++ b/firedrake/assemble.py @@ -977,7 +977,7 @@ def assemble(self, tensor=None): else: self._check_tensor(tensor) if self._needs_zeroing: - _as_pyop2_tensor(tensor).zero() + self._as_parloop_type(tensor).zero() self.execute_parloops(tensor) @@ -1001,22 +1001,14 @@ def execute_parloops(self, tensor): def parloops(self, tensor): if hasattr(self, "_parloops"): for (lknl, _), parloop in zip(self.local_kernels, self._parloops): - data = _as_parloop_type(tensor, lknl.indices) + data = self._as_parloop_type(tensor, lknl.indices) parloop.arguments[0].data = data + else: # Make parloops for one concrete output tensor and cache them. - self._parloops = tuple( - parloop_builder.build(tensor) - for parloop_builder in self.parloop_builders - ) - return self._parloops - - @cached_property - def parloop_builders(self): - out = [] - for local_kernel, subdomain_id in self.local_kernels: - out.append( - ParloopBuilder( + parloops_ = [] + for local_kernel, subdomain_id in self.local_kernels: + parloop_builder = ParloopBuilder( self._form, self._bcs, local_kernel, @@ -1024,8 +1016,12 @@ def parloop_builders(self): self.all_integer_subdomain_ids, diagonal=self.diagonal, ) - ) - return tuple(out) + pyop2_tensor = self._as_parloop_tensor(tensor, local_kernel.indices) + parloop = parloop_builder.build(pyop2_tensor) + parloops_.append(parloop) + self._parloops = tuple(parloops_) + + return self._parloops @cached_property def local_kernels(self): @@ -1079,6 +1075,11 @@ def all_integer_subdomain_ids(self): def result(self, tensor): """The result of the assembly operation.""" + @abc.abstractmethod + @classmethod + def _as_pyop2_tensor(cls, tensor, indices=None): + """Cast a Firedrake tensor into a PyOP2 data structure, optionally indexing it.""" + class ZeroFormAssembler(ParloopFormAssembler): """Class for assembling a 0-form. @@ -1125,6 +1126,11 @@ def _check_tensor(self, tensor): def result(self, tensor): return tensor.data[0] + @classmethod + def _as_pyop2_tensor(cls, tensor, indices=None): + assert indices is None + return tensor + class OneFormAssembler(ParloopFormAssembler): """Class for assembling a 1-form. @@ -1211,6 +1217,14 @@ def diagonal(self): def result(self, tensor): return tensor + @classmethod + def _as_pyop2_tensor(cls, tensor, indices=None): + if indices is not None and any(index is not None for index in indices): + i, = indices + return tensor.dat[i] + else: + return tensor.dat + def TwoFormAssembler(form, *args, **kwargs): assert isinstance(form, (ufl.form.Form, slate.TensorBase)) @@ -1450,6 +1464,24 @@ def result(self, tensor): tensor.M.assemble() return tensor + @classmethod + def _as_pyop2_tensor(cls, tensor, indices=None): + if indices is not None and any(index is not None for index in indices): + i, j = indices + mat = tensor.M[i, j] + else: + mat = tensor.M + + if mat.handle.getType() == "python": + mat_context = mat.handle.getPythonContext() + if isinstance(mat_context, _GlobalMatPayload): + mat = mat_context.global_ + else: + assert isinstance(mat_context, _DatMatPayload) + mat = mat_context.dat + + return mat + class MatrixFreeAssembler(FormAssembler): """Stub class wrapping matrix-free assembly. @@ -1494,6 +1526,10 @@ def _check_tensor(self, tensor): if tensor.a.arguments() != self._form.arguments(): raise ValueError("Form's arguments do not match provided result tensor") + @classmethod + def _as_pyop2_tensor(cls, tensor, indices=None): + assert False, " dont think this is needed" + def _global_kernel_cache_key(form, local_knl, subdomain_id, all_integer_subdomain_ids, **kwargs): # N.B. Generating the global kernel is not a collective operation so the @@ -1806,12 +1842,12 @@ def __init__(self, form, bcs, local_knl, subdomain_id, self._active_coefficients = _FormHandler.iter_active_coefficients(form, local_knl.kinfo) self._constants = _FormHandler.iter_constants(form, local_knl.kinfo) - def build(self, tensor): + def build(self, tensor: op2.Global | op2.Dat | op2.Mat) -> op2.Parloop: """Construct the parloop. Parameters ---------- - tensor : op2.Global or firedrake.cofunction.Cofunction or matrix.MatrixBase + tensor : The output tensor. """ @@ -1972,28 +2008,27 @@ def _as_parloop_arg(tsfc_arg, self): @_as_parloop_arg.register(kernel_args.OutputKernelArg) def _as_parloop_arg_output(_, self): rank = len(self._form.arguments()) - pyop2_tensor = _as_pyop2_tensor(self._tensor, self._indices) Vs = self._indexed_function_spaces if rank == 0: - return op2.GlobalParloopArg(pyop2_tensor) + return op2.GlobalParloopArg(self._tensor) elif rank == 1 or rank == 2 and self._diagonal: V, = Vs if V.ufl_element().family() == "Real": - return op2.GlobalParloopArg(pyop2_tensor) + return op2.GlobalParloopArg(self._tensor) else: - return op2.DatParloopArg(pyop2_tensor, self._get_map(V)) + return op2.DatParloopArg(self._tensor, self._get_map(V)) elif rank == 2: rmap, cmap = [self._get_map(V) for V in Vs] if all(V.ufl_element().family() == "Real" for V in Vs): assert rmap is None and cmap is None - return op2.GlobalParloopArg(pyop2_tensor.handle.getPythonContext().global_) + return op2.GlobalParloopArg(self._tensor.handle.getPythonContext().global_) elif any(V.ufl_element().family() == "Real" for V in Vs): m = rmap or cmap - return op2.DatParloopArg(pyop2_tensor.handle.getPythonContext().dat, m) + return op2.DatParloopArg(self._tensor.handle.getPythonContext().dat, m) else: - return op2.MatParloopArg(pyop2_tensor, (rmap, cmap), lgmaps=self.collect_lgmaps()) + return op2.MatParloopArg(self._tensor, (rmap, cmap), lgmaps=self.collect_lgmaps()) else: raise AssertionError @@ -2106,44 +2141,3 @@ def index_function_spaces(form, indices): raise AssertionError -def _as_pyop2_tensor(tensor, indices=None): - """Cast a Firedrake tensor into a PyOP2 data structure, optionally indexing it.""" - is_indexed = indices and any(index is not None for index in indices) - - if isinstance(tensor, op2.Global): - assert not is_indexed - return tensor - - if isinstance(tensor, (firedrake.Function, firedrake.Cofunction)): - if is_indexed: - i, = indices - return tensor.dat[i] - else: - return tensor.dat - else: - assert isinstance(tensor, firedrake.Matrix) - if is_indexed: - i, j = indices - return tensor.M[i, j] - else: - return tensor.M - - -def _as_parloop_type(tensor, indices): - """Cast a Firedrake tensor into a PyOP2 data structure suitable for a parloop. - - This function differs from `_as_pyop2_tensor` in that matrices with ``"python"`` - type are unpacked into their underlying `op2.Global` or `op2.Dat`. - - """ - pyop2_tensor = _as_pyop2_tensor(tensor, indices) - - if isinstance(pyop2_tensor, op2.Mat) and pyop2_tensor.handle.getType() == "python": - mat_context = pyop2_tensor.handle.getPythonContext() - if isinstance(mat_context, _GlobalMatPayload): - pyop2_tensor = mat_context.global_ - else: - assert isinstance(mat_context, _DatMatPayload) - pyop2_tensor = mat_context.dat - - return pyop2_tensor From c3f6ec834a78b64161a095e50551bbddf49235df Mon Sep 17 00:00:00 2001 From: Connor Date: Mon, 11 Nov 2024 14:33:23 +0000 Subject: [PATCH 07/10] Think this is now right --- firedrake/assemble.py | 64 +++++++++++++++++++++---------------------- 1 file changed, 31 insertions(+), 33 deletions(-) diff --git a/firedrake/assemble.py b/firedrake/assemble.py index 1b9dfb6e57..a070e36a73 100644 --- a/firedrake/assemble.py +++ b/firedrake/assemble.py @@ -977,7 +977,7 @@ def assemble(self, tensor=None): else: self._check_tensor(tensor) if self._needs_zeroing: - self._as_parloop_type(tensor).zero() + self._as_pyop2_type(tensor).zero() self.execute_parloops(tensor) @@ -994,6 +994,11 @@ def _apply_bc(self, tensor, bc): def _check_tensor(self, tensor): """Check input tensor.""" + @staticmethod + @abc.abstractmethod + def _as_pyop2_type(tensor, indices=None): + """Cast a Firedrake tensor into a PyOP2 data structure, optionally indexing it.""" + def execute_parloops(self, tensor): for parloop in self.parloops(tensor): parloop() @@ -1001,7 +1006,7 @@ def execute_parloops(self, tensor): def parloops(self, tensor): if hasattr(self, "_parloops"): for (lknl, _), parloop in zip(self.local_kernels, self._parloops): - data = self._as_parloop_type(tensor, lknl.indices) + data = self._as_pyop2_type(tensor, lknl.indices) parloop.arguments[0].data = data else: @@ -1016,7 +1021,7 @@ def parloops(self, tensor): self.all_integer_subdomain_ids, diagonal=self.diagonal, ) - pyop2_tensor = self._as_parloop_tensor(tensor, local_kernel.indices) + pyop2_tensor = self._as_pyop2_type(tensor, local_kernel.indices) parloop = parloop_builder.build(pyop2_tensor) parloops_.append(parloop) self._parloops = tuple(parloops_) @@ -1075,11 +1080,6 @@ def all_integer_subdomain_ids(self): def result(self, tensor): """The result of the assembly operation.""" - @abc.abstractmethod - @classmethod - def _as_pyop2_tensor(cls, tensor, indices=None): - """Cast a Firedrake tensor into a PyOP2 data structure, optionally indexing it.""" - class ZeroFormAssembler(ParloopFormAssembler): """Class for assembling a 0-form. @@ -1123,14 +1123,14 @@ def _apply_bc(self, tensor, bc): def _check_tensor(self, tensor): pass + @staticmethod + def _as_pyop2_type(tensor, indices=None): + assert not indices + return tensor + def result(self, tensor): return tensor.data[0] - @classmethod - def _as_pyop2_tensor(cls, tensor, indices=None): - assert indices is None - return tensor - class OneFormAssembler(ParloopFormAssembler): """Class for assembling a 1-form. @@ -1203,6 +1203,14 @@ def _check_tensor(self, tensor): if tensor.function_space() != self._form.arguments()[0].function_space(): raise ValueError("Form's argument does not match provided result tensor") + @staticmethod + def _as_pyop2_type(tensor, indices=None): + if indices is not None and any(index is not None for index in indices): + i, = indices + return tensor.dat[i] + else: + return tensor.dat + def execute_parloops(self, tensor): # We are repeatedly incrementing into the same Dat so intermediate halo exchanges # can be skipped. @@ -1217,14 +1225,6 @@ def diagonal(self): def result(self, tensor): return tensor - @classmethod - def _as_pyop2_tensor(cls, tensor, indices=None): - if indices is not None and any(index is not None for index in indices): - i, = indices - return tensor.dat[i] - else: - return tensor.dat - def TwoFormAssembler(form, *args, **kwargs): assert isinstance(form, (ufl.form.Form, slate.TensorBase)) @@ -1460,12 +1460,8 @@ def _check_tensor(self, tensor): if tensor.a.arguments() != self._form.arguments(): raise ValueError("Form's arguments do not match provided result tensor") - def result(self, tensor): - tensor.M.assemble() - return tensor - - @classmethod - def _as_pyop2_tensor(cls, tensor, indices=None): + @staticmethod + def _as_pyop2_type(tensor, indices=None): if indices is not None and any(index is not None for index in indices): i, j = indices mat = tensor.M[i, j] @@ -1482,6 +1478,10 @@ def _as_pyop2_tensor(cls, tensor, indices=None): return mat + def result(self, tensor): + tensor.M.assemble() + return tensor + class MatrixFreeAssembler(FormAssembler): """Stub class wrapping matrix-free assembly. @@ -1526,9 +1526,9 @@ def _check_tensor(self, tensor): if tensor.a.arguments() != self._form.arguments(): raise ValueError("Form's arguments do not match provided result tensor") - @classmethod - def _as_pyop2_tensor(cls, tensor, indices=None): - assert False, " dont think this is needed" + @staticmethod + def _as_pyop2_type(tensor, indices=None): + raise AssertionError("Should not be called for matrix-free assembly") def _global_kernel_cache_key(form, local_knl, subdomain_id, all_integer_subdomain_ids, **kwargs): @@ -1937,7 +1937,7 @@ def collect_lgmaps(self): lgmaps = [] for i, j in self.get_indicess(): row_bcs, col_bcs = self._filter_bcs(i, j) - rlgmap, clgmap = self._tensor.M[i, j].local_to_global_maps + rlgmap, clgmap = self._tensor[i, j].local_to_global_maps rlgmap = self.test_function_space[i].local_to_global_map(row_bcs, rlgmap) clgmap = self.trial_function_space[j].local_to_global_map(col_bcs, clgmap) lgmaps.append((rlgmap, clgmap)) @@ -2139,5 +2139,3 @@ def index_function_spaces(form, indices): return tuple(a.ufl_function_space()[i] for i, a in zip(indices, form.arguments())) else: raise AssertionError - - From 88456fda159d2aca81b8def34984540b7794a5d4 Mon Sep 17 00:00:00 2001 From: Connor Date: Mon, 11 Nov 2024 16:57:43 +0000 Subject: [PATCH 08/10] fixup --- firedrake/assemble.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/firedrake/assemble.py b/firedrake/assemble.py index a070e36a73..d22125f309 100644 --- a/firedrake/assemble.py +++ b/firedrake/assemble.py @@ -1937,7 +1937,7 @@ def collect_lgmaps(self): lgmaps = [] for i, j in self.get_indicess(): row_bcs, col_bcs = self._filter_bcs(i, j) - rlgmap, clgmap = self._tensor[i, j].local_to_global_maps + rlgmap, clgmap = self._tensor.local_to_global_maps rlgmap = self.test_function_space[i].local_to_global_map(row_bcs, rlgmap) clgmap = self.trial_function_space[j].local_to_global_map(col_bcs, clgmap) lgmaps.append((rlgmap, clgmap)) @@ -2023,10 +2023,10 @@ def _as_parloop_arg_output(_, self): if all(V.ufl_element().family() == "Real" for V in Vs): assert rmap is None and cmap is None - return op2.GlobalParloopArg(self._tensor.handle.getPythonContext().global_) + return op2.GlobalParloopArg(self._tensor) elif any(V.ufl_element().family() == "Real" for V in Vs): m = rmap or cmap - return op2.DatParloopArg(self._tensor.handle.getPythonContext().dat, m) + return op2.DatParloopArg(self._tensor, m) else: return op2.MatParloopArg(self._tensor, (rmap, cmap), lgmaps=self.collect_lgmaps()) else: From fc68a6de66bcdf3b22c69f7223631d714b2fff27 Mon Sep 17 00:00:00 2001 From: Connor Date: Tue, 12 Nov 2024 09:49:35 +0000 Subject: [PATCH 09/10] fixup --- firedrake/assemble.py | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/firedrake/assemble.py b/firedrake/assemble.py index d22125f309..6ca7cbb513 100644 --- a/firedrake/assemble.py +++ b/firedrake/assemble.py @@ -1931,17 +1931,28 @@ def collect_lgmaps(self): :param local_knl: A :class:`tsfc_interface.SplitKernel`. :param bcs: Iterable of boundary conditions. """ + if len(self._form.arguments()) == 2 and not self._diagonal: if not self._bcs: return None - lgmaps = [] - for i, j in self.get_indicess(): + + if any(i is not None for i in self._local_knl.indices): + i, j = self._local_knl.indices row_bcs, col_bcs = self._filter_bcs(i, j) + # the tensor is already indexed rlgmap, clgmap = self._tensor.local_to_global_maps rlgmap = self.test_function_space[i].local_to_global_map(row_bcs, rlgmap) clgmap = self.trial_function_space[j].local_to_global_map(col_bcs, clgmap) - lgmaps.append((rlgmap, clgmap)) - return tuple(lgmaps) + return ((rlgmap, clgmap),) + else: + lgmaps = [] + for i, j in self.get_indicess(): + row_bcs, col_bcs = self._filter_bcs(i, j) + rlgmap, clgmap = self._tensor[i, j].local_to_global_maps + rlgmap = self.test_function_space[i].local_to_global_map(row_bcs, rlgmap) + clgmap = self.trial_function_space[j].local_to_global_map(col_bcs, clgmap) + lgmaps.append((rlgmap, clgmap)) + return tuple(lgmaps) else: return None From 454a2dd464b5eb0853cd4f3bd99b0e00b29a9c5d Mon Sep 17 00:00:00 2001 From: Connor Date: Tue, 12 Nov 2024 14:13:44 +0000 Subject: [PATCH 10/10] fixup --- firedrake/assemble.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/firedrake/assemble.py b/firedrake/assemble.py index 6ca7cbb513..447345d563 100644 --- a/firedrake/assemble.py +++ b/firedrake/assemble.py @@ -1526,10 +1526,6 @@ def _check_tensor(self, tensor): if tensor.a.arguments() != self._form.arguments(): raise ValueError("Form's arguments do not match provided result tensor") - @staticmethod - def _as_pyop2_type(tensor, indices=None): - raise AssertionError("Should not be called for matrix-free assembly") - def _global_kernel_cache_key(form, local_knl, subdomain_id, all_integer_subdomain_ids, **kwargs): # N.B. Generating the global kernel is not a collective operation so the