diff --git a/dpnp/backend/extensions/blas/blas_py.cpp b/dpnp/backend/extensions/blas/blas_py.cpp index bb0fc5c05851..fee0c3bf6cab 100644 --- a/dpnp/backend/extensions/blas/blas_py.cpp +++ b/dpnp/backend/extensions/blas/blas_py.cpp @@ -64,13 +64,13 @@ PYBIND11_MODULE(_blas_impl, m) blas_ext::DotContigFactory>( dot_dispatch_vector); - auto dot_pypi = [&](sycl::queue exec_q, arrayT src1, arrayT src2, - arrayT dst, const event_vecT &depends = {}) { + auto dot_pyapi = [&](sycl::queue exec_q, arrayT src1, arrayT src2, + arrayT dst, const event_vecT &depends = {}) { return dot_ext::dot_func(exec_q, src1, src2, dst, depends, dot_dispatch_vector); }; - m.def("_dot", dot_pypi, + m.def("_dot", dot_pyapi, "Call `dot` from OneMKL BLAS library to return " "the dot product of two real-valued vectors.", py::arg("sycl_queue"), py::arg("vectorA"), py::arg("vectorB"), @@ -82,13 +82,13 @@ PYBIND11_MODULE(_blas_impl, m) blas_ext::DotcContigFactory>( dotc_dispatch_vector); - auto dotc_pypi = [&](sycl::queue exec_q, arrayT src1, arrayT src2, - arrayT dst, const event_vecT &depends = {}) { + auto dotc_pyapi = [&](sycl::queue exec_q, arrayT src1, arrayT src2, + arrayT dst, const event_vecT &depends = {}) { return dot_ext::dot_func(exec_q, src1, src2, dst, depends, dotc_dispatch_vector); }; - m.def("_dotc", dotc_pypi, + m.def("_dotc", dotc_pyapi, "Call `dotc` from OneMKL BLAS library to return " "the dot product of two complex vectors, " "conjugating the first vector.", @@ -101,13 +101,13 @@ PYBIND11_MODULE(_blas_impl, m) blas_ext::DotuContigFactory>( dotu_dispatch_vector); - auto dotu_pypi = [&](sycl::queue exec_q, arrayT src1, arrayT src2, - arrayT dst, const event_vecT &depends = {}) { + auto dotu_pyapi = [&](sycl::queue exec_q, arrayT src1, arrayT src2, + arrayT dst, const event_vecT &depends = {}) { return dot_ext::dot_func(exec_q, src1, src2, dst, depends, dotu_dispatch_vector); }; - m.def("_dotu", dotu_pypi, + m.def("_dotu", dotu_pyapi, "Call `dotu` from OneMKL BLAS library to return " "the dot product of two complex vectors.", py::arg("sycl_queue"), py::arg("vectorA"), py::arg("vectorB"), @@ -119,7 +119,7 @@ PYBIND11_MODULE(_blas_impl, m) "Call `gemm` from OneMKL BLAS library to return " "the matrix-matrix product with 2-D matrices.", py::arg("sycl_queue"), py::arg("matrixA"), py::arg("matrixB"), - py::arg("result"), py::arg("depends") = py::list()); + py::arg("resultC"), py::arg("depends") = py::list()); } { @@ -127,8 +127,6 @@ PYBIND11_MODULE(_blas_impl, m) "Call `gemm_batch` from OneMKL BLAS library to return " "the matrix-matrix product for a batch of 2-D matrices.", py::arg("sycl_queue"), py::arg("matrixA"), py::arg("matrixB"), - py::arg("result"), py::arg("batch_size"), py::arg("stridea"), - py::arg("strideb"), py::arg("stridec"), - py::arg("depends") = py::list()); + py::arg("resultC"), py::arg("depends") = py::list()); } } diff --git a/dpnp/backend/extensions/blas/gemm.cpp b/dpnp/backend/extensions/blas/gemm.cpp index ae20ea9efee7..c1005f797b18 100644 --- a/dpnp/backend/extensions/blas/gemm.cpp +++ b/dpnp/backend/extensions/blas/gemm.cpp @@ -59,6 +59,7 @@ typedef sycl::event (*gemm_impl_fn_ptr_t)(sycl::queue &, const std::int64_t, char *, const std::int64_t, + bool, const std::vector &); static gemm_impl_fn_ptr_t gemm_dispatch_table[dpctl_td_ns::num_types] @@ -77,6 +78,7 @@ static sycl::event gemm_impl(sycl::queue &exec_q, const std::int64_t ldb, char *resultC, const std::int64_t ldc, + bool is_row_major, const std::vector &depends) { type_utils::validate_type_for_device(exec_q); @@ -91,7 +93,25 @@ static sycl::event gemm_impl(sycl::queue &exec_q, sycl::event gemm_event; try { - gemm_event = mkl_blas::row_major::gemm( + auto gemm_func = + [&](sycl::queue &q, oneapi::mkl::transpose transA, + oneapi::mkl::transpose transB, std::int64_t m, std::int64_t n, + std::int64_t k, Tab alpha, const Tab *a, std::int64_t lda, + const Tab *b, std::int64_t ldb, Tab beta, Tc *c, + std::int64_t ldc, + const std::vector &deps) -> sycl::event { + if (is_row_major) { + return mkl_blas::row_major::gemm(q, transA, transB, m, n, k, + alpha, a, lda, b, ldb, beta, c, + ldc, deps); + } + else { + return mkl_blas::column_major::gemm(q, transA, transB, m, n, k, + alpha, a, lda, b, ldb, beta, + c, ldc, deps); + } + }; + gemm_event = gemm_func( exec_q, transA, // Defines the transpose operation for matrix A: // 'N' indicates no transpose, 'T' for transpose, @@ -130,7 +150,7 @@ static sycl::event gemm_impl(sycl::queue &exec_q, return gemm_event; } -std::pair +std::tuple gemm(sycl::queue &exec_q, dpctl::tensor::usm_ndarray matrixA, dpctl::tensor::usm_ndarray matrixB, @@ -208,16 +228,44 @@ std::pair throw py::value_error( "Result array is not c-contiguous nor f-contiguous."); } - oneapi::mkl::transpose transA = is_matrixA_f_contig - ? oneapi::mkl::transpose::T - : oneapi::mkl::transpose::N; - oneapi::mkl::transpose transB = is_matrixB_f_contig - ? oneapi::mkl::transpose::T - : oneapi::mkl::transpose::N; + bool is_row_major = true; + if (is_matrixA_f_contig && is_matrixB_f_contig) { + is_row_major = false; + } + oneapi::mkl::transpose transA; + oneapi::mkl::transpose transB; + if (is_row_major) { + transA = is_matrixA_f_contig ? oneapi::mkl::transpose::T + : oneapi::mkl::transpose::N; + transB = is_matrixB_f_contig ? oneapi::mkl::transpose::T + : oneapi::mkl::transpose::N; + } + else { + transA = oneapi::mkl::transpose::N; + transB = oneapi::mkl::transpose::N; + } - const std::int64_t lda = (transA == oneapi::mkl::transpose::N) ? k : m; - const std::int64_t ldb = (transB == oneapi::mkl::transpose::N) ? n : k; - const std::int64_t ldc = n; // always n for row_major + std::int64_t lda; + std::int64_t ldb; + if (is_row_major) { + if (transA == oneapi::mkl::transpose::N) { + lda = k; + } + else { + lda = m; + } + if (transB == oneapi::mkl::transpose::N) { + ldb = n; + } + else { + ldb = k; + } + } + else { + lda = m; + ldb = k; + } + const std::int64_t ldc = is_row_major ? n : m; int matrixA_typenum = matrixA.get_typenum(); int matrixB_typenum = matrixB.get_typenum(); @@ -242,14 +290,14 @@ std::pair char *b_typeless_ptr = matrixB.get_data(); char *r_typeless_ptr = resultC.get_data(); - sycl::event gemm_ev = - gemm_fn(exec_q, transA, transB, m, n, k, a_typeless_ptr, lda, - b_typeless_ptr, ldb, r_typeless_ptr, ldc, depends); + sycl::event gemm_ev = gemm_fn(exec_q, transA, transB, m, n, k, + a_typeless_ptr, lda, b_typeless_ptr, ldb, + r_typeless_ptr, ldc, is_row_major, depends); sycl::event args_ev = dpctl::utils::keep_args_alive( exec_q, {matrixA, matrixB, resultC}, {gemm_ev}); - return std::make_pair(args_ev, gemm_ev); + return std::make_tuple(args_ev, gemm_ev, is_row_major); } template diff --git a/dpnp/backend/extensions/blas/gemm.hpp b/dpnp/backend/extensions/blas/gemm.hpp index cd93494ce035..6e3a58402698 100644 --- a/dpnp/backend/extensions/blas/gemm.hpp +++ b/dpnp/backend/extensions/blas/gemm.hpp @@ -38,22 +38,18 @@ namespace ext { namespace blas { -extern std::pair +extern std::tuple gemm(sycl::queue &exec_q, dpctl::tensor::usm_ndarray matrixA, dpctl::tensor::usm_ndarray matrixB, dpctl::tensor::usm_ndarray resultC, const std::vector &depends); -extern std::pair +extern std::tuple gemm_batch(sycl::queue &exec_q, dpctl::tensor::usm_ndarray matrixA, dpctl::tensor::usm_ndarray matrixB, dpctl::tensor::usm_ndarray resultC, - const std::int64_t batch_size, - size_t stridea, - size_t strideb, - size_t stridec, const std::vector &depends); extern void init_gemm_dispatch_table(void); diff --git a/dpnp/backend/extensions/blas/gemm_batch.cpp b/dpnp/backend/extensions/blas/gemm_batch.cpp index 5c81ea41c972..6a1247c4c3e6 100644 --- a/dpnp/backend/extensions/blas/gemm_batch.cpp +++ b/dpnp/backend/extensions/blas/gemm_batch.cpp @@ -64,6 +64,7 @@ typedef sycl::event (*gemm_batch_impl_fn_ptr_t)( char *, char *, char *, + bool, const std::vector &); static gemm_batch_impl_fn_ptr_t @@ -77,7 +78,7 @@ static sycl::event gemm_batch_impl(sycl::queue &exec_q, const std::int64_t batch_size, const std::int64_t lda, const std::int64_t ldb, - const std::int64_t ld_result, + const std::int64_t ldc, size_t stridea, size_t strideb, size_t stridec, @@ -86,6 +87,7 @@ static sycl::event gemm_batch_impl(sycl::queue &exec_q, char *matrixA, char *matrixB, char *resultC, + bool is_row_major, const std::vector &depends) { type_utils::validate_type_for_device(exec_q); @@ -100,7 +102,26 @@ static sycl::event gemm_batch_impl(sycl::queue &exec_q, sycl::event gemm_batch_event; try { - gemm_batch_event = mkl_blas::row_major::gemm_batch( + auto gemm_batch_func = + [&](sycl::queue &q, oneapi::mkl::transpose transA, + oneapi::mkl::transpose transB, std::int64_t m, std::int64_t n, + std::int64_t k, Tab alpha, const Tab *a, std::int64_t lda, + std::int64_t stridea, const Tab *b, std::int64_t ldb, + std::int64_t strideb, Tab beta, Tc *c, std::int64_t ldc, + std::int64_t stridec, std::int64_t batch_size, + const std::vector &deps) -> sycl::event { + if (is_row_major) { + return mkl_blas::row_major::gemm_batch( + q, transA, transB, m, n, k, alpha, a, lda, stridea, b, ldb, + strideb, beta, c, ldc, stridec, batch_size, deps); + } + else { + return mkl_blas::column_major::gemm_batch( + q, transA, transB, m, n, k, alpha, a, lda, stridea, b, ldb, + strideb, beta, c, ldc, stridec, batch_size, deps); + } + }; + gemm_batch_event = gemm_batch_func( exec_q, transA, // Defines the transpose operation for matrix A: // 'N' indicates no transpose, 'T' for transpose, @@ -120,10 +141,10 @@ static sycl::event gemm_batch_impl(sycl::queue &exec_q, strideb, // Stride between different B matrices. Tab(0), // Scaling factor for matrix C. res, // Pointer to matrix C, where the result is stored. - ld_result, // Leading dimension of matrix C. + ldc, // Leading dimension of matrix C. stridec, // Stride between different C matrices. - batch_size, // Specifies the number of matrix multiply operations to - // perform. + batch_size, // Specifies the number of matrix multiply + // operations to perform. depends); } catch (oneapi::mkl::exception const &e) { error_msg << "Unexpected MKL exception caught during gemm_batch() " @@ -145,15 +166,11 @@ static sycl::event gemm_batch_impl(sycl::queue &exec_q, return gemm_batch_event; } -std::pair +std::tuple gemm_batch(sycl::queue &exec_q, dpctl::tensor::usm_ndarray matrixA, dpctl::tensor::usm_ndarray matrixB, dpctl::tensor::usm_ndarray resultC, - const std::int64_t batch_size, - size_t stridea, - size_t strideb, - size_t stridec, const std::vector &depends = {}) { const int matrixA_nd = matrixA.get_ndim(); @@ -185,49 +202,90 @@ std::pair const py::ssize_t *a_shape = matrixA.get_shape_raw(); const py::ssize_t *b_shape = matrixB.get_shape_raw(); const py::ssize_t *c_shape = resultC.get_shape_raw(); - const std::int64_t m = a_shape[matrixA_nd - 2]; - const std::int64_t n = b_shape[matrixB_nd - 1]; - const std::int64_t k = a_shape[matrixA_nd - 1]; - if (a_shape[matrixA_nd - 1] != b_shape[matrixB_nd - 2]) { + const std::int64_t m = a_shape[1]; + const std::int64_t n = b_shape[2]; + const std::int64_t k = a_shape[2]; + const std::int64_t batch_size = c_shape[0]; + if (a_shape[2] != b_shape[1]) { throw py::value_error("The number of columns in A must be equal to " "the number of rows in B."); } - if (a_shape[matrixA_nd - 2] != c_shape[resultC_nd - 2]) { + if (a_shape[1] != c_shape[1]) { throw py::value_error("The number of rows in A must be equal to " "the number of rows in result array."); } - if (b_shape[matrixB_nd - 1] != c_shape[resultC_nd - 1]) { + if (b_shape[2] != c_shape[2]) { throw py::value_error("The number of columns in B must be equal to " "the number of columns in result array."); } - bool shapes_equal = true; - size_t src_nelems = 1; - py::ssize_t lead_dim; - for (int i = 0; i < matrixA_nd - 2; ++i) { - if (a_shape[i] == b_shape[i]) { - lead_dim = a_shape[i]; - } - else if (a_shape[i] == 1 || b_shape[i] == 1) { - lead_dim = std::max(a_shape[i], b_shape[i]); - } - else { - throw py::value_error("Array shapes do not match."); - } - src_nelems *= static_cast(lead_dim); - shapes_equal = shapes_equal && (lead_dim == c_shape[i]); + std::int64_t first_dim; + if (a_shape[0] == b_shape[0]) { + first_dim = a_shape[0]; + } + else if (a_shape[0] == 1 || b_shape[0] == 1) { + first_dim = std::max(a_shape[0], b_shape[0]); } - src_nelems *= (m * n); - if (!shapes_equal) { + else { throw py::value_error("Array shapes do not match."); } + if (first_dim != c_shape[0]) { + throw py::value_error("Array shapes do not match."); + } + std::int64_t src_nelems = first_dim * m * n; dpctl::tensor::validation::CheckWritable::throw_if_not_writable(resultC); dpctl::tensor::validation::AmpleMemory::throw_if_not_ample(resultC, src_nelems); - // transA and transB are always False - oneapi::mkl::transpose transA = oneapi::mkl::transpose::N; - oneapi::mkl::transpose transB = oneapi::mkl::transpose::N; + std::vector a_stride = matrixA.get_strides_vector(); + std::vector b_stride = matrixB.get_strides_vector(); + std::vector c_stride = resultC.get_strides_vector(); + const std::int64_t stridea = a_stride[0]; + const std::int64_t strideb = b_stride[0]; + const std::int64_t stridec = c_stride[0]; + + bool A_base_is_f_contig = a_stride[1] == 1 && a_stride[2] == a_shape[1]; + bool B_base_is_f_contig = b_stride[1] == 1 && b_stride[2] == b_shape[1]; + + bool is_row_major = true; + if (A_base_is_f_contig && B_base_is_f_contig) { + is_row_major = false; + } + + oneapi::mkl::transpose transA; + oneapi::mkl::transpose transB; + if (is_row_major) { + transA = A_base_is_f_contig ? oneapi::mkl::transpose::T + : oneapi::mkl::transpose::N; + transB = B_base_is_f_contig ? oneapi::mkl::transpose::T + : oneapi::mkl::transpose::N; + } + else { + transA = oneapi::mkl::transpose::N; + transB = oneapi::mkl::transpose::N; + } + + std::int64_t lda; + std::int64_t ldb; + if (is_row_major) { + if (transA == oneapi::mkl::transpose::N) { + lda = k; + } + else { + lda = m; + } + if (transB == oneapi::mkl::transpose::N) { + ldb = n; + } + else { + ldb = k; + } + } + else { + lda = m; + ldb = k; + } + const std::int64_t ldc = is_row_major ? n : m; int matrixA_typenum = matrixA.get_typenum(); int matrixB_typenum = matrixB.get_typenum(); @@ -252,15 +310,15 @@ std::pair char *b_typeless_ptr = matrixB.get_data(); char *r_typeless_ptr = resultC.get_data(); - // Note that lda = k, ldb = n, and ld_result = n - sycl::event gemm_batch_ev = gemm_batch_fn( - exec_q, m, n, k, batch_size, k, n, n, stridea, strideb, stridec, transA, - transB, a_typeless_ptr, b_typeless_ptr, r_typeless_ptr, depends); + sycl::event gemm_batch_ev = + gemm_batch_fn(exec_q, m, n, k, batch_size, lda, ldb, ldc, stridea, + strideb, stridec, transA, transB, a_typeless_ptr, + b_typeless_ptr, r_typeless_ptr, is_row_major, depends); sycl::event args_batch_ev = dpctl::utils::keep_args_alive( exec_q, {matrixA, matrixB, resultC}, {gemm_batch_ev}); - return std::make_pair(args_batch_ev, gemm_batch_ev); + return std::make_tuple(args_batch_ev, gemm_batch_ev, is_row_major); } template diff --git a/dpnp/dpnp_utils/dpnp_utils_linearalgebra.py b/dpnp/dpnp_utils/dpnp_utils_linearalgebra.py index 92ea44e6c925..ed28d8b6785b 100644 --- a/dpnp/dpnp_utils/dpnp_utils_linearalgebra.py +++ b/dpnp/dpnp_utils/dpnp_utils_linearalgebra.py @@ -108,40 +108,60 @@ def _chr(label): return chr(label) -def _create_result_array(x1, x2, out, shape, dtype, usm_type, sycl_queue): +def _compute_res_dtype(*arrays, dtype, casting, sycl_queue): """ - Create the result array. + Determines the output array data type and an intermediate data type + used in performing calculations related to a specific math function. + If dtype is ``None``, the output array data type of the operation is + determined based on the Promotion Type Rule and device capabilities. + Otherwise, `dtype` is used as output array dtype, if input arrays + can cast to it according to the casting rule determined. If casting + cannot be done, a ``TypeError`` is raised. + The intermediate data type is the data type used for performing the math + function calculations. If output array dtype is a floating-point data type, + it is also used for the intermediate data type. If output array dtype is an + integral data type, the default floating point data type of the device where + input arrays are allocated on are used for intermediate data type. + + Parameters + ---------- + arrays : {dpnp.ndarray, usm_ndarray} + Input arrays. + dtype : dtype + If not ``None``, data type of the output array. + casting : {"no", "equiv", "safe", "same_kind", "unsafe"}, optional + Controls what kind of data casting may occur. + sycl_queue : {SyclQueue} + A SYCL queue to use for determining default floating point datat type. + + Returns + ------- + compute_dtype, res_dtype : + `compute_dtype` is the data type used in performing math function calculations. + The input arrays of the math function are cast to `compute_dtype` and then + the calculations are performed. + `res_dtype` is the output data type. When the result is obtained, it is cast + to `res_dtype`. - If `out` is not ``None`` and its features match the specified `shape`, `dtype, - `usm_type`, and `sycl_queue` and it is C-contiguous or F-contiguous and - does not have any memory overlap with `x1` and `x2`, `out` itself is returned. - If these conditions are not satisfied, an empty array is returned with the - specified `shape`, `dtype, `usm_type`, and `sycl_queue`. """ - if out is not None: - x1_usm = dpnp.get_usm_ndarray(x1) - x2_usm = dpnp.get_usm_ndarray(x2) - out_usm = dpnp.get_usm_ndarray(out) + res_dtype = dpnp.result_type(*arrays) + default_dtype = dpnp.default_float_type(sycl_queue=sycl_queue) - if ( - out.dtype == dtype - and out.shape == shape - and out.usm_type == usm_type - and out.sycl_queue == sycl_queue - and out.flags.c_contiguous - and not ti._array_overlap(x1_usm, out_usm) - and not ti._array_overlap(x2_usm, out_usm) - ): - return out + if dtype is not None: + if dpnp.can_cast(res_dtype, dtype, casting=casting): + res_dtype = dtype + else: + raise TypeError( + f"Cannot cast from dtype({res_dtype}) to dtype({dtype}) with casting rule {casting}" + ) - return dpnp.empty( - shape, - dtype=dtype, - usm_type=usm_type, - sycl_queue=sycl_queue, + compute_dtype = ( + res_dtype if dpnp.issubdtype(res_dtype, dpnp.inexact) else default_dtype ) + return compute_dtype, res_dtype + def _compute_size_by_dict(indices, idx_dict): """ @@ -198,20 +218,20 @@ def _compute_size(start, shape): return ret -def _copy_array(x, dep_events, host_events, contig_copy=False, dtype=None): +def _copy_array(x, dep_events, host_events, copy_flag=False, dtype=None): """ Creating a copy of input array if needed. - If `contig_copy` is ``True``, a C-contiguous copy of input array is returned. + If `copy_flag` is ``True``, a C-contiguous copy of input array is returned. In this case, the copy array has the input array data type unless `dtype` is determined. - If `contig_copy` is ``False`` and input array data type is different than `dtype`, + If `copy_flag` is ``False`` and input array data type is different than `dtype`, a C-contiguous copy of input array with specified `dtype` is returned. """ - if contig_copy: - copy = contig_copy + if copy_flag: + copy = copy_flag else: copy = x.dtype != dtype if dtype is not None else False @@ -228,6 +248,62 @@ def _copy_array(x, dep_events, host_events, contig_copy=False, dtype=None): return x +def _create_result_array(x1, x2, out, shape, dtype, usm_type, sycl_queue): + """ + Create the result array. + + If `out` is not ``None`` and its features match the specified `shape`, `dtype, + `usm_type`, and `sycl_queue` and it is C-contiguous or F-contiguous and + does not have any memory overlap with `x1` and `x2`, `out` itself is returned. + If these conditions are not satisfied, an empty array is returned with the + specified `shape`, `dtype, `usm_type`, and `sycl_queue`. + """ + + if out is not None: + x1_usm = dpnp.get_usm_ndarray(x1) + x2_usm = dpnp.get_usm_ndarray(x2) + out_usm = dpnp.get_usm_ndarray(out) + contig_flag = _define_contig_flag(out) + + if ( + out.dtype == dtype + and out.shape == shape + and out.usm_type == usm_type + and out.sycl_queue == sycl_queue + and contig_flag + and not ti._array_overlap(x1_usm, out_usm) + and not ti._array_overlap(x2_usm, out_usm) + ): + return out + + return dpnp.empty( + shape, + dtype=dtype, + usm_type=usm_type, + sycl_queue=sycl_queue, + ) + + +def _define_contig_flag(x): + """ + Determines if the data in last two dimensions of array `x` are + c_contiguous or f_contiguous. For 2D arrays, it is the same as using + x.flags.c_contiguous or x.flags.f_contiguous. + """ + + flag = False + x_strides = x.strides + x_shape = x.shape + if x.ndim < 2: + return True + + x_is_c_contiguous = x_strides[-1] == 1 and x_strides[-2] == x_shape[-1] + x_is_f_contiguous = x_strides[-2] == 1 and x_strides[-1] == x_shape[-2] + if x_is_c_contiguous or x_is_f_contiguous: + flag = True + return flag + + def _einsum_diagonals(input_subscripts, operands): """ Adopted from _einsum_diagonals in cupy/core/_einsum.py @@ -489,54 +565,53 @@ def _flop_count(idx_contraction, inner, num_terms, size_dictionary): return overall_size * op_factor -def _gemm_batch_matmul(exec_q, x1, x2, res, x1_is_2D, x2_is_2D, dev_tasks_list): - # If input array is F-contiguous, we need to change the order to C-contiguous. - # because mkl::gemm_bacth needs each 2D array to be F-contiguous but - # when the input array is F-contiguous, the data of 2D array - # that needs to be called in mkl::gemm_batch are not contiguous. +def _gemm_batch_matmul(exec_q, x1, x2, res, dev_tasks_list): + # arrays here are already at least 3D, make them 3D + x1 = x1.reshape(-1, x1.shape[-2], x1.shape[-1]) + x2 = x2.reshape(-1, x2.shape[-2], x2.shape[-1]) + orig_shape = res.shape + res = res.reshape(-1, res.shape[-2], res.shape[-1]) + ht_tasks_list = [] - contig_copy = not x1.flags.c_contiguous - x1 = _copy_array(x1, dev_tasks_list, ht_tasks_list, contig_copy=contig_copy) - contig_copy = not x2.flags.c_contiguous - x2 = _copy_array(x2, dev_tasks_list, ht_tasks_list, contig_copy=contig_copy) - - x1_strides = x1.strides - x2_strides = x2.strides - res_strides = res.strides - - # need to standardize to use in ti._contract_iter2 - x1_strides = _standardize_strides(x1_strides, x1_is_2D, x1.shape, x1.ndim) - x2_strides = _standardize_strides(x2_strides, x2_is_2D, x2.shape, x2.ndim) - - batch_size = res.shape[:-2][0] - stridea = x1_strides[0] - strideb = x2_strides[0] - stridec = res_strides[-3] - - if x1.ndim > 3: - iter = ti._contract_iter2( - res.shape[:-2], x1_strides[:-2], x2_strides[:-2] + # gemm_batch does not handle negative strides, make a copy if needed + x1 = _copy_array( + x1, dev_tasks_list, ht_tasks_list, copy_flag=x1.strides[0] < 0 + ) + x2 = _copy_array( + x2, dev_tasks_list, ht_tasks_list, copy_flag=x2.strides[0] < 0 + ) + res = _copy_array( + res, dev_tasks_list, ht_tasks_list, copy_flag=res.strides[0] < 0 + ) + # onemkl::blas::gemm_bacth throws an exception (Provided range is out + # of integer limits) if the batch_size is too large (>=4096*4096), so + # we need to split the batch into smaller chunks + chunk = 2048 * 2048 + batch_size = res.shape[0] + for i in range(0, batch_size, chunk): + x1_usm = dpnp.get_usm_ndarray(x1[i : i + chunk, ...]) + x2_usm = dpnp.get_usm_ndarray(x2[i : i + chunk, ...]) + res_usm = dpnp.get_usm_ndarray(res[i : i + chunk, ...]) + ht_blas_ev, _, row_major = bi._gemm_batch( + exec_q, + x1_usm, + x2_usm, + res_usm, + dev_tasks_list, ) + ht_tasks_list.append(ht_blas_ev) + dpctl.SyclEvent.wait_for(ht_tasks_list) + res_shape = res.shape + if not row_major: + res = dpnp.reshape( + res.ravel(), (batch_size, res_shape[2], res_shape[1]) + ).transpose(0, 2, 1) - if len(iter[0]) != 1: - raise ValueError("Input arrays cannot be used in gemm_batch") - batch_size = iter[0][0] - stridea = iter[1][0] - strideb = iter[3][0] - - ht_blas_ev, _ = bi._gemm_batch( - exec_q, - dpnp.get_usm_ndarray(x1), - dpnp.get_usm_ndarray(x2), - dpnp.get_usm_ndarray(res), - batch_size, - stridea, - strideb, - stridec, - dev_tasks_list, - ) + if res_shape != orig_shape: + res = res.reshape(orig_shape) - return ht_blas_ev, ht_tasks_list, res + res = dpnp.ascontiguousarray(res) + return res def _greedy_path(input_sets, output_set, idx_dict, memory_limit): @@ -657,34 +732,6 @@ def _greedy_path(input_sets, output_set, idx_dict, memory_limit): return path -def _iter_path_pairs(path): - """ - Copied from _iter_path_pairs in cupy/core/_einsum.py - - Decompose path into binary path - - Parameters - ---------- - path : sequence of tuples of ints - - Yields - ------ - tuple of ints - pair (idx0, idx1) that represents the operation - {pop(idx0); pop(idx1); append();} - - """ - - for indices in path: - assert all(idx >= 0 for idx in indices) - # [3, 1, 4, 9] -> [(9, 4), (-1, 3), (-1, 1)] - if len(indices) >= 2: - indices = sorted(indices, reverse=True) - yield indices[0], indices[1] - for idx in indices[2:]: - yield -1, idx - - def _index_linear_to_tuple(shape, linear_id): """ Convert a linear index to a tuple of indices in a multi-dimensional array. @@ -713,6 +760,34 @@ def _index_linear_to_tuple(shape, linear_id): return tuple(indices) +def _iter_path_pairs(path): + """ + Copied from _iter_path_pairs in cupy/core/_einsum.py + + Decompose path into binary path + + Parameters + ---------- + path : sequence of tuples of ints + + Yields + ------ + tuple of ints + pair (idx0, idx1) that represents the operation + {pop(idx0); pop(idx1); append();} + + """ + + for indices in path: + assert all(idx >= 0 for idx in indices) + # [3, 1, 4, 9] -> [(9, 4), (-1, 3), (-1, 1)] + if len(indices) >= 2: + indices = sorted(indices, reverse=True) + yield indices[0], indices[1] + for idx in indices[2:]: + yield -1, idx + + def _make_transpose_axes(sub, b_dims, c_dims): """Copied from _make_transpose_axes in cupy/core/_einsum.py""" bs = [] @@ -732,63 +807,6 @@ def _make_transpose_axes(sub, b_dims, c_dims): ) -def _op_res_dtype(*arrays, dtype, casting, sycl_queue): - """ - _op_res_dtype(*arrays, dtype, casting, sycl_queue) - - Determines the output array data type and an intermediate data type - used in performing calculations related to a specific math function. - If dtype is ``None``, the output array data type of the operation is - determined based on the Promotion Type Rule and device capabilities. - Otherwise, `dtype` is used as output array dtype, if input arrays - can cast to it according to the casting rule determined. If casting - cannot be done, a ``TypeError`` is raised. - The intermediate data type is the data type used for performing the math - function calculations. If output array dtype is a floating-point data type, - it is also used for the intermediate data type. If output array dtype is an - integral data type, the default floating point data type of the device where - input arrays are allocated on are used for intermediate data type. - - Parameters - ---------- - arrays : {dpnp.ndarray, usm_ndarray} - Input arrays. - dtype : dtype - If not ``None``, data type of the output array. - casting : {"no", "equiv", "safe", "same_kind", "unsafe"}, optional - Controls what kind of data casting may occur. - sycl_queue : {SyclQueue} - A SYCL queue to use for determining default floating point datat type. - - Returns - ------- - op_dtype, res_dtype : - `op_dtype` is the data type used in performing math function calculations. - The input arrays of the math function are cast to `op_dtype` and then - the calculations are performed. - `res_dtype` is the output data type. When the result is obtained, it is cast - to `res_dtype`. - - """ - - res_dtype = dpnp.result_type(*arrays) - default_dtype = dpnp.default_float_type(sycl_queue=sycl_queue) - - if dtype is not None: - if dpnp.can_cast(res_dtype, dtype, casting=casting): - res_dtype = dtype - else: - raise TypeError( - f"Cannot cast from dtype({res_dtype}) to dtype({dtype}) with casting rule {casting}" - ) - - op_dtype = ( - res_dtype if dpnp.issubdtype(res_dtype, dpnp.inexact) else default_dtype - ) - - return op_dtype, res_dtype - - def _optimal_path(input_sets, output_set, idx_dict, memory_limit): """ Copied from _optimal_path in numpy/core/einsumfunc.py @@ -1119,31 +1137,6 @@ def _parse_possible_contraction( return [sort, positions, new_input_sets] -def _shape_error(a, b, core_dim, err_msg): - if err_msg == 0: - raise ValueError( - "Input arrays have a mismatch in their core dimensions. " - "The core dimensions should follow this signature: (n?,k),(k,m?)->(n?,m?) " - f"(size {a} is different from {b})" - ) - elif err_msg == 1: - raise ValueError( - f"Output array has a mismatch in its core dimension {core_dim}. " - "The core dimensions should follow this signature: (n?,k),(k,m?)->(n?,m?) " - f"(size {a} is different from {b})" - ) - elif err_msg == 2: - raise ValueError( - "Input arrays could not be broadcast together with remapped shapes, " - f"{a} is different from {b}." - ) - elif err_msg == 3: - raise ValueError( - "Output array could not be broadcast to input arrays with remapped shapes, " - f"{a} is different from {b}." - ) - - def _reduced_binary_einsum(arr0, sub0, arr1, sub1, sub_others): """Copied from _reduced_binary_einsum in cupy/core/_einsum.py""" @@ -1188,40 +1181,29 @@ def _reduced_binary_einsum(arr0, sub0, arr1, sub1, sub_others): return arr_out, sub_out -def _standardize_strides(strides, inherently_2D, shape, ndim): - """ - Standardizing the strides. - - When shape of an array along any particular dimension is 1, the stride - along that dimension is undefined. This functions standardize the strides - in the following way: - For N-D arrays that are inherently 2D (all dimesnsion are one except for two of them), - we use zero as the stride for dimensions equal one. - For other N-D arrays, the non-zero value of strides is calculated and used. - - """ - - if inherently_2D: - stndrd_strides = tuple( - str_i if sh_i > 1 else 0 for sh_i, str_i in zip(shape, strides) +def _shape_error(a, b, core_dim, err_msg): + if err_msg == 0: + raise ValueError( + "Input arrays have a mismatch in their core dimensions. " + "The core dimensions should follow this signature: (n?,k),(k,m?)->(n?,m?) " + f"(size {a} is different from {b})" ) - else: - stndrd_strides = [ - numpy.prod(shape[i + 1 :]) if strides[i] == 0 else strides[i] - for i in range(ndim - 1) - ] - # last dimension - stndrd_strides.append( - 1 if strides[ndim - 1] == 0 else strides[ndim - 1] + elif err_msg == 1: + raise ValueError( + f"Output array has a mismatch in its core dimension {core_dim}. " + "The core dimensions should follow this signature: (n?,k),(k,m?)->(n?,m?) " + f"(size {a} is different from {b})" + ) + elif err_msg == 2: + raise ValueError( + "Input arrays could not be broadcast together with remapped shapes, " + f"{a} is different from {b}." + ) + elif err_msg == 3: + raise ValueError( + "Output array could not be broadcast to input arrays with remapped shapes, " + f"{a} is different from {b}." ) - stndrd_strides = tuple(stndrd_strides) - - return stndrd_strides - - -def _tuple_sorted_by_0(zs): - """Copied from _tuple_sorted_by_0 in cupy/core/_einsum.py""" - return tuple(i for _, i in sorted(zs)) def _transpose_ex(a, axeses): @@ -1256,6 +1238,11 @@ def _transpose_ex(a, axeses): return a +def _tuple_sorted_by_0(zs): + """Copied from _tuple_sorted_by_0 in cupy/core/_einsum.py""" + return tuple(i for _, i in sorted(zs)) + + def _update_other_results(results, best): """ Copied from _update_other_results in numpy/core/einsumfunc.py @@ -1541,6 +1528,79 @@ def dpnp_cross(a, b, cp, exec_q): return cp +def dpnp_dot(a, b, /, out=None, *, conjugate=False): + """ + Return the dot product of two arrays. + + The routine that is used to perform the main calculation + depends on input arrays data type: 1) For integer and boolean data types, + `dpctl.tensor.vecdot` form the Data Parallel Control library is used, + 2) For real-valued floating point data types, `dot` routines from + BLAS library of OneMKL are used, and 3) For complex data types, + `dotu` or `dotc` routines from BLAS library of OneMKL are used. + If `conjugate` is ``False``, `dotu` is used. Otherwise, `dotc` is used, + for which the first array is conjugated before calculating the dot product. + + """ + + if a.size != b.size: + raise ValueError( + "Input arrays have a mismatch in their size. " + f"(size {a.size} is different from {b.size})" + ) + + res_usm_type, exec_q = get_usm_allocations([a, b]) + + # Determine the appropriate data types + # casting is irrelevant here since dtype is `None` + dot_dtype, res_dtype = _compute_res_dtype( + a, b, dtype=None, casting="no", sycl_queue=exec_q + ) + + result = _create_result_array( + a, b, out, (), dot_dtype, res_usm_type, exec_q + ) + # input arrays should have the proper data type + dep_events_list = [] + host_tasks_list = [] + if dpnp.issubdtype(res_dtype, dpnp.inexact): + # copying is needed if dtypes of input arrays are different + a = _copy_array(a, dep_events_list, host_tasks_list, dtype=dot_dtype) + b = _copy_array(b, dep_events_list, host_tasks_list, dtype=dot_dtype) + if dpnp.issubdtype(res_dtype, dpnp.complexfloating): + if conjugate: + dot_func = "_dotc" + else: + dot_func = "_dotu" + ht_ev, _ = getattr(bi, dot_func)( + exec_q, + dpnp.get_usm_ndarray(a), + dpnp.get_usm_ndarray(b), + dpnp.get_usm_ndarray(result), + dep_events_list, + ) + else: + ht_ev, _ = bi._dot( + exec_q, + dpnp.get_usm_ndarray(a), + dpnp.get_usm_ndarray(b), + dpnp.get_usm_ndarray(result), + dep_events_list, + ) + host_tasks_list.append(ht_ev) + dpctl.SyclEvent.wait_for(host_tasks_list) + else: + dpt_a = dpnp.get_usm_ndarray(a) + dpt_b = dpnp.get_usm_ndarray(b) + result = dpnp_array._create_from_usm_ndarray(dpt.vecdot(dpt_a, dpt_b)) + + if dot_dtype != res_dtype: + result = result.astype(res_dtype, copy=False) + + # numpy.dot does not allow casting even if it is safe + return dpnp.get_result_array(result, out, casting="no") + + def dpnp_einsum( *operands, out=None, dtype=None, order="K", casting="safe", optimize=False ): @@ -1780,79 +1840,6 @@ def dpnp_kron(a, b, a_ndim, b_ndim): return result.reshape(tuple(numpy.multiply(a_shape, b_shape))) -def dpnp_dot(a, b, /, out=None, *, conjugate=False): - """ - Return the dot product of two arrays. - - The routine that is used to perform the main calculation - depends on input arrays data type: 1) For integer and boolean data types, - `dpctl.tensor.vecdot` form the Data Parallel Control library is used, - 2) For real-valued floating point data types, `dot` routines from - BLAS library of OneMKL are used, and 3) For complex data types, - `dotu` or `dotc` routines from BLAS library of OneMKL are used. - If `conjugate` is ``False``, `dotu` is used. Otherwise, `dotc` is used, - for which the first array is conjugated before calculating the dot product. - - """ - - if a.size != b.size: - raise ValueError( - "Input arrays have a mismatch in their size. " - f"(size {a.size} is different from {b.size})" - ) - - res_usm_type, exec_q = get_usm_allocations([a, b]) - - # Determine the appropriate data types - # casting is irrelevant here since dtype is `None` - dot_dtype, res_dtype = _op_res_dtype( - a, b, dtype=None, casting="no", sycl_queue=exec_q - ) - - result = _create_result_array( - a, b, out, (), dot_dtype, res_usm_type, exec_q - ) - # input arrays should have the proper data type - dep_events_list = [] - host_tasks_list = [] - if dpnp.issubdtype(res_dtype, dpnp.inexact): - # copying is needed if dtypes of input arrays are different - a = _copy_array(a, dep_events_list, host_tasks_list, dtype=dot_dtype) - b = _copy_array(b, dep_events_list, host_tasks_list, dtype=dot_dtype) - if dpnp.issubdtype(res_dtype, dpnp.complexfloating): - if conjugate: - dot_func = "_dotc" - else: - dot_func = "_dotu" - ht_ev, _ = getattr(bi, dot_func)( - exec_q, - dpnp.get_usm_ndarray(a), - dpnp.get_usm_ndarray(b), - dpnp.get_usm_ndarray(result), - dep_events_list, - ) - else: - ht_ev, _ = bi._dot( - exec_q, - dpnp.get_usm_ndarray(a), - dpnp.get_usm_ndarray(b), - dpnp.get_usm_ndarray(result), - dep_events_list, - ) - host_tasks_list.append(ht_ev) - dpctl.SyclEvent.wait_for(host_tasks_list) - else: - dpt_a = dpnp.get_usm_ndarray(a) - dpt_b = dpnp.get_usm_ndarray(b) - result = dpnp_array._create_from_usm_ndarray(dpt.vecdot(dpt_a, dpt_b)) - - if dot_dtype != res_dtype: - result = result.astype(res_dtype, copy=False) - - # numpy.dot does not allow casting even if it is safe - return dpnp.get_result_array(result, out, casting="no") - - def dpnp_matmul( x1, x2, @@ -1939,7 +1926,7 @@ def dpnp_matmul( _shape_error(out_shape[-1], x2_shape[-1], 0, 1) # Determine the appropriate data types - gemm_dtype, res_dtype = _op_res_dtype( + gemm_dtype, res_dtype = _compute_res_dtype( x1, x2, dtype=dtype, casting=casting, sycl_queue=exec_q ) @@ -2010,24 +1997,24 @@ def dpnp_matmul( elif x1.size == 0 or x2.size == 0: result.fill(0) else: - # input arrays should have the proper data type - # and be C_CONTIGUOUS or F_CONTIGUOUS + # input arrays should have the proper data type and + # their base (last 2-dimensions) to be c-contiguous or f-contiguous dep_events_list = [] host_tasks_list = [] - contig_copy = not (x1.flags.c_contiguous or x1.flags.f_contiguous) + contig_flag = _define_contig_flag(x1) x1 = _copy_array( x1, dep_events_list, host_tasks_list, - contig_copy=contig_copy, + copy_flag=not contig_flag, dtype=gemm_dtype, ) - contig_copy = not (x2.flags.c_contiguous or x2.flags.f_contiguous) + contig_flag = _define_contig_flag(x2) x2 = _copy_array( x2, dep_events_list, host_tasks_list, - contig_copy=contig_copy, + copy_flag=not contig_flag, dtype=gemm_dtype, ) @@ -2036,33 +2023,32 @@ def dpnp_matmul( # gain performance. # TODO: investigate usage of syrk function from BLAS in # case of a.T @ a and a @ a.T to gain performance. + row_major = True if x1_is_2D and x2_is_2D: - ht_blas_ev, _ = bi._gemm( + ht_blas_ev, _, row_major = bi._gemm( exec_q, dpnp.get_usm_ndarray(x1), dpnp.get_usm_ndarray(x2), dpnp.get_usm_ndarray(result), dep_events_list, ) + host_tasks_list.append(ht_blas_ev) else: - ( - ht_blas_ev, - ht_copy_ev, - result, - ) = _gemm_batch_matmul( + result = _gemm_batch_matmul( exec_q, x1, x2, result, - x1_is_2D, - x2_is_2D, dep_events_list, ) - host_tasks_list += ht_copy_ev - host_tasks_list.append(ht_blas_ev) dpctl.SyclEvent.wait_for(host_tasks_list) - + if not row_major: + # TODO: investigate the possibility of defining result + # array with "F" order for this case + result = dpnp.ascontiguousarray( + dpnp.reshape(result.ravel(), result.shape, order="F") + ) if appended_axes: result = dpnp.squeeze(result, tuple(appended_axes)) if len(appended_axes) == 2 and out is not None: @@ -2093,6 +2079,10 @@ def dpnp_matmul( else: return result else: + # TODO: There is oppurtinuty to improve performance when out keyword + # is present. For some cases, out is NOT result but they have the same + # base (They are views of the same data). In this case, we can avoid + # copyign result to out. result = dpnp.get_result_array(result, out, casting=casting) if axes is not None and out is result: # out and out_orig contain the same data but they have different shape diff --git a/tests/test_mathematical.py b/tests/test_mathematical.py index dc2dfbd92fe5..854fa0c7f6a9 100644 --- a/tests/test_mathematical.py +++ b/tests/test_mathematical.py @@ -2578,9 +2578,14 @@ def setup_method(self): "shape_pair", [ ((4,), (4,)), + ((1, 4), (4, 1)), ((4,), (4, 2)), + ((1, 4), (4, 2)), ((2, 4), (4,)), + ((2, 4), (4, 1)), ((1, 4), (4,)), # output should be 1-d not 0-d + ((4,), (4, 1)), + ((1, 4), (4, 1)), ((2, 4), (4, 3)), ((1, 2, 3), (1, 3, 5)), ((4, 2, 3), (4, 3, 5)), @@ -2605,11 +2610,15 @@ def setup_method(self): ((1, 5, 3, 2), (6, 5, 2, 4)), ((5, 3, 2), (6, 5, 2, 4)), ((1, 3, 3), (10, 1, 3, 1)), + ((2, 3, 3), (10, 1, 3, 1)), + ((10, 2, 3, 3), (10, 1, 3, 1)), ], ) def test_matmul(self, order_pair, shape_pair): order1, order2 = order_pair shape1, shape2 = shape_pair + # input should be float type otherwise they are copied to c-contigous array + # so testing order becomes meaningless dtype = dpnp.default_float_type() a1 = numpy.arange(numpy.prod(shape1), dtype=dtype).reshape(shape1) a2 = numpy.arange(numpy.prod(shape2), dtype=dtype).reshape(shape2) @@ -2652,8 +2661,9 @@ def test_matmul(self, order_pair, shape_pair): def test_matmul_empty(self, order_pair, shape_pair): order1, order2 = order_pair shape1, shape2 = shape_pair - a1 = numpy.arange(numpy.prod(shape1)).reshape(shape1) - a2 = numpy.arange(numpy.prod(shape2)).reshape(shape2) + dtype = dpnp.default_float_type() + a1 = numpy.arange(numpy.prod(shape1), dtype=dtype).reshape(shape1) + a2 = numpy.arange(numpy.prod(shape2), dtype=dtype).reshape(shape2) a1 = numpy.array(a1, order=order1) a2 = numpy.array(a2, order=order2) @@ -2928,18 +2938,84 @@ def test_matmul_order(self, order, shape_pair): [(-2, -2, -2, -2), (2, 2, 2, 2), (-2, 2, -2, 2), (2, -2, 2, -2)], ids=["-2", "2", "(-2, 2)", "(2, -2)"], ) - def test_matmul_strided(self, stride): + def test_matmul_strided1(self, stride): for dim in [1, 2, 3, 4]: - A = numpy.random.rand(*([20] * dim)) - B = dpnp.asarray(A) + shape = tuple(20 for _ in range(dim)) + A = numpy.random.rand(*shape) + A_dp = dpnp.asarray(A) slices = tuple(slice(None, None, stride[i]) for i in range(dim)) a = A[slices] - b = B[slices] - - result = dpnp.matmul(b, b) + a_dp = A_dp[slices] + # input arrays will be copied into c-contiguous arrays + # the 2D base is not c-contiguous nor f-contigous + result = dpnp.matmul(a_dp, a_dp) expected = numpy.matmul(a, a) assert_dtype_allclose(result, expected) + OUT = dpnp.empty(shape, dtype=result.dtype) + out = OUT[slices] + result = dpnp.matmul(a_dp, a_dp, out=out) + assert result is out + assert_dtype_allclose(result, expected) + + @pytest.mark.parametrize( + "shape", [(10, 3, 3), (12, 10, 3, 3)], ids=["3D", "4D"] + ) + @pytest.mark.parametrize("stride", [-1, -2, 2], ids=["-1", "-2", "2"]) + @pytest.mark.parametrize("transpose", [False, True], ids=["False", "True"]) + def test_matmul_strided2(self, shape, stride, transpose): + # one dimension (-3) is strided + # if negative stride, copy is needed and the base becomes c-contiguous + # otherwise the base remains the same as input in gemm_batch + A = numpy.random.rand(*shape) + A_dp = dpnp.asarray(A) + if transpose: + A = numpy.moveaxis(A, (-2, -1), (-1, -2)) + A_dp = dpnp.moveaxis(A_dp, (-2, -1), (-1, -2)) + index = [slice(None)] * len(shape) + index[-3] = slice(None, None, stride) + index = tuple(index) + a = A[index] + a_dp = A_dp[index] + result = dpnp.matmul(a_dp, a_dp) + expected = numpy.matmul(a, a) + assert_dtype_allclose(result, expected) + + OUT = dpnp.empty(shape, dtype=result.dtype) + out = OUT[index] + result = dpnp.matmul(a_dp, a_dp, out=out) + assert result is out + assert_dtype_allclose(result, expected) + + @pytest.mark.parametrize( + "stride", + [(-2, -2), (2, 2), (-2, 2), (2, -2)], + ids=["(-2, -2)", "(2, 2)", "(-2, 2)", "(2, -2)"], + ) + @pytest.mark.parametrize("transpose", [False, True], ids=["False", "True"]) + def test_matmul_strided3(self, stride, transpose): + # 4D case, the 1st and 2nd dimensions are strided + # For negative stride, copy is needed and the base becomes c-contiguous. + # For positive stride, no copy but reshape makes the base c-contiguous. + stride0, stride1 = stride + shape = (12, 10, 3, 3) # 4D array + A = numpy.random.rand(*shape) + A_dp = dpnp.asarray(A) + if transpose: + A = numpy.moveaxis(A, (-2, -1), (-1, -2)) + A_dp = dpnp.moveaxis(A_dp, (-2, -1), (-1, -2)) + a = A[::stride0, ::stride1] + a_dp = A_dp[::stride0, ::stride1] + result = dpnp.matmul(a_dp, a_dp) + expected = numpy.matmul(a, a) + assert_dtype_allclose(result, expected) + + OUT = dpnp.empty(shape, dtype=result.dtype) + out = OUT[::stride0, ::stride1] + result = dpnp.matmul(a_dp, a_dp, out=out) + assert result is out + assert_dtype_allclose(result, expected) + @pytest.mark.parametrize( "dtype", get_all_dtypes(no_none=True, no_bool=True) ) @@ -2975,6 +3051,30 @@ def test_matmul_out_0D(self, out_shape): assert result is dpnp_out assert_dtype_allclose(result, expected) + @pytest.mark.skipif(is_cpu_device(), reason="large size") + @pytest.mark.parametrize( + "shape", + [ + ((4096, 4096, 4, 4)), + ((2048, 2048, 8, 8)), + ], + ) + def test_matmul_large(self, shape): + size = numpy.prod(shape, dtype=int) + a = numpy.array(numpy.random.uniform(-5, 5, size)).reshape(shape) + a_dp = dpnp.asarray(a) + + result = dpnp.matmul(a_dp, a_dp) + expected = numpy.matmul(a, a) + assert_dtype_allclose(result, expected, factor=24) + + # make the 2-d base f-contiguous + a = a.transpose(0, 1, 3, 2) + a_dp = a_dp.transpose(0, 1, 3, 2) + result = dpnp.matmul(a_dp, a_dp) + expected = numpy.matmul(a, a) + assert_dtype_allclose(result, expected, factor=24) + class TestMatmulInvalidCases: @pytest.mark.parametrize( @@ -2996,10 +3096,12 @@ def test_zero_dim(self, shape_pair): @pytest.mark.parametrize( "shape_pair", [ - ((5, 3, 1), (3, 1, 4)), - ((3, 2, 3), (3, 2, 4)), - ((3, 2), (1,)), - ((1, 2), (3, 1)), + ((2, 3), (4, 5)), + ((2, 4), (3, 5)), + ((2, 3), (4,)), + ((3,), (4, 5)), + ((2, 2, 3), (2, 4, 5)), + ((3, 2, 3), (2, 4, 5)), ((4, 3, 2), (6, 5, 2, 4)), ((6, 5, 3, 2), (3, 2, 4)), ], @@ -3020,8 +3122,8 @@ def test_invalid_shape(self, shape_pair): ((5, 4, 3), (3, 1), (5, 4, 2)), ((5, 4, 3), (3,), (5, 3)), ((5, 4, 3), (3,), (6, 4)), - ((3,), (3, 4, 5), (3, 5)), - ((3,), (3, 4, 5), (4, 6)), + ((4,), (3, 4, 5), (4, 5)), + ((4,), (3, 4, 5), (3, 6)), ], ) def test_invalid_shape_out(self, shape_pair):