diff --git a/dpnp/dpnp_algo/dpnp_elementwise_common.py b/dpnp/dpnp_algo/dpnp_elementwise_common.py index 833138ba13be..374981a63031 100644 --- a/dpnp/dpnp_algo/dpnp_elementwise_common.py +++ b/dpnp/dpnp_algo/dpnp_elementwise_common.py @@ -341,6 +341,107 @@ def __call__( return out return dpnp_array._create_from_usm_ndarray(res_usm) + def outer( + self, + x1, + x2, + out=None, + where=True, + order="K", + dtype=None, + subok=True, + **kwargs, + ): + """ + Apply the ufunc op to all pairs (a, b) with a in A and b in B. + + Parameters + ---------- + x1 : {dpnp.ndarray, usm_ndarray} + First input array. + x2 : {dpnp.ndarray, usm_ndarray} + Second input array. + out : {None, dpnp.ndarray, usm_ndarray}, optional + Output array to populate. + Array must have the correct shape and the expected data type. + order : {None, "C", "F", "A", "K"}, optional + Memory layout of the newly output array, Cannot be provided + together with `out`. Default: ``"K"``. + dtype : {None, dtype}, optional + If provided, the destination array will have this dtype. Cannot be + provided together with `out`. Default: ``None``. + + Returns + ------- + out : dpnp.ndarray + Output array. The data type of the returned array is determined by + the Type Promotion Rules. + + Limitations + ----------- + Parameters `where` and `subok` are supported with their default values. + Keyword argument `kwargs` is currently unsupported. + Otherwise ``NotImplementedError`` exception will be raised. + + See also + -------- + :obj:`dpnp.outer` : A less powerful version of dpnp.multiply.outer + that ravels all inputs to 1D. This exists primarily + for compatibility with old code. + + :obj:`dpnp.tensordot` : dpnp.tensordot(a, b, axes=((), ())) and + dpnp.multiply.outer(a, b) behave same for all + dimensions of a and b. + + Examples + -------- + >>> import dpnp as np + >>> A = np.array([1, 2, 3]) + >>> B = np.array([4, 5, 6]) + >>> np.multiply.outer(A, B) + array([[ 4, 5, 6], + [ 8, 10, 12], + [12, 15, 18]]) + + A multi-dimensional example: + >>> A = np.array([[1, 2, 3], [4, 5, 6]]) + >>> A.shape + (2, 3) + >>> B = np.array([[1, 2, 3, 4]]) + >>> B.shape + (1, 4) + >>> C = np.multiply.outer(A, B) + >>> C.shape; C + (2, 3, 1, 4) + array([[[[ 1, 2, 3, 4]], + [[ 2, 4, 6, 8]], + [[ 3, 6, 9, 12]]], + [[[ 4, 8, 12, 16]], + [[ 5, 10, 15, 20]], + [[ 6, 12, 18, 24]]]]) + + """ + + dpnp.check_supported_arrays_type( + x1, x2, scalar_type=True, all_scalars=False + ) + if dpnp.isscalar(x1) or dpnp.isscalar(x2): + _x1 = x1 + _x2 = x2 + else: + _x1 = x1[(Ellipsis,) + (None,) * x2.ndim] + _x2 = x2[(None,) * x1.ndim + (Ellipsis,)] + return self.__call__( + _x1, + _x2, + out=out, + where=where, + order=order, + dtype=dtype, + subok=subok, + **kwargs, + ) + class DPNPAngle(DPNPUnaryFunc): """Class that implements dpnp.angle unary element-wise functions.""" diff --git a/dpnp/dpnp_iface_linearalgebra.py b/dpnp/dpnp_iface_linearalgebra.py index aac7940af505..4bf0d2ab524e 100644 --- a/dpnp/dpnp_iface_linearalgebra.py +++ b/dpnp/dpnp_iface_linearalgebra.py @@ -43,10 +43,6 @@ import dpnp -# pylint: disable=no-name-in-module -from .dpnp_utils import ( - call_origin, -) from .dpnp_utils.dpnp_utils_linearalgebra import ( dpnp_dot, dpnp_einsum, @@ -851,62 +847,58 @@ def matmul( ) -def outer(x1, x2, out=None): +def outer(a, b, out=None): """ Returns the outer product of two arrays. For full documentation refer to :obj:`numpy.outer`. - Limitations - ----------- - Parameters `x1` and `x2` are supported as either scalar, - :class:`dpnp.ndarray` or :class:`dpctl.tensor.usm_ndarray`, but both - `x1` and `x2` can not be scalars at the same time. Otherwise - the functions will be executed sequentially on CPU. - Input array data types are limited by supported DPNP :ref:`Data types`. + Parameters + ---------- + a : {dpnp.ndarray, usm_ndarray} + First input vector. Input is flattened if not already 1-dimensional. + b : {dpnp.ndarray, usm_ndarray} + Second input vector. Input is flattened if not already 1-dimensional. + out : {None, dpnp.ndarray, usm_ndarray}, optional + A location where the result is stored + + Returns + ------- + out : dpnp.ndarray + out[i, j] = a[i] * b[j] See Also -------- :obj:`dpnp.einsum` : Evaluates the Einstein summation convention on the operands. :obj:`dpnp.inner` : Returns the inner product of two arrays. + :obj:`dpnp.tensordot` : dpnp.tensordot(a.ravel(), b.ravel(), axes=((), ())) + is the equivalent. Examples -------- >>> import dpnp as np >>> a = np.array([1, 1, 1]) >>> b = np.array([1, 2, 3]) - >>> result = np.outer(a, b) - >>> [x for x in result] + >>> np.outer(a, b) array([[1, 2, 3], [1, 2, 3], [1, 2, 3]]) """ - x1_is_scalar = dpnp.isscalar(x1) - x2_is_scalar = dpnp.isscalar(x2) - - if x1_is_scalar and x2_is_scalar: - pass - elif not dpnp.is_supported_array_or_scalar(x1): - pass - elif not dpnp.is_supported_array_or_scalar(x2): - pass + dpnp.check_supported_arrays_type(a, b, scalar_type=True, all_scalars=False) + if dpnp.isscalar(a): + x1 = a + x2 = b.ravel()[None, :] + elif dpnp.isscalar(b): + x1 = a.ravel()[:, None] + x2 = b else: - x1_in = ( - x1 - if x1_is_scalar - else (x1.reshape(-1) if x1.ndim > 1 else x1)[:, None] - ) - x2_in = ( - x2 - if x2_is_scalar - else (x2.reshape(-1) if x2.ndim > 1 else x2)[None, :] - ) - return dpnp.multiply(x1_in, x2_in, out=out) + x1 = a.ravel() + x2 = b.ravel() - return call_origin(numpy.outer, x1, x2, out=out) + return dpnp.multiply.outer(x1, x2, out=out) def tensordot(a, b, axes=2): diff --git a/tests/test_flipping.py b/tests/test_flipping.py index 36365be1be71..cbda60fb2123 100644 --- a/tests/test_flipping.py +++ b/tests/test_flipping.py @@ -60,7 +60,7 @@ def test_arange_4d(self, axis, dtype): ) def test_lr_equivalent(self, dtype): dp_a = dpnp.arange(4, dtype=dtype) - dp_a = dp_a[:, dpnp.newaxis] + dp_a[dpnp.newaxis, :] + dp_a = dpnp.add.outer(dp_a, dp_a) assert_equal(dpnp.flip(dp_a, 1), dpnp.fliplr(dp_a)) np_a = numpy.arange(4, dtype=dtype) @@ -72,7 +72,7 @@ def test_lr_equivalent(self, dtype): ) def test_ud_equivalent(self, dtype): dp_a = dpnp.arange(4, dtype=dtype) - dp_a = dp_a[:, dpnp.newaxis] + dp_a[dpnp.newaxis, :] + dp_a = dpnp.add.outer(dp_a, dp_a) assert_equal(dpnp.flip(dp_a, 0), dpnp.flipud(dp_a)) np_a = numpy.arange(4, dtype=dtype) diff --git a/tests/test_mathematical.py b/tests/test_mathematical.py index 966d30a1149a..935cc9e1f439 100644 --- a/tests/test_mathematical.py +++ b/tests/test_mathematical.py @@ -2875,3 +2875,45 @@ def test_bitwise_1array_input(): result = dpnp.add(1, x, dtype="f4") expected = numpy.add(1, x_np, dtype="f4") assert_dtype_allclose(result, expected) + + +@pytest.mark.parametrize( + "x_shape", + [ + (), + (2), + (3, 4), + (3, 4, 5), + ], +) +@pytest.mark.parametrize( + "y_shape", + [ + (), + (2), + (3, 4), + (3, 4, 5), + ], +) +def test_elemenwise_outer(x_shape, y_shape): + x_np = numpy.random.random(x_shape) + y_np = numpy.random.random(y_shape) + expected = numpy.multiply.outer(x_np, y_np) + + x = dpnp.asarray(x_np) + y = dpnp.asarray(y_np) + result = dpnp.multiply.outer(x, y) + + assert_dtype_allclose(result, expected) + + result_outer = dpnp.outer(x, y) + assert dpnp.allclose(result.flatten(), result_outer.flatten()) + + +def test_elemenwise_outer_scalar(): + s = 5 + x = dpnp.asarray([1, 2, 3]) + y = dpnp.asarray(s) + expected = dpnp.add.outer(x, y) + result = dpnp.add.outer(x, s) + assert_dtype_allclose(result, expected) diff --git a/tests/test_outer.py b/tests/test_outer.py index 5ab062ae90c8..3df5f4f4100c 100644 --- a/tests/test_outer.py +++ b/tests/test_outer.py @@ -42,24 +42,17 @@ class TestScalarOuter(unittest.TestCase): @testing.for_all_dtypes() @testing.numpy_cupy_allclose(type_check=False) def test_first_is_scalar(self, xp, dtype): - scalar = xp.int64(4) + scalar = 4 a = xp.arange(5**3, dtype=dtype).reshape(5, 5, 5) return xp.outer(scalar, a) @testing.for_all_dtypes() @testing.numpy_cupy_allclose(type_check=False) def test_second_is_scalar(self, xp, dtype): - scalar = xp.int32(7) + scalar = 7 a = xp.arange(5**3, dtype=dtype).reshape(5, 5, 5) return xp.outer(a, scalar) - @pytest.mark.usefixtures("allow_fall_back_on_numpy") - @testing.numpy_cupy_array_equal() - def test_both_inputs_as_scalar(self, xp): - a = xp.int64(4) - b = xp.int32(17) - return xp.outer(a, b) - class TestListOuter(unittest.TestCase): def test_list(self): @@ -67,7 +60,7 @@ def test_list(self): b: list[list[list[int]]] = a.tolist() dp_a = dp.array(a) - with assert_raises(NotImplementedError): + with assert_raises(TypeError): dp.outer(b, dp_a) dp.outer(dp_a, b) dp.outer(b, b)