Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 19 additions & 67 deletions pytential/symbolic/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,44 +100,6 @@ def __init__(self, bound_expr, actx: PyOpenCLArrayContext, context=None,

self.queue = actx.queue

# {{{ TODO: remove when device scalar broadcasting is fixed

# NOTE:
# * awaiting resolution of https://github.com/inducer/arraycontext/issues/49
# * these are only the operations required to pass tests

def _force_host_scalar(self, arg):
if isinstance(arg, cl.array.Array) and arg.shape == ():
return self.array_context.to_numpy(arg)[()]
elif isinstance(arg, np.ndarray) and arg.shape == ():
return arg[()]
else:
return arg

def _map_device_scalar_reduction(self, func, expr):
return func(
self._force_host_scalar(self.rec(child))
for child in expr.children)

def map_sum(self, expr):
return self._map_device_scalar_reduction(sum, expr)

def map_product(self, expr):
from pytools import product
return self._map_device_scalar_reduction(product, expr)

def _map_device_scalar_op(self, op, arg1, arg2):
return op(
self._force_host_scalar(self.rec(arg1)),
self._force_host_scalar(self.rec(arg2)))

def map_quotient(self, expr):
import operator
return self._map_device_scalar_op(
operator.truediv, expr.numerator, expr.denominator)

# }}}

# {{{ map_XXX

def _map_minmax(self, func, inherited_func, expr):
Expand All @@ -163,27 +125,21 @@ def map_min(self, expr):

def map_node_sum(self, expr):
actx = self.array_context
result = sum(actx.np.sum(grp_ary) for grp_ary in self.rec(expr.operand))
if not actx._force_device_scalars:
result = actx.to_numpy(result)[()]

return result
return sum(actx.np.sum(grp_ary) for grp_ary in self.rec(expr.operand))

def map_node_max(self, expr):
from functools import reduce
actx = self.array_context
result = max(actx.np.max(grp_ary) for grp_ary in self.rec(expr.operand))
if not actx._force_device_scalars:
result = actx.to_numpy(result)[()]

return result
return reduce(
actx.np.maximum,
(actx.np.max(grp_ary) for grp_ary in self.rec(expr.operand)))

def map_node_min(self, expr):
from functools import reduce
actx = self.array_context
result = min(actx.np.min(grp_ary) for grp_ary in self.rec(expr.operand))
if not actx._force_device_scalars:
result = actx.to_numpy(result)[()]

return result
return reduce(
actx.np.minimum,
(actx.np.min(grp_ary) for grp_ary in self.rec(expr.operand)))

def _map_elementwise_reduction(self, reduction_name, expr):
import loopy as lp
Expand Down Expand Up @@ -544,23 +500,19 @@ def matvec(self, x):
# => output is a flat PyOpenCL array
# * structured arrays (object arrays/DOFArrays)
# => output has same structure as input
if isinstance(x, np.ndarray) and x.dtype.char != "O":
x = self.array_context.from_numpy(x)
flat = True
host = True
assert x.shape == (self.total_dofs,)
if isinstance(x, DOFArray):
flat, host = False, False
elif isinstance(x, np.ndarray) and x.dtype.char == "O":
flat, host = False, False
elif isinstance(x, cl.array.Array):
flat = True
host = False
flat, host = True, False
assert x.shape == (self.total_dofs,)
elif isinstance(x, np.ndarray) and x.dtype.char != "O":
x = self.array_context.from_numpy(x)
flat, host = True, True
assert x.shape == (self.total_dofs,)
elif isinstance(x, np.ndarray) and x.dtype.char == "O":
flat = False
host = False
elif isinstance(x, DOFArray):
flat = False
host = False
else:
raise ValueError("unsupported input type")
raise ValueError(f"unsupported input type: {type(x).__name__}")
Comment thread
inducer marked this conversation as resolved.

args = self.extra_args.copy()
args[self.arg_name] = self.unflatten(x) if flat else x
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def write_git_revision(package_name):
install_requires=[
"pytools>=2018.2",
"modepy>=2013.3",
"pyopencl>=2013.1",
"pyopencl>=2021.2.6",
"boxtree>=2019.1",
"pymbolic>=2013.2",
"loopy>=2020.2",
Expand Down