Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 1 addition & 5 deletions cunumeric/_ufunc/ufunc.py
Original file line number Diff line number Diff line change
Expand Up @@ -663,10 +663,6 @@ def reduce(
raise NotImplementedError(
f"reduction for {self} is not yet implemented"
)
if out is not None:
raise NotImplementedError(
"reduction for {self} does not take an `out` argument"
)
if not isinstance(where, bool) or not where:
raise NotImplementedError(
"the 'where' keyword is not yet supported"
Expand All @@ -682,7 +678,7 @@ def reduce(
array,
axis=axis,
dtype=dtype,
# out=out,
out=out,
keepdims=keepdims,
initial=initial,
where=where,
Expand Down
244 changes: 100 additions & 144 deletions cunumeric/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -585,9 +585,8 @@ def __contains__(self, item):
UnaryRedCode.CONTAINS,
self,
axis=None,
dtype=np.dtype(np.bool_),
res_dtype=bool,
args=args,
check_types=False,
)

def __copy__(self):
Expand Down Expand Up @@ -1501,10 +1500,9 @@ def all(
UnaryRedCode.ALL,
self,
axis=axis,
dst=out,
res_dtype=bool,
out=out,
keepdims=keepdims,
dtype=np.dtype(np.bool_),
check_types=False,
initial=initial,
where=where,
)
Expand Down Expand Up @@ -1537,15 +1535,15 @@ def any(
UnaryRedCode.ANY,
self,
axis=axis,
dst=out,
res_dtype=bool,
out=out,
keepdims=keepdims,
dtype=np.dtype(np.bool_),
check_types=False,
initial=initial,
where=where,
)

def argmax(self, axis=None, out=None):
@add_boilerplate()
def argmax(self, axis=None, out=None, keepdims=False):
"""a.argmax(axis=None, out=None)

Return indices of the maximum values along the given axis.
Expand All @@ -1561,24 +1559,21 @@ def argmax(self, axis=None, out=None):
Multiple GPUs, Multiple CPUs

"""
if self.size == 1:
return 0
if axis is None:
axis = self.ndim - 1
elif type(axis) != int:
raise TypeError("'axis' argument for argmax must be an 'int'")
elif axis < 0 or axis >= self.ndim:
raise TypeError("invalid 'axis' argument for argmax " + str(axis))
if out is not None and out.dtype != np.int64:
raise ValueError("output array must have int64 dtype")
if axis is not None and not isinstance(axis, int):
raise ValueError("axis must be an integer")
return self._perform_unary_reduction(
UnaryRedCode.ARGMAX,
self,
axis=axis,
dtype=np.dtype(np.int64),
dst=out,
check_types=False,
res_dtype=np.dtype(np.int64),
out=out,
keepdims=keepdims,
)

def argmin(self, axis=None, out=None):
@add_boilerplate()
def argmin(self, axis=None, out=None, keepdims=False):
"""a.argmin(axis=None, out=None)

Return indices of the minimum values along the given axis.
Expand All @@ -1594,21 +1589,17 @@ def argmin(self, axis=None, out=None):
Multiple GPUs, Multiple CPUs

"""
if self.size == 1:
return 0
if axis is None:
axis = self.ndim - 1
elif type(axis) != int:
raise TypeError("'axis' argument for argmin must be an 'int'")
elif axis < 0 or axis >= self.ndim:
raise TypeError("invalid 'axis' argument for argmin " + str(axis))
if out is not None and out.dtype != np.int64:
raise ValueError("output array must have int64 dtype")
if axis is not None and not isinstance(axis, int):
raise ValueError("axis must be an integer")
return self._perform_unary_reduction(
UnaryRedCode.ARGMIN,
self,
axis=axis,
dtype=np.dtype(np.int64),
dst=out,
check_types=False,
res_dtype=np.dtype(np.int64),
out=out,
keepdims=keepdims,
)

def astype(
Expand Down Expand Up @@ -2590,7 +2581,7 @@ def max(
UnaryRedCode.MAX,
self,
axis=axis,
dst=out,
out=out,
keepdims=keepdims,
initial=initial,
where=where,
Expand Down Expand Up @@ -2686,7 +2677,7 @@ def min(
UnaryRedCode.MIN,
self,
axis=axis,
dst=out,
out=out,
keepdims=keepdims,
initial=initial,
where=where,
Expand Down Expand Up @@ -2781,7 +2772,8 @@ def prod(
UnaryRedCode.PROD,
self_array,
axis=axis,
dst=out,
dtype=dtype,
out=out,
keepdims=keepdims,
initial=initial,
where=where,
Expand Down Expand Up @@ -3054,7 +3046,8 @@ def sum(
UnaryRedCode.SUM,
self_array,
axis=axis,
dst=out,
dtype=dtype,
out=out,
keepdims=keepdims,
initial=initial,
where=where,
Expand Down Expand Up @@ -3486,13 +3479,31 @@ def _perform_unary_reduction(
src,
axis=None,
dtype=None,
dst=None,
res_dtype=None,
out=None,
keepdims=False,
args=None,
check_types=True,
initial=None,
where=True,
):
# When 'res_dtype' is not None, the input and output of the reduction
# have different types. Such reduction operators don't take a dtype of
# the accumulator
if res_dtype is not None:
assert dtype is None
dtype = src.dtype
else:
# If 'dtype' exists, that determines both the accumulation dtype
# and the output dtype
if dtype is not None:
res_dtype = dtype
elif out is not None:
dtype = out.dtype
res_dtype = out.dtype
else:
dtype = src.dtype
res_dtype = src.dtype

# TODO: Need to require initial to be given when the array is empty
# or a where mask is given.
if isinstance(where, ndarray):
Expand All @@ -3516,121 +3527,66 @@ def _perform_unary_reduction(
"(arg)max/min not supported for complex-type arrays"
)
# Compute the output shape
if axis is not None:
to_reduce = set()
if type(axis) == int:
if axis < 0:
axis = len(src.shape) + axis
if axis < 0:
raise ValueError("Illegal 'axis' value")
elif axis >= src.ndim:
raise ValueError("Illegal 'axis' value")
to_reduce.add(axis)
axes = (axis,)
elif type(axis) == tuple:
for ax in axis:
if ax < 0:
ax = len(src.shape) + ax
if ax < 0:
raise ValueError("Illegal 'axis' value")
elif ax >= src.ndim:
raise ValueError("Illegal 'axis' value")
to_reduce.add(ax)
axes = axis
else:
raise TypeError(
"Illegal type passed for 'axis' argument "
+ str(type(axis))
)
out_shape = ()
for dim in range(len(src.shape)):
if dim in to_reduce:
if keepdims:
out_shape += (1,)
else:
out_shape += (src.shape[dim],)
else:
# Collapsing down to a single value in this case
out_shape = ()
axes = None
# if src.size == 0:
# return nd
if dst is None:
if dtype is not None:
dst = ndarray(
shape=out_shape,
dtype=dtype,
inputs=(src, where),
)
else:
dst = ndarray(
shape=out_shape,
dtype=src.dtype,
inputs=(src, where),
)
else:
if dtype is not None and dtype != dst.dtype:
raise TypeError(
"Output array type does not match requested dtype"
)
if dst.shape != out_shape:
raise TypeError(
"Output array shape "
+ str(dst.shape)
+ " does not match expected shape "
+ str(out_shape)
)
# Quick exit
if where is False:
return dst
if check_types and src.dtype != dst.dtype:
out_dtype = cls.find_common_type(src, dst)
if src.dtype != out_dtype:
temp = ndarray(
src.shape,
dtype=out_dtype,
inputs=(src, where),
)
temp._thunk.convert(src._thunk)
src = temp
if dst.dtype != out_dtype:
temp = ndarray(
dst.shape,
dtype=out_dtype,
inputs=(src, where),
)
axes = axis
if axes is None:
axes = tuple(range(src.ndim))
elif not isinstance(axes, tuple):
axes = (axes,)

temp._thunk.unary_reduction(
op,
src._thunk,
cls._get_where_thunk(where, dst.shape),
axes,
keepdims,
args,
initial,
)
dst._thunk.convert(temp._thunk)
else:
dst._thunk.unary_reduction(
op,
src._thunk,
cls._get_where_thunk(where, dst.shape),
axes,
keepdims,
args,
initial,
)
if any(type(ax) != int for ax in axes):
raise TypeError(
"'axis' must be an integer or a tuple of integers, "
f"but got {axis}"
)

axes = tuple(ax + src.ndim if ax < 0 else ax for ax in axes)

if any(ax < 0 for ax in axes):
raise ValueError(f"Invalid 'axis' value {axis}")

out_shape = ()
for dim in range(src.ndim):
if dim not in axes:
out_shape += (src.shape[dim],)
elif keepdims:
out_shape += (1,)

if out is None:
out = ndarray(
shape=out_shape, dtype=res_dtype, inputs=(src, where)
)
elif out.shape != out_shape:
raise ValueError(
f"the output shape mismatch: expected {out_shape} but got "
f"{out.shape}"
)

if dtype != src.dtype:
src = src.astype(dtype)

if out.dtype == res_dtype:
result = out
else:
dst._thunk.unary_reduction(
result = ndarray(
shape=out_shape, dtype=res_dtype, inputs=(src, where)
)

if where:
result._thunk.unary_reduction(
op,
src._thunk,
cls._get_where_thunk(where, dst.shape),
cls._get_where_thunk(where, result.shape),
axis,
axes,
keepdims,
args,
initial,
)
return dst

if result is not out:
out._thunk.convert(result._thunk)

return out

@classmethod
def _perform_binary_reduction(
Expand Down
Loading