diff --git a/dpnp/dpnp_iface_indexing.py b/dpnp/dpnp_iface_indexing.py index 97951bc30800..e8df6fec9068 100644 --- a/dpnp/dpnp_iface_indexing.py +++ b/dpnp/dpnp_iface_indexing.py @@ -1174,7 +1174,7 @@ def put(a, ind, v, /, *, axis=None, mode="wrap"): v : {scalar, array_like} Values to be put into `a`. Must be broadcastable to the result shape ``a.shape[:axis] + ind.shape + a.shape[axis+1:]``. - axis {None, int}, optional + axis : {None, int}, optional The axis along which the values will be placed. If `a` is 1-D array, this argument is optional. Default: ``None``. @@ -1502,7 +1502,7 @@ def take(a, indices, /, *, axis=None, out=None, mode="wrap"): return dpnp.get_result_array(result, out) -def take_along_axis(a, indices, axis): +def take_along_axis(a, indices, axis, mode="wrap"): """ Take values from the input array by matching 1d index and data slices. @@ -1523,15 +1523,24 @@ def take_along_axis(a, indices, axis): Indices to take along each 1d slice of `a`. This must match the dimension of the input array, but dimensions ``Ni`` and ``Nj`` only need to broadcast against `a`. - axis : int + axis : {None, int} The axis to take 1d slices along. If axis is ``None``, the input array is treated as if it had first been flattened to 1d, for consistency with :obj:`dpnp.sort` and :obj:`dpnp.argsort`. + mode : {"wrap", "clip"}, optional + Specifies how out-of-bounds indices will be handled. Possible values + are: + + - ``"wrap"``: clamps indices to (``-n <= i < n``), then wraps + negative indices. + - ``"clip"``: clips indices to (``0 <= i < n``). + + Default: ``"wrap"``. Returns ------- out : dpnp.ndarray - The indexed result. + The indexed result of the same data type as `a`. See Also -------- @@ -1591,12 +1600,21 @@ def take_along_axis(a, indices, axis): """ - dpnp.check_supported_arrays_type(a, indices) - if axis is None: - a = a.ravel() + dpnp.check_supported_arrays_type(indices) + if indices.ndim != 1: + raise ValueError( + "when axis=None, `indices` must have a single dimension." + ) - return a[_build_along_axis_index(a, indices, axis)] + a = dpnp.ravel(a) + axis = 0 + + usm_a = dpnp.get_usm_ndarray(a) + usm_ind = dpnp.get_usm_ndarray(indices) + + usm_res = dpt.take_along_axis(usm_a, usm_ind, axis=axis, mode=mode) + return dpnp_array._create_from_usm_ndarray(usm_res) def tril_indices( diff --git a/tests/test_indexing.py b/tests/test_indexing.py index 90d27a01dbd6..bed48bce3985 100644 --- a/tests/test_indexing.py +++ b/tests/test_indexing.py @@ -544,6 +544,13 @@ def test_values(self, arr_dt, idx_dt, ndim, values): dpnp.put_along_axis(dp_a, dp_ai, values, axis) assert_array_equal(np_a, dp_a) + @pytest.mark.parametrize("xp", [numpy, dpnp]) + @pytest.mark.parametrize("dt", [bool, numpy.float32]) + def test_invalid_indices_dtype(self, xp, dt): + a = xp.ones((10, 10)) + ind = xp.ones(10, dtype=dt) + assert_raises(IndexError, xp.put_along_axis, a, ind, 7, axis=1) + @pytest.mark.parametrize("arr_dt", get_all_dtypes()) @pytest.mark.parametrize("idx_dt", get_integer_dtypes()) def test_broadcast(self, arr_dt, idx_dt): @@ -673,66 +680,80 @@ def test_argequivalent(self, func, argfunc, kwargs): @pytest.mark.parametrize("idx_dt", get_integer_dtypes()) @pytest.mark.parametrize("ndim", list(range(1, 4))) def test_multi_dimensions(self, arr_dt, idx_dt, ndim): - np_a = numpy.arange(4**ndim, dtype=arr_dt).reshape((4,) * ndim) - np_ai = numpy.array([3, 0, 2, 1], dtype=idx_dt).reshape( + a = numpy.arange(4**ndim, dtype=arr_dt).reshape((4,) * ndim) + ind = numpy.array([3, 0, 2, 1], dtype=idx_dt).reshape( (1,) * (ndim - 1) + (4,) ) - - dp_a = dpnp.array(np_a, dtype=arr_dt) - dp_ai = dpnp.array(np_ai, dtype=idx_dt) + ia, iind = dpnp.array(a), dpnp.array(ind) for axis in range(ndim): - expected = numpy.take_along_axis(np_a, np_ai, axis) - result = dpnp.take_along_axis(dp_a, dp_ai, axis) + result = dpnp.take_along_axis(ia, iind, axis) + expected = numpy.take_along_axis(a, ind, axis) assert_array_equal(expected, result) @pytest.mark.parametrize("xp", [numpy, dpnp]) - def test_invalid(self, xp): + def test_not_enough_indices(self, xp): a = xp.ones((10, 10)) - ai = xp.ones((10, 2), dtype=xp.intp) - - # not enough indices assert_raises(ValueError, xp.take_along_axis, a, xp.array(1), axis=1) - # bool arrays not allowed - assert_raises( - IndexError, xp.take_along_axis, a, ai.astype(bool), axis=1 - ) + @pytest.mark.parametrize("xp", [numpy, dpnp]) + @pytest.mark.parametrize("dt", [bool, numpy.float32]) + def test_invalid_indices_dtype(self, xp, dt): + a = xp.ones((10, 10)) + ind = xp.ones((10, 2), dtype=dt) + assert_raises(IndexError, xp.take_along_axis, a, ind, axis=1) - # float arrays not allowed - assert_raises( - IndexError, xp.take_along_axis, a, ai.astype(numpy.float32), axis=1 - ) + @pytest.mark.parametrize("xp", [numpy, dpnp]) + def test_invalid_axis(self, xp): + a = xp.ones((10, 10)) + ind = xp.ones((10, 2), dtype=xp.intp) + assert_raises(AxisError, xp.take_along_axis, a, ind, axis=10) - # invalid axis - assert_raises(AxisError, xp.take_along_axis, a, ai, axis=10) + @pytest.mark.parametrize("xp", [numpy, dpnp]) + def test_indices_ndim_axis_none(self, xp): + a = xp.ones((10, 10)) + ind = xp.ones((10, 2), dtype=xp.intp) + assert_raises(ValueError, xp.take_along_axis, a, ind, axis=None) - @pytest.mark.parametrize("arr_dt", get_all_dtypes()) + @pytest.mark.parametrize("a_dt", get_all_dtypes(no_none=True)) @pytest.mark.parametrize("idx_dt", get_integer_dtypes()) - def test_empty(self, arr_dt, idx_dt): - np_a = numpy.ones((3, 4, 5), dtype=arr_dt) - np_ai = numpy.ones((3, 0, 5), dtype=idx_dt) - - dp_a = dpnp.array(np_a, dtype=arr_dt) - dp_ai = dpnp.array(np_ai, dtype=idx_dt) + def test_empty(self, a_dt, idx_dt): + a = numpy.ones((3, 4, 5), dtype=a_dt) + ind = numpy.ones((3, 0, 5), dtype=idx_dt) + ia, iind = dpnp.array(a), dpnp.array(ind) - expected = numpy.take_along_axis(np_a, np_ai, axis=1) - result = dpnp.take_along_axis(dp_a, dp_ai, axis=1) + result = dpnp.take_along_axis(ia, iind, axis=1) + expected = numpy.take_along_axis(a, ind, axis=1) assert_array_equal(expected, result) - @pytest.mark.parametrize("arr_dt", get_all_dtypes()) + @pytest.mark.parametrize("a_dt", get_all_dtypes(no_none=True)) @pytest.mark.parametrize("idx_dt", get_integer_dtypes()) - def test_broadcast(self, arr_dt, idx_dt): - np_a = numpy.ones((3, 4, 1), dtype=arr_dt) - np_ai = numpy.ones((1, 2, 5), dtype=idx_dt) - - dp_a = dpnp.array(np_a, dtype=arr_dt) - dp_ai = dpnp.array(np_ai, dtype=idx_dt) + def test_broadcast(self, a_dt, idx_dt): + a = numpy.ones((3, 4, 1), dtype=a_dt) + ind = numpy.ones((1, 2, 5), dtype=idx_dt) + ia, iind = dpnp.array(a), dpnp.array(ind) - expected = numpy.take_along_axis(np_a, np_ai, axis=1) - result = dpnp.take_along_axis(dp_a, dp_ai, axis=1) + result = dpnp.take_along_axis(ia, iind, axis=1) + expected = numpy.take_along_axis(a, ind, axis=1) assert_array_equal(expected, result) + def test_mode_wrap(self): + a = numpy.array([-2, -1, 0, 1, 2]) + ind = numpy.array([-2, 2, -5, 4]) + ia, iind = dpnp.array(a), dpnp.array(ind) + + result = dpnp.take_along_axis(ia, iind, axis=0, mode="wrap") + expected = numpy.take_along_axis(a, ind, axis=0) + assert_array_equal(result, expected) + + def test_mode_clip(self): + a = dpnp.array([-2, -1, 0, 1, 2]) + ind = dpnp.array([-2, 2, -5, 4]) + + # numpy does not support keyword `mode` + result = dpnp.take_along_axis(a, ind, axis=0, mode="clip") + assert (result == dpnp.array([-2, 0, -2, 2])).all() + @pytest.mark.usefixtures("allow_fall_back_on_numpy") def test_choose():