From cd3574ca890f5fa872ea816a4005199a94cd6b63 Mon Sep 17 00:00:00 2001 From: Anton Volkov Date: Wed, 24 Jul 2024 13:30:28 +0200 Subject: [PATCH] Implement dpnp.trim_zeros() --- doc/reference/manipulation.rst | 1 + dpnp/dpnp_iface_manipulation.py | 66 ++++ tests/test_manipulation.py | 71 ++++ tests/test_sycl_queue.py | 1 + tests/test_usm_type.py | 1 + .../manipulation_tests/test_add_remove.py | 360 ++++++++++++++++++ 6 files changed, 500 insertions(+) create mode 100644 tests/third_party/cupy/manipulation_tests/test_add_remove.py diff --git a/doc/reference/manipulation.rst b/doc/reference/manipulation.rst index 3f6c0aabfc27..c6aeb65de92c 100644 --- a/doc/reference/manipulation.rst +++ b/doc/reference/manipulation.rst @@ -131,6 +131,7 @@ Adding and removing elements dpnp.resize dpnp.unique dpnp.trim_zeros + dpnp.pad Rearranging elements diff --git a/dpnp/dpnp_iface_manipulation.py b/dpnp/dpnp_iface_manipulation.py index b262e314440c..2b5f0bb34e06 100644 --- a/dpnp/dpnp_iface_manipulation.py +++ b/dpnp/dpnp_iface_manipulation.py @@ -82,6 +82,7 @@ "swapaxes", "tile", "transpose", + "trim_zeros", "unique", "vstack", ] @@ -1927,6 +1928,71 @@ def transpose(a, axes=None): return array.transpose(*axes) +def trim_zeros(filt, trim="fb"): + """ + Trim the leading and/or trailing zeros from a 1-D array. + + For full documentation refer to :obj:`numpy.trim_zeros`. + + Parameters + ---------- + filt : {dpnp.ndarray, usm_ndarray} + Input 1-D array. + trim : str, optional + A string with 'f' representing trim from front and 'b' to trim from + back. By defaults, trim zeros from both front and back of the array. + Default: ``"fb"``. + + Returns + ------- + out : dpnp.ndarray + The result of trimming the input. + + Examples + -------- + >>> import dpnp as np + >>> a = np.array((0, 0, 0, 1, 2, 3, 0, 2, 1, 0)) + >>> np.trim_zeros(a) + array([1, 2, 3, 0, 2, 1]) + + >>> np.trim_zeros(a, 'b') + array([0, 0, 0, 1, 2, 3, 0, 2, 1]) + + """ + + dpnp.check_supported_arrays_type(filt) + if filt.ndim == 0: + raise TypeError("0-d array cannot be trimmed") + if filt.ndim > 1: + raise ValueError("Multi-dimensional trim is not supported") + + if not isinstance(trim, str): + raise TypeError("only string trim is supported") + + trim = trim.upper() + if not any(x in trim for x in "FB"): + return filt # no trim rule is specified + + if filt.size == 0: + return filt # no trailing zeros in empty array + + a = dpnp.nonzero(filt)[0] + a_size = a.size + if a_size == 0: + # 'filt' is array of zeros + return dpnp.empty_like(filt, shape=(0,)) + + first = 0 + if "F" in trim: + first = a[0] + + last = filt.size + if "B" in trim: + last = a[-1] + 1 + + return filt[first:last] + + def unique(ar, **kwargs): """ Find the unique elements of an array. diff --git a/tests/test_manipulation.py b/tests/test_manipulation.py index 0178ff9a28b9..817f48835b94 100644 --- a/tests/test_manipulation.py +++ b/tests/test_manipulation.py @@ -378,3 +378,74 @@ def test_ndarray_axes_n_int(self): expected = na.transpose(1, 0, 2) result = da.transpose(1, 0, 2) assert_array_equal(expected, result) + + +class TestTrimZeros: + @pytest.mark.parametrize("dtype", get_all_dtypes(no_none=True)) + def test_basic(self, dtype): + a = numpy.array([0, 0, 1, 0, 2, 3, 4, 0], dtype=dtype) + ia = dpnp.array(a) + + result = dpnp.trim_zeros(ia) + expected = numpy.trim_zeros(a) + assert_array_equal(expected, result) + + @pytest.mark.parametrize("dtype", get_all_dtypes(no_none=True)) + @pytest.mark.parametrize("trim", ["F", "B"]) + def test_trim(self, dtype, trim): + a = numpy.array([0, 0, 1, 0, 2, 3, 4, 0], dtype=dtype) + ia = dpnp.array(a) + + result = dpnp.trim_zeros(ia, trim) + expected = numpy.trim_zeros(a, trim) + assert_array_equal(expected, result) + + @pytest.mark.parametrize("dtype", get_all_dtypes(no_none=True)) + @pytest.mark.parametrize("trim", ["F", "B"]) + def test_all_zero(self, dtype, trim): + a = numpy.zeros((8,), dtype=dtype) + ia = dpnp.array(a) + + result = dpnp.trim_zeros(ia, trim) + expected = numpy.trim_zeros(a, trim) + assert_array_equal(expected, result) + + def test_size_zero(self): + a = numpy.zeros(0) + ia = dpnp.array(a) + + result = dpnp.trim_zeros(ia) + expected = numpy.trim_zeros(a) + assert_array_equal(expected, result) + + @pytest.mark.parametrize( + "a", [numpy.array([0, 2**62, 0]), numpy.array([0, 2**63, 0])] + ) + def test_overflow(self, a): + ia = dpnp.array(a) + + result = dpnp.trim_zeros(ia) + expected = numpy.trim_zeros(a) + assert_array_equal(expected, result) + + def test_trim_no_rule(self): + a = numpy.array([0, 0, 1, 0, 2, 3, 4, 0]) + ia = dpnp.array(a) + trim = "ADE" # no "F" or "B" in trim string + + result = dpnp.trim_zeros(ia, trim) + expected = numpy.trim_zeros(a, trim) + assert_array_equal(expected, result) + + def test_list_array(self): + assert_raises(TypeError, dpnp.trim_zeros, [0, 0, 1, 0, 2, 3, 4, 0]) + + @pytest.mark.parametrize( + "trim", [1, ["F"], numpy.array("B")], ids=["int", "list", "array"] + ) + def test_unsupported_trim(self, trim): + a = numpy.array([0, 0, 1, 0, 2, 3, 4, 0]) + ia = dpnp.array(a) + + assert_raises(TypeError, dpnp.trim_zeros, ia, trim) + assert_raises(AttributeError, numpy.trim_zeros, a, trim) diff --git a/tests/test_sycl_queue.py b/tests/test_sycl_queue.py index 528d33084d0c..0449ac39812a 100644 --- a/tests/test_sycl_queue.py +++ b/tests/test_sycl_queue.py @@ -475,6 +475,7 @@ def test_meshgrid(device): "trace", [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]] ), pytest.param("trapz", [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]), + pytest.param("trim_zeros", [0, 0, 0, 1, 2, 3, 0, 2, 1, 0]), pytest.param("trunc", [-1.7, -1.5, -0.2, 0.2, 1.5, 1.7, 2.0]), pytest.param("var", [1.0, 2.0, 4.0, 7.0]), ], diff --git a/tests/test_usm_type.py b/tests/test_usm_type.py index 8abc0aaf3220..3071ffa638e0 100644 --- a/tests/test_usm_type.py +++ b/tests/test_usm_type.py @@ -610,6 +610,7 @@ def test_norm(usm_type, ord, axis): pytest.param( "trace", [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]] ), + pytest.param("trim_zeros", [0, 0, 0, 1, 2, 3, 0, 2, 1, 0]), pytest.param("trunc", [-1.7, -1.5, -0.2, 0.2, 1.5, 1.7, 2.0]), pytest.param("var", [1.0, 2.0, 4.0, 7.0]), ], diff --git a/tests/third_party/cupy/manipulation_tests/test_add_remove.py b/tests/third_party/cupy/manipulation_tests/test_add_remove.py new file mode 100644 index 000000000000..404215b7562a --- /dev/null +++ b/tests/third_party/cupy/manipulation_tests/test_add_remove.py @@ -0,0 +1,360 @@ +import unittest + +import numpy +import pytest + +import dpnp as cupy +from tests.third_party.cupy import testing +from tests.third_party.cupy.testing._loops import ( + _complex_dtypes, + _regular_float_dtypes, +) + + +@pytest.mark.skip("delete() is not implemented yet") +class TestDelete(unittest.TestCase): + @testing.numpy_cupy_array_equal() + def test_delete_with_no_axis(self, xp): + arr = xp.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]) + indices = xp.array([0, 2, 4, 6, 8]) + + return xp.delete(arr, indices) + + @testing.numpy_cupy_array_equal() + def test_delete_with_axis_zero(self, xp): + arr = xp.array([[0, 1, 2], [3, 4, 5], [6, 7, 8]]) + indices = xp.array([0, 2]) + + return xp.delete(arr, indices, axis=0) + + @testing.numpy_cupy_array_equal() + def test_delete_with_axis_one(self, xp): + arr = xp.array([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]) + indices = xp.array([0, 2, 4]) + + return xp.delete(arr, indices, axis=1) + + @testing.numpy_cupy_array_equal() + def test_delete_with_indices_as_bool_array(self, xp): + arr = xp.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]) + indices = xp.array( + [True, False, True, False, True, False, True, False, True, False] + ) + + return xp.delete(arr, indices) + + @testing.numpy_cupy_array_equal() + def test_delete_with_indices_as_slice(self, xp): + arr = xp.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]) + indices = slice(None, None, 2) + return xp.delete(arr, indices) + + @testing.numpy_cupy_array_equal() + def test_delete_with_indices_as_int(self, xp): + arr = xp.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]) + indices = 5 + if cupy.cuda.runtime.is_hip: + pytest.xfail("HIP may have a bug") + return xp.delete(arr, indices) + + +@pytest.mark.skip("append() is not implemented yet") +class TestAppend(unittest.TestCase): + @testing.for_all_dtypes_combination( + names=["dtype1", "dtype2"], no_bool=True + ) + @testing.numpy_cupy_array_equal() + def test(self, xp, dtype1, dtype2): + a = testing.shaped_random((3, 4, 5), xp, dtype1) + b = testing.shaped_random((6, 7), xp, dtype2) + return xp.append(a, b) + + @testing.for_all_dtypes_combination( + names=["dtype1", "dtype2"], no_bool=True + ) + @testing.numpy_cupy_array_equal() + def test_scalar_lhs(self, xp, dtype1, dtype2): + scalar = xp.dtype(dtype1).type(10).item() + return xp.append(scalar, xp.arange(20, dtype=dtype2)) + + @testing.for_all_dtypes_combination( + names=["dtype1", "dtype2"], no_bool=True + ) + @testing.numpy_cupy_array_equal() + def test_scalar_rhs(self, xp, dtype1, dtype2): + scalar = xp.dtype(dtype2).type(10).item() + return xp.append(xp.arange(20, dtype=dtype1), scalar) + + @testing.for_all_dtypes_combination( + names=["dtype1", "dtype2"], no_bool=True + ) + @testing.numpy_cupy_array_equal() + def test_numpy_scalar_lhs(self, xp, dtype1, dtype2): + scalar = xp.dtype(dtype1).type(10) + return xp.append(scalar, xp.arange(20, dtype=dtype2)) + + @testing.for_all_dtypes_combination( + names=["dtype1", "dtype2"], no_bool=True + ) + @testing.numpy_cupy_array_equal() + def test_numpy_scalar_rhs(self, xp, dtype1, dtype2): + scalar = xp.dtype(dtype2).type(10) + return xp.append(xp.arange(20, dtype=dtype1), scalar) + + @testing.numpy_cupy_array_equal() + def test_scalar_both(self, xp): + return xp.append(10, 10) + + @testing.numpy_cupy_array_equal() + def test_axis(self, xp): + a = testing.shaped_random((3, 4, 5), xp, xp.float32) + b = testing.shaped_random((3, 10, 5), xp, xp.float32) + return xp.append(a, b, axis=1) + + @testing.numpy_cupy_array_equal() + def test_zerodim(self, xp): + return xp.append(xp.array(0), xp.arange(10)) + + @testing.numpy_cupy_array_equal() + def test_empty(self, xp): + return xp.append(xp.array([]), xp.arange(10)) + + +@pytest.mark.skip("resize() is not implemented yet") +class TestResize(unittest.TestCase): + @testing.numpy_cupy_array_equal() + def test(self, xp): + return xp.resize(xp.arange(10), (10, 10)) + + @testing.numpy_cupy_array_equal() + def test_remainder(self, xp): + return xp.resize(xp.arange(8), (10, 10)) + + @testing.numpy_cupy_array_equal() + def test_shape_int(self, xp): + return xp.resize(xp.arange(10), 15) + + @testing.numpy_cupy_array_equal() + def test_scalar(self, xp): + return xp.resize(2, (10, 10)) + + @testing.numpy_cupy_array_equal() + def test_scalar_shape_int(self, xp): + return xp.resize(2, 10) + + @testing.numpy_cupy_array_equal() + def test_typed_scalar(self, xp): + return xp.resize(xp.float32(10.0), (10, 10)) + + @testing.numpy_cupy_array_equal() + def test_zerodim(self, xp): + return xp.resize(xp.array(0), (10, 10)) + + @testing.numpy_cupy_array_equal() + def test_empty(self, xp): + return xp.resize(xp.array([]), (10, 10)) + + +@pytest.mark.skip("unique() is not implemented yet") +class TestUnique: + @testing.for_all_dtypes(no_float16=True, no_bool=True, no_complex=True) + @testing.numpy_cupy_array_equal() + def test_unique_no_axis(self, xp, dtype): + a = testing.shaped_random((100, 100), xp, dtype) + return xp.unique(a) + + @testing.for_all_dtypes(no_float16=True, no_bool=True, no_complex=True) + @testing.numpy_cupy_array_equal() + def test_unique(self, xp, dtype): + a = testing.shaped_random((100, 100), xp, dtype) + return xp.unique(a, axis=1) + + @testing.for_all_dtypes(no_float16=True, no_bool=True, no_complex=True) + @testing.numpy_cupy_array_equal() + def test_unique_index_no_axis(self, xp, dtype): + a = testing.shaped_random((100, 100), xp, dtype) + return xp.unique(a, return_index=True)[1] + + @testing.for_all_dtypes(no_float16=True, no_bool=True, no_complex=True) + @testing.numpy_cupy_array_equal() + def test_unique_index(self, xp, dtype): + a = testing.shaped_random((100, 100), xp, dtype) + return xp.unique(a, return_index=True, axis=0)[1] + + @testing.for_all_dtypes(no_float16=True, no_bool=True, no_complex=True) + @testing.numpy_cupy_array_equal() + def test_unique_inverse_no_axis(self, xp, dtype): + a = testing.shaped_random((100, 100), xp, dtype) + res = xp.unique(a, return_inverse=True)[1] + if xp is numpy and numpy.lib.NumpyVersion(numpy.__version__) < "2.0.0": + res = res.reshape(a.shape) + return res + + @testing.for_all_dtypes(no_float16=True, no_bool=True, no_complex=True) + @testing.numpy_cupy_array_equal() + def test_unique_inverse(self, xp, dtype): + a = testing.shaped_random((100, 100), xp, dtype) + return xp.unique(a, return_inverse=True, axis=1)[1] + + @testing.for_all_dtypes(no_float16=True, no_bool=True, no_complex=True) + @testing.numpy_cupy_array_equal() + def test_unique_counts_no_axis(self, xp, dtype): + a = testing.shaped_random((100, 100), xp, dtype) + return xp.unique(a, return_counts=True)[1] + + @testing.for_all_dtypes(no_float16=True, no_bool=True, no_complex=True) + @testing.numpy_cupy_array_equal() + def test_unique_counts(self, xp, dtype): + a = testing.shaped_random((100, 100), xp, dtype) + return xp.unique(a, return_counts=True, axis=0)[1] + + @testing.for_all_dtypes(no_float16=True, no_bool=True, no_complex=True) + @testing.numpy_cupy_array_equal() + def test_unique_return_all_no_axis(self, xp, dtype): + a = testing.shaped_random((100, 100), xp, dtype) + res = xp.unique( + a, return_index=True, return_inverse=True, return_counts=True + ) + if xp is numpy and numpy.lib.NumpyVersion(numpy.__version__) < "2.0.0": + res = res[:2] + (res[2].reshape(a.shape),) + res[3:] + return res + + @testing.for_all_dtypes(no_float16=True, no_bool=True, no_complex=True) + @testing.numpy_cupy_array_equal() + def test_unique_return_all(self, xp, dtype): + a = testing.shaped_random((100, 100), xp, dtype) + return xp.unique( + a, + return_index=True, + return_inverse=True, + return_counts=True, + axis=1, + ) + + @testing.for_all_dtypes(no_float16=True, no_bool=True, no_complex=True) + @testing.numpy_cupy_array_equal() + def test_unique_empty_no_axis(self, xp, dtype): + a = xp.empty((0,), dtype=dtype) + return xp.unique(a) + + @testing.for_all_dtypes(no_float16=True, no_bool=True, no_complex=True) + @testing.numpy_cupy_array_equal() + def test_unique_empty(self, xp, dtype): + a = xp.empty((0,), dtype=dtype) + return xp.unique(a, axis=0) + + @testing.for_all_dtypes(no_float16=True, no_bool=True, no_complex=True) + @testing.numpy_cupy_array_equal() + def test_unique_empty_return_all_no_axis(self, xp, dtype): + a = xp.empty((3, 0, 2), dtype=dtype) + res = xp.unique( + a, return_index=True, return_inverse=True, return_counts=True + ) + if xp is numpy and numpy.lib.NumpyVersion(numpy.__version__) < "2.0.0": + res = res[:2] + (res[2].reshape(a.shape),) + res[3:] + return res + + @testing.for_all_dtypes(no_float16=True, no_bool=True, no_complex=True) + @testing.numpy_cupy_array_equal() + def test_unique_empty_return_all(self, xp, dtype): + a = xp.empty((3, 0, 2), dtype=dtype) + return xp.unique( + a, + return_index=True, + return_inverse=True, + return_counts=True, + axis=2, + ) + + @pytest.mark.parametrize("equal_nan", [True, False]) + @testing.for_dtypes_combination(_regular_float_dtypes + _complex_dtypes) + @testing.numpy_cupy_array_equal() + @testing.with_requires("numpy>=1.23.1") + def test_unique_equal_nan_no_axis(self, xp, dtype, equal_nan): + if xp.dtype(dtype).kind == "c": + # Nan and Nan+Nan*1j are collapsed when equal_nan=True + a = xp.array( + [ + complex(xp.nan, 3), + 2, + complex(7, xp.nan), + xp.nan, + complex(xp.nan, xp.nan), + 2, + xp.nan, + 1, + ], + dtype=dtype, + ) + else: + a = xp.array([2, xp.nan, 2, xp.nan, 1], dtype=dtype) + return xp.unique(a, equal_nan=equal_nan) + + @pytest.mark.parametrize("equal_nan", [True, False]) + @testing.for_dtypes_combination(_regular_float_dtypes + _complex_dtypes) + @testing.numpy_cupy_array_equal() + @testing.with_requires("numpy>=1.23.1") + def test_unique_equal_nan(self, xp, dtype, equal_nan): + if xp.dtype(dtype).kind == "c": + # Nan and Nan+Nan*1j are collapsed when equal_nan=True + a = xp.array( + [ + [complex(xp.nan, 3), 2, complex(7, xp.nan)], + [xp.nan, complex(xp.nan, xp.nan), 2], + [xp.nan, 1, complex(xp.nan, -1)], + ], + dtype=dtype, + ) + else: + a = xp.array( + [[2, xp.nan, 2], [xp.nan, 1, xp.nan], [xp.nan, 1, xp.nan]], + dtype=dtype, + ) + return xp.unique(a, axis=0, equal_nan=equal_nan) + + +@testing.parameterize(*testing.product({"trim": ["fb", "f", "b"]})) +class TestTrim_zeros(unittest.TestCase): + @testing.for_all_dtypes() + @testing.numpy_cupy_array_equal() + def test_trim_non_zeros(self, xp, dtype): + a = xp.array([-1, 2, -3, 7]).astype(dtype) + return xp.trim_zeros(a, trim=self.trim) + + @testing.for_all_dtypes() + @testing.numpy_cupy_array_equal() + def test_trim_trimmed(self, xp, dtype): + a = xp.array([1, 0, 2, 3, 0, 5], dtype=dtype) + return xp.trim_zeros(a, trim=self.trim) + + @testing.for_all_dtypes() + @testing.numpy_cupy_array_equal() + def test_trim_all_zeros(self, xp, dtype): + a = xp.zeros(shape=(1000,), dtype=dtype) + return xp.trim_zeros(a, trim=self.trim) + + @testing.for_all_dtypes() + @testing.numpy_cupy_array_equal() + def test_trim_front_zeros(self, xp, dtype): + a = xp.array([0, 0, 4, 1, 0, 2, 3, 0, 5], dtype=dtype) + return xp.trim_zeros(a, trim=self.trim) + + @testing.for_all_dtypes() + @testing.numpy_cupy_array_equal() + def test_trim_back_zeros(self, xp, dtype): + a = xp.array([1, 0, 2, 3, 0, 5, 0, 0, 0], dtype=dtype) + return xp.trim_zeros(a, trim=self.trim) + + @testing.for_all_dtypes() + def test_trim_zero_dim(self, dtype): + for xp in (numpy, cupy): + a = testing.shaped_arange((), xp, dtype) + with pytest.raises(TypeError): + xp.trim_zeros(a, trim=self.trim) + + @testing.for_all_dtypes() + def test_trim_ndim(self, dtype): + for xp in (numpy, cupy): + a = testing.shaped_arange((2, 3), xp, dtype=dtype) + with pytest.raises(ValueError): + xp.trim_zeros(a, trim=self.trim)