Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion dpnp/dpnp_algo/dpnp_algo_mathematical.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ and the rest of the library
# NO IMPORTs here. All imports must be placed into main "dpnp_algo.pyx" file

__all__ += [
"dpnp_cumprod",
"dpnp_ediff1d",
"dpnp_fabs",
"dpnp_fmod",
Expand Down
11 changes: 10 additions & 1 deletion dpnp/dpnp_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -721,7 +721,16 @@ def copy(self, order="C"):
return dpnp.copy(self, order=order)

# 'ctypes',
# 'cumprod',

def cumprod(self, axis=None, dtype=None, out=None):
"""
Return the cumulative product of the elements along the given axis.

Refer to :obj:`dpnp.cumprod` for full documentation.

"""

return dpnp.cumprod(self, axis=axis, dtype=dtype, out=out)

def cumsum(self, axis=None, dtype=None, out=None):
"""
Expand Down
84 changes: 65 additions & 19 deletions dpnp/dpnp_iface_mathematical.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@

from .backend.extensions.sycl_ext import _sycl_ext_impl
from .dpnp_algo import (
dpnp_cumprod,
dpnp_ediff1d,
dpnp_fabs,
dpnp_fmax,
Expand Down Expand Up @@ -806,38 +805,85 @@ def cross(a, b, axisa=-1, axisb=-1, axisc=-1, axis=None):
return dpnp.moveaxis(cp, -1, axisc)


def cumprod(x1, **kwargs):
def cumprod(a, axis=None, dtype=None, out=None):
"""
Return the cumulative product of elements along a given axis.

For full documentation refer to :obj:`numpy.cumprod`.

Limitations
-----------
Parameter `x` is supported as :class:`dpnp.ndarray`.
Keyword argument `kwargs` is currently unsupported.
Otherwise the function will be executed sequentially on CPU.
Input array data types are limited by supported DPNP :ref:`Data types`.
Parameters
----------
a : {dpnp.ndarray, usm_ndarray}
Input array.
axis : {None, int}, optional
Axis along which the cumulative product is computed. It defaults to
compute the cumulative product over the flattened array.
Default: ``None``.
dtype : {None, dtype}, optional
Type of the returned array and of the accumulator in which the elements
are multiplied. If `dtype` is not specified, it defaults to the dtype
of `a`, unless `a` has an integer dtype with a precision less than that
of the default platform integer. In that case, the default platform
integer is used.
Default: ``None``.
out : {None, dpnp.ndarray, usm_ndarray}, optional
Alternative output array in which to place the result. It must have the
same shape and buffer length as the expected output but the type will
be cast if necessary.
Default: ``None``.

Returns
-------
out : dpnp.ndarray
A new array holding the result is returned unless `out` is specified as
:class:`dpnp.ndarray`, in which case a reference to `out` is returned.
The result has the same size as `a`, and the same shape as `a` if `axis`
is not ``None`` or `a` is a 1-d array.

See Also
--------
:obj:`dpnp.prod` : Product array elements.

Examples
--------
>>> import dpnp as np
>>> a = np.array([1, 2, 3])
>>> result = np.cumprod(a)
>>> [x for x in result]
[1, 2, 6]
>>> b = np.array([[1, 2, 3], [4, 5, 6]])
>>> result = np.cumprod(b)
>>> [x for x in result]
[1, 2, 6, 24, 120, 720]
>>> np.cumprod(a) # intermediate results 1, 1*2
... # total product 1*2*3 = 6
array([1, 2, 6])
>>> a = np.array([[1, 2, 3], [4, 5, 6]])
>>> np.cumprod(a, dtype=np.float32) # specify type of output
array([ 1., 2., 6., 24., 120., 720.], dtype=float32)

The cumulative product for each column (i.e., over the rows) of `a`:

>>> np.cumprod(a, axis=0)
array([[ 1, 2, 3],
[ 4, 10, 18]])

The cumulative product for each row (i.e. over the columns) of `a`:

>>> np.cumprod(a, axis=1)
array([[ 1, 2, 6],
[ 4, 20, 120]])

"""

x1_desc = dpnp.get_dpnp_descriptor(x1, copy_when_nondefault_queue=False)
if x1_desc and not kwargs:
return dpnp_cumprod(x1_desc).get_pyobj()
dpnp.check_supported_arrays_type(a)
if a.ndim > 1 and axis is None:
usm_a = dpnp.ravel(a).get_array()
else:
usm_a = dpnp.get_usm_ndarray(a)

return call_origin(numpy.cumprod, x1, **kwargs)
return dpnp_wrap_reduction_call(
a,
out,
dpt.cumulative_prod,
_get_reduction_res_dt,
usm_a,
axis=axis,
dtype=dtype,
)


def cumsum(a, axis=None, dtype=None, out=None):
Expand Down
4 changes: 2 additions & 2 deletions dpnp/dpnp_iface_nanfunctions.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,8 +265,8 @@ def nancumprod(x1, **kwargs):

See Also
--------
:obj:`dpnp.cumprod` : Return the cumulative product of elements
along a given axis.
:obj:`dpnp.cumprod` : Cumulative product across array propagating NaNs.
:obj:`dpnp.isnan` : Show which elements are NaN.

Examples
--------
Expand Down
69 changes: 69 additions & 0 deletions tests/test_mathematical.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,75 @@ def test_not_implemented_kwargs(self, kwargs):
dpnp.clip(a, 1, 5, **kwargs)


class TestCumProd:
@pytest.mark.parametrize(
"arr, axis",
[
pytest.param([1, 2, 10, 11, 6, 5, 4], -1),
pytest.param([[1, 2, 3, 4], [5, 6, 7, 9], [10, 3, 4, 5]], 0),
pytest.param([[1, 2, 3, 4], [5, 6, 7, 9], [10, 3, 4, 5]], -1),
],
)
@pytest.mark.parametrize("dtype", get_all_dtypes())
def test_axis(self, arr, axis, dtype):
a = numpy.array(arr, dtype=dtype)
ia = dpnp.array(a)

result = dpnp.cumprod(ia, axis=axis)
expected = numpy.cumprod(a, axis=axis)
assert_array_equal(expected, result)

@pytest.mark.parametrize("dtype", get_all_dtypes())
def test_ndarray_method(self, dtype):
a = numpy.arange(1, 10).astype(dtype=dtype)
ia = dpnp.array(a)

result = ia.cumprod()
expected = a.cumprod()
assert_array_equal(expected, result)

@pytest.mark.parametrize("sh", [(10,), (2, 5)])
@pytest.mark.parametrize(
"xp_in, xp_out, check",
[
pytest.param(dpt, dpt, False),
pytest.param(dpt, dpnp, True),
pytest.param(dpnp, dpt, False),
],
)
def test_usm_ndarray(self, sh, xp_in, xp_out, check):
a = numpy.arange(-12, -2).reshape(sh)
ia = xp_in.asarray(a)

result = dpnp.cumprod(ia)
expected = numpy.cumprod(a)
assert_array_equal(expected, result)

out = numpy.empty((10,))
iout = xp_out.asarray(out)

result = dpnp.cumprod(ia, out=iout)
expected = numpy.cumprod(a, out=out)
assert_array_equal(expected, result)
assert (result is iout) is check

@pytest.mark.usefixtures("suppress_complex_warning")
@pytest.mark.parametrize("arr_dt", get_all_dtypes(no_none=True))
@pytest.mark.parametrize("out_dt", get_all_dtypes(no_none=True))
@pytest.mark.parametrize("dtype", get_all_dtypes())
def test_out_dtype(self, arr_dt, out_dt, dtype):
a = numpy.arange(5, 10).astype(dtype=arr_dt)
out = numpy.zeros_like(a, dtype=out_dt)

ia = dpnp.array(a)
iout = dpnp.array(out)

result = ia.cumprod(out=iout, dtype=dtype)
expected = a.cumprod(out=out, dtype=dtype)
assert_array_equal(expected, result)
assert result is iout


class TestCumSum:
@pytest.mark.parametrize(
"arr, axis",
Expand Down
2 changes: 1 addition & 1 deletion tests/test_sycl_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,7 +417,7 @@ def test_meshgrid(device_x, device_y):
),
pytest.param("cosh", [-5.0, -3.5, 0.0, 3.5, 5.0]),
pytest.param("count_nonzero", [3, 0, 2, -1.2]),
pytest.param("cumprod", [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]),
pytest.param("cumprod", [[1, 2, 3], [4, 5, 6]]),
pytest.param("cumsum", [[1, 2, 3], [4, 5, 6]]),
pytest.param("diff", [1.0, 2.0, 4.0, 7.0, 0.0]),
pytest.param("ediff1d", [1.0, 2.0, 4.0, 7.0, 0.0]),
Expand Down
1 change: 1 addition & 0 deletions tests/test_usm_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,6 +522,7 @@ def test_norm(usm_type, ord, axis):
),
pytest.param("cosh", [-5.0, -3.5, 0.0, 3.5, 5.0]),
pytest.param("count_nonzero", [0, 1, 7, 0]),
pytest.param("cumprod", [[1, 2, 3], [4, 5, 6]]),
pytest.param("cumsum", [[1, 2, 3], [4, 5, 6]]),
pytest.param("diff", [1.0, 2.0, 4.0, 7.0, 0.0]),
pytest.param("exp", [1.0, 2.0, 4.0, 7.0]),
Expand Down
22 changes: 2 additions & 20 deletions tests/third_party/cupy/math_tests/test_sumprod.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,13 +483,11 @@ def _cumprod(self, xp, a, *args, **kwargs):
return res

@testing.for_all_dtypes()
# TODO: remove type_check once proper cumprod is implemented
@testing.numpy_cupy_allclose(type_check=(not is_win_platform()))
@testing.numpy_cupy_allclose()
def test_cumprod_1dim(self, xp, dtype):
a = testing.shaped_arange((5,), xp, dtype)
return self._cumprod(xp, a)

@pytest.mark.usefixtures("allow_fall_back_on_numpy")
@testing.for_all_dtypes()
@testing.numpy_cupy_allclose()
def test_cumprod_out(self, xp, dtype):
Expand All @@ -498,7 +496,6 @@ def test_cumprod_out(self, xp, dtype):
self._cumprod(xp, a, out=out)
return out

@pytest.mark.usefixtures("allow_fall_back_on_numpy")
@testing.for_all_dtypes()
@testing.numpy_cupy_allclose()
def test_cumprod_out_noncontiguous(self, xp, dtype):
Expand All @@ -507,24 +504,18 @@ def test_cumprod_out_noncontiguous(self, xp, dtype):
self._cumprod(xp, a, out=out)
return out

# TODO: remove skip once proper cumprod is implemented
@pytest.mark.skipif(
is_win_platform(), reason="numpy has another default integral dtype"
)
@testing.for_all_dtypes()
@testing.numpy_cupy_allclose(rtol=1e-6)
def test_cumprod_2dim_without_axis(self, xp, dtype):
a = testing.shaped_arange((4, 5), xp, dtype)
return self._cumprod(xp, a)

@pytest.mark.usefixtures("allow_fall_back_on_numpy")
@testing.for_all_dtypes()
@testing.numpy_cupy_allclose()
def test_cumprod_2dim_with_axis(self, xp, dtype):
a = testing.shaped_arange((4, 5), xp, dtype)
return self._cumprod(xp, a, axis=1)

@pytest.mark.skip("ndarray.cumprod() is not implemented yet")
@testing.for_all_dtypes()
@testing.numpy_cupy_allclose()
def test_ndarray_cumprod_2dim_with_axis(self, xp, dtype):
Expand All @@ -535,53 +526,44 @@ def test_ndarray_cumprod_2dim_with_axis(self, xp, dtype):
@testing.slow
def test_cumprod_huge_array(self):
size = 2**32
# Free huge memory for slow test
cupy.get_default_memory_pool().free_all_blocks()
a = cupy.ones(size, "b")
a = cupy.ones(size, dtype="b")
result = cupy.cumprod(a, dtype="b")
del a
assert (result == 1).all()
# Free huge memory for slow test
del result
cupy.get_default_memory_pool().free_all_blocks()

@pytest.mark.usefixtures("allow_fall_back_on_numpy")
@testing.for_all_dtypes()
def test_invalid_axis_lower1(self, dtype):
for xp in (numpy, cupy):
a = testing.shaped_arange((4, 5), xp, dtype)
with pytest.raises(numpy.AxisError):
xp.cumprod(a, axis=-a.ndim - 1)

@pytest.mark.usefixtures("allow_fall_back_on_numpy")
@testing.for_all_dtypes()
def test_invalid_axis_lower2(self, dtype):
for xp in (numpy, cupy):
a = testing.shaped_arange((4, 5), xp, dtype)
with pytest.raises(numpy.AxisError):
xp.cumprod(a, axis=-a.ndim - 1)

@pytest.mark.usefixtures("allow_fall_back_on_numpy")
@testing.for_all_dtypes()
def test_invalid_axis_upper1(self, dtype):
for xp in (numpy, cupy):
a = testing.shaped_arange((4, 5), xp, dtype)
with pytest.raises(numpy.AxisError):
return xp.cumprod(a, axis=a.ndim)

@pytest.mark.usefixtures("allow_fall_back_on_numpy")
@testing.for_all_dtypes()
def test_invalid_axis_upper2(self, dtype):
a = testing.shaped_arange((4, 5), cupy, dtype)
with pytest.raises(numpy.AxisError):
return cupy.cumprod(a, axis=a.ndim)

@pytest.mark.skip("no exception is raised by numpy")
def test_cumprod_arraylike(self):
with pytest.raises(TypeError):
return cupy.cumprod((1, 2, 3))

@pytest.mark.skip("no exception is raised by numpy")
@testing.for_float_dtypes()
def test_cumprod_numpy_array(self, dtype):
a_numpy = numpy.arange(1, 6, dtype=dtype)
Expand Down