diff --git a/dpnp/backend/extensions/lapack/CMakeLists.txt b/dpnp/backend/extensions/lapack/CMakeLists.txt index 03fceb05482e..683bb34009bd 100644 --- a/dpnp/backend/extensions/lapack/CMakeLists.txt +++ b/dpnp/backend/extensions/lapack/CMakeLists.txt @@ -30,6 +30,7 @@ set(_module_src ${CMAKE_CURRENT_SOURCE_DIR}/geqrf.cpp ${CMAKE_CURRENT_SOURCE_DIR}/geqrf_batch.cpp ${CMAKE_CURRENT_SOURCE_DIR}/gesv.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/gesv_batch.cpp ${CMAKE_CURRENT_SOURCE_DIR}/gesvd.cpp ${CMAKE_CURRENT_SOURCE_DIR}/getrf.cpp ${CMAKE_CURRENT_SOURCE_DIR}/getrf_batch.cpp diff --git a/dpnp/backend/extensions/lapack/common_helpers.hpp b/dpnp/backend/extensions/lapack/common_helpers.hpp index df2b87028e48..88e791256dce 100644 --- a/dpnp/backend/extensions/lapack/common_helpers.hpp +++ b/dpnp/backend/extensions/lapack/common_helpers.hpp @@ -24,9 +24,11 @@ //***************************************************************************** #pragma once +#include +#include + #include #include -#include #include namespace dpnp::extensions::lapack::helper @@ -63,4 +65,89 @@ inline bool check_zeros_shape(int ndim, const py::ssize_t *shape) } return src_nelems == 0; } + +// Allocate the memory for the pivot indices +inline std::int64_t *alloc_ipiv(const std::int64_t n, sycl::queue &exec_q) +{ + std::int64_t *ipiv = nullptr; + + try { + ipiv = sycl::malloc_device(n, exec_q); + if (!ipiv) { + throw std::runtime_error("Device allocation for ipiv failed"); + } + } catch (sycl::exception const &e) { + if (ipiv != nullptr) + sycl::free(ipiv, exec_q); + throw std::runtime_error( + std::string( + "Unexpected SYCL exception caught during ipiv allocation: ") + + e.what()); + } + + return ipiv; +} + +// Allocate the total memory for the total pivot indices with proper alignment +// for batch implementations +template +inline std::int64_t *alloc_ipiv_batch(const std::int64_t n, + std::int64_t n_linear_streams, + sycl::queue &exec_q) +{ + // Get padding size to ensure memory allocations are aligned to 256 bytes + // for better performance + const std::int64_t padding = 256 / sizeof(T); + + // Calculate the total size needed for the pivot indices array for all + // linear streams with proper alignment + size_t alloc_ipiv_size = round_up_mult(n_linear_streams * n, padding); + + return alloc_ipiv(alloc_ipiv_size, exec_q); +} + +// Allocate the memory for the scratchpad +template +inline T *alloc_scratchpad(std::int64_t scratchpad_size, sycl::queue &exec_q) +{ + T *scratchpad = nullptr; + + try { + if (scratchpad_size > 0) { + scratchpad = sycl::malloc_device(scratchpad_size, exec_q); + if (!scratchpad) { + throw std::runtime_error( + "Device allocation for scratchpad failed"); + } + } + } catch (sycl::exception const &e) { + if (scratchpad != nullptr) { + sycl::free(scratchpad, exec_q); + } + throw std::runtime_error(std::string("Unexpected SYCL exception caught " + "during scratchpad allocation: ") + + e.what()); + } + + return scratchpad; +} + +// Allocate the total scratchpad memory with proper alignment for batch +// implementations +template +inline T *alloc_scratchpad_batch(std::int64_t scratchpad_size, + std::int64_t n_linear_streams, + sycl::queue &exec_q) +{ + // Get padding size to ensure memory allocations are aligned to 256 bytes + // for better performance + const std::int64_t padding = 256 / sizeof(T); + + // Calculate the total scratchpad memory size needed for all linear + // streams with proper alignment + const size_t alloc_scratch_size = + round_up_mult(n_linear_streams * scratchpad_size, padding); + + return alloc_scratchpad(alloc_scratch_size, exec_q); +} } // namespace dpnp::extensions::lapack::helper diff --git a/dpnp/backend/extensions/lapack/evd_batch_common.hpp b/dpnp/backend/extensions/lapack/evd_batch_common.hpp index 5a7d985ff821..3fa5f98214a4 100644 --- a/dpnp/backend/extensions/lapack/evd_batch_common.hpp +++ b/dpnp/backend/extensions/lapack/evd_batch_common.hpp @@ -119,34 +119,4 @@ std::pair return std::make_pair(ht_ev, evd_batch_ev); } - -template -inline T *alloc_scratchpad(std::int64_t scratchpad_size, - std::int64_t n_linear_streams, - sycl::queue &exec_q) -{ - // Get padding size to ensure memory allocations are aligned to 256 bytes - // for better performance - const std::int64_t padding = 256 / sizeof(T); - - if (scratchpad_size <= 0) { - throw std::runtime_error( - "Invalid scratchpad size: must be greater than zero." - " Calculated scratchpad size: " + - std::to_string(scratchpad_size)); - } - - // Calculate the total scratchpad memory size needed for all linear - // streams with proper alignment - const size_t alloc_scratch_size = - helper::round_up_mult(n_linear_streams * scratchpad_size, padding); - - // Allocate memory for the total scratchpad - T *scratchpad = sycl::malloc_device(alloc_scratch_size, exec_q); - if (!scratchpad) { - throw std::runtime_error("Device allocation for scratchpad failed"); - } - - return scratchpad; -} } // namespace dpnp::extensions::lapack::evd diff --git a/dpnp/backend/extensions/lapack/gesv.cpp b/dpnp/backend/extensions/lapack/gesv.cpp index 17b5b9c60f8b..0cc6eb1f008e 100644 --- a/dpnp/backend/extensions/lapack/gesv.cpp +++ b/dpnp/backend/extensions/lapack/gesv.cpp @@ -26,43 +26,34 @@ #include // dpctl tensor headers -#include "utils/memory_overlap.hpp" #include "utils/type_utils.hpp" #include "common_helpers.hpp" #include "gesv.hpp" -#include "linalg_exceptions.hpp" +#include "gesv_common_utils.hpp" #include "types_matrix.hpp" -#include "dpnp_utils.hpp" - namespace dpnp::extensions::lapack { namespace mkl_lapack = oneapi::mkl::lapack; namespace py = pybind11; namespace type_utils = dpctl::tensor::type_utils; -typedef sycl::event (*gesv_impl_fn_ptr_t)(sycl::queue, +typedef sycl::event (*gesv_impl_fn_ptr_t)(sycl::queue &, const std::int64_t, const std::int64_t, char *, - std::int64_t, char *, - std::int64_t, - std::vector &, const std::vector &); static gesv_impl_fn_ptr_t gesv_dispatch_vector[dpctl_td_ns::num_types]; template -static sycl::event gesv_impl(sycl::queue exec_q, +static sycl::event gesv_impl(sycl::queue &exec_q, const std::int64_t n, const std::int64_t nrhs, char *in_a, - std::int64_t lda, char *in_b, - std::int64_t ldb, - std::vector &host_task_events, const std::vector &depends) { type_utils::validate_type_for_device(exec_q); @@ -70,21 +61,28 @@ static sycl::event gesv_impl(sycl::queue exec_q, T *a = reinterpret_cast(in_a); T *b = reinterpret_cast(in_b); + const std::int64_t lda = std::max(1UL, n); + const std::int64_t ldb = std::max(1UL, n); + const std::int64_t scratchpad_size = mkl_lapack::gesv_scratchpad_size(exec_q, n, nrhs, lda, ldb); - T *scratchpad = nullptr; + + T *scratchpad = helper::alloc_scratchpad(scratchpad_size, exec_q); std::int64_t *ipiv = nullptr; + try { + ipiv = helper::alloc_ipiv(n, exec_q); + } catch (const std::exception &e) { + if (scratchpad != nullptr) + sycl::free(scratchpad, exec_q); + throw; + } std::stringstream error_msg; - std::int64_t info = 0; bool is_exception_caught = false; sycl::event gesv_event; try { - scratchpad = sycl::malloc_device(scratchpad_size, exec_q); - ipiv = sycl::malloc_device(n, exec_q); - gesv_event = mkl_lapack::gesv( exec_q, n, // The order of the square matrix A @@ -104,41 +102,8 @@ static sycl::event gesv_impl(sycl::queue exec_q, scratchpad_size, depends); } catch (mkl_lapack::exception const &e) { is_exception_caught = true; - info = e.info(); - - if (info < 0) { - error_msg << "Parameter number " << -info - << " had an illegal value."; - } - else if (info == scratchpad_size && e.detail() != 0) { - error_msg - << "Insufficient scratchpad size. Required size is at least " - << e.detail(); - } - else if (info > 0) { - T host_U; - exec_q.memcpy(&host_U, &a[(info - 1) * lda + info - 1], sizeof(T)) - .wait(); - - using ThresholdType = typename helper::value_type_of::type; - - const auto threshold = - std::numeric_limits::epsilon() * 100; - if (std::abs(host_U) < threshold) { - sycl::free(scratchpad, exec_q); - throw LinAlgError("The input coefficient matrix is singular."); - } - else { - error_msg << "Unexpected MKL exception caught during gesv() " - "call:\nreason: " - << e.what() << "\ninfo: " << e.info(); - } - } - else { - error_msg << "Unexpected MKL exception caught during gesv() " - "call:\nreason: " - << e.what() << "\ninfo: " << e.info(); - } + gesv_utils::handle_lapack_exc(exec_q, lda, a, scratchpad_size, + scratchpad, ipiv, e, error_msg); } catch (sycl::exception const &e) { is_exception_caught = true; error_msg << "Unexpected SYCL exception caught during gesv() call:\n" @@ -147,16 +112,14 @@ static sycl::event gesv_impl(sycl::queue exec_q, if (is_exception_caught) // an unexpected error occurs { - if (scratchpad != nullptr) { + if (scratchpad != nullptr) sycl::free(scratchpad, exec_q); - } - if (ipiv != nullptr) { + if (ipiv != nullptr) sycl::free(ipiv, exec_q); - } throw std::runtime_error(error_msg.str()); } - sycl::event clean_up_event = exec_q.submit([&](sycl::handler &cgh) { + sycl::event ht_ev = exec_q.submit([&](sycl::handler &cgh) { cgh.depends_on(gesv_event); auto ctx = exec_q.get_context(); cgh.host_task([ctx, scratchpad, ipiv]() { @@ -164,13 +127,12 @@ static sycl::event gesv_impl(sycl::queue exec_q, sycl::free(ipiv, ctx); }); }); - host_task_events.push_back(clean_up_event); - return gesv_event; + return ht_ev; } std::pair - gesv(sycl::queue exec_q, + gesv(sycl::queue &exec_q, dpctl::tensor::usm_ndarray coeff_matrix, dpctl::tensor::usm_ndarray dependent_vals, const std::vector &depends) @@ -178,66 +140,30 @@ std::pair const int coeff_matrix_nd = coeff_matrix.get_ndim(); const int dependent_vals_nd = dependent_vals.get_ndim(); - if (coeff_matrix_nd != 2) { - throw py::value_error("The coefficient matrix has ndim=" + - std::to_string(coeff_matrix_nd) + - ", but a 2-dimensional array is expected."); - } - - if (dependent_vals_nd > 2) { - throw py::value_error( - "The dependent values array has ndim=" + - std::to_string(dependent_vals_nd) + - ", but a 1-dimensional or a 2-dimensional array is expected."); - } - const py::ssize_t *coeff_matrix_shape = coeff_matrix.get_shape_raw(); const py::ssize_t *dependent_vals_shape = dependent_vals.get_shape_raw(); - if (coeff_matrix_shape[0] != coeff_matrix_shape[1]) { - throw py::value_error("The coefficient matrix must be square," - " but got a shape of (" + - std::to_string(coeff_matrix_shape[0]) + ", " + - std::to_string(coeff_matrix_shape[1]) + ")."); - } + constexpr int expected_coeff_matrix_ndim = 2; + constexpr int min_dependent_vals_ndim = 1; + constexpr int max_dependent_vals_ndim = 2; - // check compatibility of execution queue and allocation queue - if (!dpctl::utils::queues_are_compatible(exec_q, - {coeff_matrix, dependent_vals})) - { - throw py::value_error( - "Execution queue is not compatible with allocation queues"); - } + gesv_utils::common_gesv_checks( + exec_q, coeff_matrix, dependent_vals, coeff_matrix_shape, + dependent_vals_shape, expected_coeff_matrix_ndim, + min_dependent_vals_ndim, max_dependent_vals_ndim); - auto const &overlap = dpctl::tensor::overlap::MemoryOverlap(); - if (overlap(coeff_matrix, dependent_vals)) { - throw py::value_error( - "The arrays of coefficients and dependent variables " - "are overlapping segments of memory"); - } - - bool is_coeff_matrix_f_contig = coeff_matrix.is_f_contiguous(); - if (!is_coeff_matrix_f_contig) { - throw py::value_error("The coefficient matrix " - "must be F-contiguous"); - } - - bool is_dependent_vals_f_contig = dependent_vals.is_f_contiguous(); - if (!is_dependent_vals_f_contig) { - throw py::value_error("The array of dependent variables " - "must be F-contiguous"); + // Ensure `batch_size`, `n` and 'nrhs' are non-zero, otherwise return empty + // events + if (helper::check_zeros_shape(coeff_matrix_nd, coeff_matrix_shape) || + helper::check_zeros_shape(dependent_vals_nd, dependent_vals_shape)) + { + // nothing to do + return std::make_pair(sycl::event(), sycl::event()); } auto array_types = dpctl_td_ns::usm_ndarray_types(); - int coeff_matrix_type_id = + const int coeff_matrix_type_id = array_types.typenum_to_lookup_id(coeff_matrix.get_typenum()); - int dependent_vals_type_id = - array_types.typenum_to_lookup_id(dependent_vals.get_typenum()); - - if (coeff_matrix_type_id != dependent_vals_type_id) { - throw py::value_error("The types of the coefficient matrix and " - "dependent variables are mismatched"); - } gesv_impl_fn_ptr_t gesv_fn = gesv_dispatch_vector[coeff_matrix_type_id]; if (gesv_fn == nullptr) { @@ -253,18 +179,13 @@ std::pair const std::int64_t nrhs = (dependent_vals_nd > 1) ? dependent_vals_shape[1] : 1; - const std::int64_t lda = std::max(1UL, n); - const std::int64_t ldb = std::max(1UL, n); - - std::vector host_task_events; - sycl::event gesv_ev = - gesv_fn(exec_q, n, nrhs, coeff_matrix_data, lda, dependent_vals_data, - ldb, host_task_events, depends); + sycl::event gesv_ev = gesv_fn(exec_q, n, nrhs, coeff_matrix_data, + dependent_vals_data, depends); - sycl::event args_ev = dpctl::utils::keep_args_alive( - exec_q, {coeff_matrix, dependent_vals}, host_task_events); + sycl::event ht_ev = dpctl::utils::keep_args_alive( + exec_q, {coeff_matrix, dependent_vals}, {gesv_ev}); - return std::make_pair(args_ev, gesv_ev); + return std::make_pair(ht_ev, gesv_ev); } template diff --git a/dpnp/backend/extensions/lapack/gesv.hpp b/dpnp/backend/extensions/lapack/gesv.hpp index 401c486eb13f..67a4d1a52d5f 100644 --- a/dpnp/backend/extensions/lapack/gesv.hpp +++ b/dpnp/backend/extensions/lapack/gesv.hpp @@ -33,10 +33,26 @@ namespace dpnp::extensions::lapack { extern std::pair - gesv(sycl::queue exec_q, + gesv(sycl::queue &exec_q, dpctl::tensor::usm_ndarray coeff_matrix, dpctl::tensor::usm_ndarray dependent_vals, const std::vector &depends); +extern std::pair + gesv_batch(sycl::queue &exec_q, + dpctl::tensor::usm_ndarray coeff_matrix, + dpctl::tensor::usm_ndarray dependent_vals, + const std::vector &depends); + +extern void common_gesv_checks(sycl::queue &exec_q, + dpctl::tensor::usm_ndarray coeff_matrix, + dpctl::tensor::usm_ndarray dependent_vals, + const py::ssize_t *coeff_matrix_shape, + const py::ssize_t *dependent_vals_shape, + const int expected_coeff_matrix_ndim, + const int min_dependent_vals_ndim, + const int max_dependent_vals_ndim); + extern void init_gesv_dispatch_vector(void); +extern void init_gesv_batch_dispatch_vector(void); } // namespace dpnp::extensions::lapack diff --git a/dpnp/backend/extensions/lapack/gesv_batch.cpp b/dpnp/backend/extensions/lapack/gesv_batch.cpp new file mode 100644 index 000000000000..90aaf8ebcfd5 --- /dev/null +++ b/dpnp/backend/extensions/lapack/gesv_batch.cpp @@ -0,0 +1,277 @@ +//***************************************************************************** +// Copyright (c) 2023-2024, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** + +#include + +// dpctl tensor headers +#include "utils/type_utils.hpp" + +#include "common_helpers.hpp" +#include "gesv.hpp" +#include "gesv_common_utils.hpp" +#include "types_matrix.hpp" + +namespace dpnp::extensions::lapack +{ +namespace mkl_lapack = oneapi::mkl::lapack; +namespace py = pybind11; +namespace type_utils = dpctl::tensor::type_utils; + +typedef sycl::event (*gesv_batch_impl_fn_ptr_t)( + sycl::queue &, + const std::int64_t, + const std::int64_t, + const std::int64_t, + char *, + char *, + const std::vector &); + +static gesv_batch_impl_fn_ptr_t + gesv_batch_dispatch_vector[dpctl_td_ns::num_types]; + +template +static sycl::event gesv_batch_impl(sycl::queue &exec_q, + const std::int64_t n, + const std::int64_t nrhs, + const std::int64_t batch_size, + char *in_a, + char *in_b, + const std::vector &depends) +{ + type_utils::validate_type_for_device(exec_q); + + T *a = reinterpret_cast(in_a); + T *b = reinterpret_cast(in_b); + + const std::int64_t a_size = n * n; + const std::int64_t b_size = n * nrhs; + + const std::int64_t lda = std::max(1UL, n); + const std::int64_t ldb = std::max(1UL, n); + + // Get the number of independent linear streams + const std::int64_t n_linear_streams = + (batch_size > 16) ? 4 : ((batch_size > 4 ? 2 : 1)); + + const std::int64_t scratchpad_size = + mkl_lapack::gesv_scratchpad_size(exec_q, n, nrhs, lda, ldb); + + T *scratchpad = helper::alloc_scratchpad_batch(scratchpad_size, + n_linear_streams, exec_q); + + std::int64_t *ipiv = nullptr; + try { + ipiv = helper::alloc_ipiv_batch(n, n_linear_streams, exec_q); + } catch (const std::exception &e) { + if (scratchpad != nullptr) + sycl::free(scratchpad, exec_q); + throw; + } + + // Computation events to manage dependencies for each linear stream + std::vector> comp_evs(n_linear_streams, depends); + + std::stringstream error_msg; + bool is_exception_caught = false; + + // Release GIL to avoid serialization of host task + // submissions to the same queue in OneMKL + py::gil_scoped_release release; + + for (std::int64_t batch_id = 0; batch_id < batch_size; ++batch_id) { + T *a_batch = a + batch_id * a_size; + T *b_batch = b + batch_id * b_size; + + std::int64_t stream_id = (batch_id % n_linear_streams); + + T *current_scratch_gesv = scratchpad + stream_id * scratchpad_size; + std::int64_t *current_ipiv = ipiv + stream_id * n; + + // Get the event dependencies for the current stream + const auto ¤t_dep = comp_evs[stream_id]; + + sycl::event gesv_event; + + try { + gesv_event = mkl_lapack::gesv( + exec_q, + n, // The order of the square matrix A + // and the number of rows in matrix B (0 ≤ n). + nrhs, // The number of right-hand sides, + // i.e., the number of columns in matrix B (0 ≤ nrhs). + a_batch, // Pointer to the square coefficient matrix A (n x n). + lda, // The leading dimension of a, must be at least max(1, n). + current_ipiv, // The pivot indices that define the permutation + // matrix P; row i of the matrix was interchanged + // with row ipiv(i), must be at least max(1, n). + b_batch, // Pointer to the right hand side matrix B (n x nrhs). + ldb, // The leading dimension of matrix B, + // must be at least max(1, n). + current_scratch_gesv, // Pointer to scratchpad memory to be used + // by MKL routine for storing intermediate + // results. + scratchpad_size, current_dep); + } catch (mkl_lapack::exception const &e) { + is_exception_caught = true; + gesv_utils::handle_lapack_exc(exec_q, lda, a, scratchpad_size, + scratchpad, ipiv, e, error_msg); + } catch (sycl::exception const &e) { + is_exception_caught = true; + error_msg + << "Unexpected SYCL exception caught during gesv() call:\n" + << e.what(); + } + + // Update the event dependencies for the current stream + comp_evs[stream_id] = {gesv_event}; + } + + if (is_exception_caught) // an unexpected error occurs + { + if (scratchpad != nullptr) + sycl::free(scratchpad, exec_q); + if (ipiv != nullptr) + sycl::free(ipiv, exec_q); + throw std::runtime_error(error_msg.str()); + } + + sycl::event ht_ev = exec_q.submit([&](sycl::handler &cgh) { + for (const auto &ev : comp_evs) { + cgh.depends_on(ev); + } + auto ctx = exec_q.get_context(); + cgh.host_task([ctx, scratchpad, ipiv]() { + sycl::free(scratchpad, ctx); + sycl::free(ipiv, ctx); + }); + }); + + return ht_ev; +} + +std::pair + gesv_batch(sycl::queue &exec_q, + dpctl::tensor::usm_ndarray coeff_matrix, + dpctl::tensor::usm_ndarray dependent_vals, + const std::vector &depends) +{ + const int coeff_matrix_nd = coeff_matrix.get_ndim(); + const int dependent_vals_nd = dependent_vals.get_ndim(); + + const py::ssize_t *coeff_matrix_shape = coeff_matrix.get_shape_raw(); + const py::ssize_t *dependent_vals_shape = dependent_vals.get_shape_raw(); + + constexpr int expected_coeff_matrix_ndim = 3; + constexpr int min_dependent_vals_ndim = 2; + constexpr int max_dependent_vals_ndim = 3; + + gesv_utils::common_gesv_checks( + exec_q, coeff_matrix, dependent_vals, coeff_matrix_shape, + dependent_vals_shape, expected_coeff_matrix_ndim, + min_dependent_vals_ndim, max_dependent_vals_ndim); + + // Ensure `batch_size`, `n` and 'nrhs' are non-zero, otherwise return empty + // events + if (helper::check_zeros_shape(coeff_matrix_nd, coeff_matrix_shape) || + helper::check_zeros_shape(dependent_vals_nd, dependent_vals_shape)) + { + // nothing to do + return std::make_pair(sycl::event(), sycl::event()); + } + + if (dependent_vals_nd == 2) { + if (coeff_matrix_shape[2] != dependent_vals_shape[1]) { + throw py::value_error( + "The batch_size of " + " coeff_matrix and dependent_vals must be" + " the same, but got " + + std::to_string(coeff_matrix_shape[2]) + " and " + + std::to_string(dependent_vals_shape[1]) + "."); + } + } + else if (dependent_vals_nd == 3) { + if (coeff_matrix_shape[2] != dependent_vals_shape[2]) { + throw py::value_error( + "The batch_size of " + " coeff_matrix and dependent_vals must be" + " the same, but got " + + std::to_string(coeff_matrix_shape[2]) + " and " + + std::to_string(dependent_vals_shape[2]) + "."); + } + } + + auto array_types = dpctl_td_ns::usm_ndarray_types(); + const int coeff_matrix_type_id = + array_types.typenum_to_lookup_id(coeff_matrix.get_typenum()); + + gesv_batch_impl_fn_ptr_t gesv_batch_fn = + gesv_batch_dispatch_vector[coeff_matrix_type_id]; + if (gesv_batch_fn == nullptr) { + throw py::value_error( + "No gesv implementation defined for the provided type " + "of the coefficient matrix."); + } + + char *coeff_matrix_data = coeff_matrix.get_data(); + char *dependent_vals_data = dependent_vals.get_data(); + + const std::int64_t batch_size = coeff_matrix_shape[2]; + const std::int64_t n = coeff_matrix_shape[1]; + const std::int64_t nrhs = + (dependent_vals_nd > 2) ? dependent_vals_shape[1] : 1; + + sycl::event gesv_ev = + gesv_batch_fn(exec_q, n, nrhs, batch_size, coeff_matrix_data, + dependent_vals_data, depends); + + sycl::event ht_ev = dpctl::utils::keep_args_alive( + exec_q, {coeff_matrix, dependent_vals}, {gesv_ev}); + + return std::make_pair(ht_ev, gesv_ev); +} + +template +struct GesvBatchContigFactory +{ + fnT get() + { + if constexpr (types::GesvTypePairSupportFactory::is_defined) { + return gesv_batch_impl; + } + else { + return nullptr; + } + } +}; + +void init_gesv_batch_dispatch_vector(void) +{ + dpctl_td_ns::DispatchVectorBuilder + contig; + contig.populate_dispatch_vector(gesv_batch_dispatch_vector); +} +} // namespace dpnp::extensions::lapack diff --git a/dpnp/backend/extensions/lapack/gesv_common_utils.hpp b/dpnp/backend/extensions/lapack/gesv_common_utils.hpp new file mode 100644 index 000000000000..4b4df013aab5 --- /dev/null +++ b/dpnp/backend/extensions/lapack/gesv_common_utils.hpp @@ -0,0 +1,180 @@ +//***************************************************************************** +// Copyright (c) 2024, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** + +#pragma once + +#include + +// dpctl tensor headers +#include "utils/memory_overlap.hpp" +#include "utils/output_validation.hpp" +#include "utils/type_dispatch.hpp" + +#include "common_helpers.hpp" +#include "linalg_exceptions.hpp" + +namespace dpnp::extensions::lapack::gesv_utils +{ +namespace dpctl_td_ns = dpctl::tensor::type_dispatch; +namespace py = pybind11; + +inline void common_gesv_checks(sycl::queue &exec_q, + dpctl::tensor::usm_ndarray coeff_matrix, + dpctl::tensor::usm_ndarray dependent_vals, + const py::ssize_t *coeff_matrix_shape, + const py::ssize_t *dependent_vals_shape, + const int expected_coeff_matrix_ndim, + const int min_dependent_vals_ndim, + const int max_dependent_vals_ndim) +{ + const int coeff_matrix_nd = coeff_matrix.get_ndim(); + const int dependent_vals_nd = dependent_vals.get_ndim(); + + if (coeff_matrix_nd != expected_coeff_matrix_ndim) { + throw py::value_error("The coefficient matrix has ndim=" + + std::to_string(coeff_matrix_nd) + ", but a " + + std::to_string(expected_coeff_matrix_ndim) + + "-dimensional array is expected."); + } + + if (dependent_vals_nd < min_dependent_vals_ndim || + dependent_vals_nd > max_dependent_vals_ndim) + { + throw py::value_error("The dependent values array has ndim=" + + std::to_string(dependent_vals_nd) + ", but a " + + std::to_string(min_dependent_vals_ndim) + + "-dimensional or a " + + std::to_string(max_dependent_vals_ndim) + + "-dimensional array is expected."); + } + + // The coeff_matrix and dependent_vals arrays must be F-contiguous arrays + // for gesv + // with the shapes (n, n) and (n, nrhs) or (n, ) respectively; + // for gesv_batch + // with the shapes (n, n, batch_size) and (n, nrhs, batch_size) or + // (n, batch_size) respectively + if (coeff_matrix_shape[0] != coeff_matrix_shape[1]) { + throw py::value_error("The coefficient matrix must be square," + " but got a shape of (" + + std::to_string(coeff_matrix_shape[0]) + ", " + + std::to_string(coeff_matrix_shape[1]) + ")."); + } + if (coeff_matrix_shape[0] != dependent_vals_shape[0]) { + throw py::value_error("The first dimension (n) of coeff_matrix and" + " dependent_vals must be the same, but got " + + std::to_string(coeff_matrix_shape[0]) + " and " + + std::to_string(dependent_vals_shape[0]) + "."); + } + + // check compatibility of execution queue and allocation queue + if (!dpctl::utils::queues_are_compatible(exec_q, + {coeff_matrix, dependent_vals})) + { + throw py::value_error( + "Execution queue is not compatible with allocation queues."); + } + + auto const &overlap = dpctl::tensor::overlap::MemoryOverlap(); + if (overlap(coeff_matrix, dependent_vals)) { + throw py::value_error( + "The arrays of coefficients and dependent variables " + "are overlapping segments of memory."); + } + + dpctl::tensor::validation::CheckWritable::throw_if_not_writable( + dependent_vals); + + const bool is_coeff_matrix_f_contig = coeff_matrix.is_f_contiguous(); + if (!is_coeff_matrix_f_contig) { + throw py::value_error("The coefficient matrix " + "must be F-contiguous."); + } + + const bool is_dependent_vals_f_contig = dependent_vals.is_f_contiguous(); + if (!is_dependent_vals_f_contig) { + throw py::value_error("The array of dependent variables " + "must be F-contiguous."); + } + + auto array_types = dpctl_td_ns::usm_ndarray_types(); + const int coeff_matrix_type_id = + array_types.typenum_to_lookup_id(coeff_matrix.get_typenum()); + const int dependent_vals_type_id = + array_types.typenum_to_lookup_id(dependent_vals.get_typenum()); + + if (coeff_matrix_type_id != dependent_vals_type_id) { + throw py::value_error("The types of the coefficient matrix and " + "dependent variables are mismatched."); + } +} + +template +inline void handle_lapack_exc(sycl::queue &exec_q, + const std::int64_t lda, + T *a, + std::int64_t scratchpad_size, + T *scratchpad, + std::int64_t *ipiv, + const oneapi::mkl::lapack::exception &e, + std::stringstream &error_msg) +{ + std::int64_t info = e.info(); + if (info < 0) { + error_msg << "Parameter number " << -info << " had an illegal value."; + } + else if (info == scratchpad_size && e.detail() != 0) { + error_msg << "Insufficient scratchpad size. Required size is at least " + << e.detail(); + } + else if (info > 0) { + T host_U; + exec_q.memcpy(&host_U, &a[(info - 1) * lda + info - 1], sizeof(T)) + .wait(); + + using ThresholdType = typename helper::value_type_of::type; + + const auto threshold = + std::numeric_limits::epsilon() * 100; + if (std::abs(host_U) < threshold) { + if (scratchpad != nullptr) + sycl::free(scratchpad, exec_q); + if (ipiv != nullptr) + sycl::free(ipiv, exec_q); + throw LinAlgError("The input coefficient matrix is singular."); + } + else { + error_msg << "Unexpected MKL exception caught during gesv() " + "call:\nreason: " + << e.what() << "\ninfo: " << e.info(); + } + } + else { + error_msg + << "Unexpected MKL exception caught during gesv() call:\nreason: " + << e.what() << "\ninfo: " << e.info(); + } +} +} // namespace dpnp::extensions::lapack::gesv_utils diff --git a/dpnp/backend/extensions/lapack/heevd.cpp b/dpnp/backend/extensions/lapack/heevd.cpp index a623a0dd182c..234dbef18eda 100644 --- a/dpnp/backend/extensions/lapack/heevd.cpp +++ b/dpnp/backend/extensions/lapack/heevd.cpp @@ -55,18 +55,7 @@ static sycl::event heevd_impl(sycl::queue &exec_q, const std::int64_t scratchpad_size = mkl_lapack::heevd_scratchpad_size(exec_q, jobz, upper_lower, n, lda); - if (scratchpad_size <= 0) { - throw std::runtime_error( - "Invalid scratchpad size: must be greater than zero." - "Calculated scratchpad size: " + - std::to_string(scratchpad_size)); - } - - T *scratchpad = nullptr; - // Allocate memory for the scratchpad - scratchpad = sycl::malloc_device(scratchpad_size, exec_q); - if (!scratchpad) - throw std::runtime_error("Device allocation for scratchpad failed"); + T *scratchpad = helper::alloc_scratchpad(scratchpad_size, exec_q); std::stringstream error_msg; std::int64_t info = 0; diff --git a/dpnp/backend/extensions/lapack/heevd_batch.cpp b/dpnp/backend/extensions/lapack/heevd_batch.cpp index bc21b3da35d5..05d968701dd8 100644 --- a/dpnp/backend/extensions/lapack/heevd_batch.cpp +++ b/dpnp/backend/extensions/lapack/heevd_batch.cpp @@ -65,8 +65,8 @@ static sycl::event heevd_batch_impl(sycl::queue &exec_q, const std::int64_t scratchpad_size = mkl_lapack::heevd_scratchpad_size(exec_q, jobz, upper_lower, n, lda); - T *scratchpad = - evd::alloc_scratchpad(scratchpad_size, n_linear_streams, exec_q); + T *scratchpad = helper::alloc_scratchpad_batch(scratchpad_size, + n_linear_streams, exec_q); // Computation events to manage dependencies for each linear stream std::vector> comp_evs(n_linear_streams, depends); diff --git a/dpnp/backend/extensions/lapack/lapack_py.cpp b/dpnp/backend/extensions/lapack/lapack_py.cpp index 7afa67e84eca..b29810893392 100644 --- a/dpnp/backend/extensions/lapack/lapack_py.cpp +++ b/dpnp/backend/extensions/lapack/lapack_py.cpp @@ -53,6 +53,7 @@ void init_dispatch_vectors(void) { lapack_ext::init_geqrf_batch_dispatch_vector(); lapack_ext::init_geqrf_dispatch_vector(); + lapack_ext::init_gesv_batch_dispatch_vector(); lapack_ext::init_gesv_dispatch_vector(); lapack_ext::init_getrf_batch_dispatch_vector(); lapack_ext::init_getrf_dispatch_vector(); @@ -109,6 +110,13 @@ PYBIND11_MODULE(_lapack_impl, m) py::arg("sycl_queue"), py::arg("coeff_matrix"), py::arg("dependent_vals"), py::arg("depends") = py::list()); + m.def("_gesv_batch", &lapack_ext::gesv_batch, + "Call `gesv` from OneMKL LAPACK library to return " + "the batch solution of a system of linear equations with " + "a square coefficient matrix A and multiple dependent variables", + py::arg("sycl_queue"), py::arg("coeff_matrix"), + py::arg("dependent_vals"), py::arg("depends") = py::list()); + m.def("_gesvd", &lapack_ext::gesvd, "Call `gesvd` from OneMKL LAPACK library to return " "the singular value decomposition of a general rectangular matrix", diff --git a/dpnp/backend/extensions/lapack/syevd.cpp b/dpnp/backend/extensions/lapack/syevd.cpp index bbc975b43d0f..30f7ea1d13c6 100644 --- a/dpnp/backend/extensions/lapack/syevd.cpp +++ b/dpnp/backend/extensions/lapack/syevd.cpp @@ -55,18 +55,7 @@ static sycl::event syevd_impl(sycl::queue &exec_q, const std::int64_t scratchpad_size = mkl_lapack::syevd_scratchpad_size(exec_q, jobz, upper_lower, n, lda); - if (scratchpad_size <= 0) { - throw std::runtime_error( - "Invalid scratchpad size: must be greater than zero." - "Calculated scratchpad size: " + - std::to_string(scratchpad_size)); - } - - T *scratchpad = nullptr; - // Allocate memory for the scratchpad - scratchpad = sycl::malloc_device(scratchpad_size, exec_q); - if (!scratchpad) - throw std::runtime_error("Device allocation for scratchpad failed"); + T *scratchpad = helper::alloc_scratchpad(scratchpad_size, exec_q); std::stringstream error_msg; std::int64_t info = 0; diff --git a/dpnp/backend/extensions/lapack/syevd_batch.cpp b/dpnp/backend/extensions/lapack/syevd_batch.cpp index d2b87fc260cc..78bf07264f55 100644 --- a/dpnp/backend/extensions/lapack/syevd_batch.cpp +++ b/dpnp/backend/extensions/lapack/syevd_batch.cpp @@ -65,8 +65,8 @@ static sycl::event syevd_batch_impl(sycl::queue &exec_q, const std::int64_t scratchpad_size = mkl_lapack::syevd_scratchpad_size(exec_q, jobz, upper_lower, n, lda); - T *scratchpad = - evd::alloc_scratchpad(scratchpad_size, n_linear_streams, exec_q); + T *scratchpad = helper::alloc_scratchpad_batch(scratchpad_size, + n_linear_streams, exec_q); // Computation events to manage dependencies for each linear stream std::vector> comp_evs(n_linear_streams, depends); diff --git a/dpnp/linalg/dpnp_utils_linalg.py b/dpnp/linalg/dpnp_utils_linalg.py index 2ceda3740441..c25e0a8c9b73 100644 --- a/dpnp/linalg/dpnp_utils_linalg.py +++ b/dpnp/linalg/dpnp_utils_linalg.py @@ -260,99 +260,78 @@ def _batched_solve(a, b, exec_q, res_usm_type, res_type): """ - a_usm_arr = dpnp.get_usm_ndarray(a) - b_usm_arr = dpnp.get_usm_ndarray(b) - - b_order = "C" if b.flags.c_contiguous else "F" a_shape = a.shape b_shape = b.shape - is_cpu_device = exec_q.sycl_device.has_aspect_cpu - reshape = False - orig_shape_b = b_shape - if a.ndim > 3: - # get 3d input arrays by reshape - if a.ndim == b.ndim: - b = dpnp.reshape(b, (-1, b_shape[-2], b_shape[-1])) - else: - b = dpnp.reshape(b, (-1, b_shape[-1])) + # gesv_batch expects `a` to be a 3D array and + # `b` to be either a 2D or 3D array. + if a.ndim == b.ndim: + b = dpnp.reshape(b, (-1, b_shape[-2], b_shape[-1])) + else: + b = dpnp.reshape(b, (-1, b_shape[-1])) - a = dpnp.reshape(a, (-1, a_shape[-2], a_shape[-1])) + a = dpnp.reshape(a, (-1, a_shape[-2], a_shape[-1])) - a_usm_arr = dpnp.get_usm_ndarray(a) - b_usm_arr = dpnp.get_usm_ndarray(b) - reshape = True + # Reorder the elements by moving the last two axes of `a` to the front + # to match fortran-like array order which is assumed by gesv. + a = dpnp.moveaxis(a, (-2, -1), (0, 1)) + # The same for `b` if it is 3D; + # if it is 2D, transpose it. + if b.ndim > 2: + b = dpnp.moveaxis(b, (-2, -1), (0, 1)) + else: + b = b.T + + a_usm_arr = dpnp.get_usm_ndarray(a) + b_usm_arr = dpnp.get_usm_ndarray(b) _manager = dpu.SequentialOrderManager[exec_q] dep_evs = _manager.submitted_events - batch_size = a.shape[0] - coeff_vecs = [None] * batch_size - val_vecs = [None] * batch_size + # oneMKL LAPACK gesv destroys `a` and assumes fortran-like array + # as input. + a_f = dpnp.empty_like(a, dtype=res_type, order="F", usm_type=res_usm_type) - for i in range(batch_size): - # oneMKL LAPACK assumes fortran-like array as input, so allocate - # a memory with 'F' order for dpnp array of coefficient matrix - coeff_vecs[i] = dpnp.empty_like( - a[i], order="F", dtype=res_type, usm_type=res_usm_type - ) + ht_ev, a_copy_ev = ti._copy_usm_ndarray_into_usm_ndarray( + src=a_usm_arr, + dst=a_f.get_array(), + sycl_queue=exec_q, + depends=dep_evs, + ) + _manager.add_event_pair(ht_ev, a_copy_ev) - # use DPCTL tensor function to fill the coefficient matrix array - # with content from the input array - ht_ev, a_copy_ev = ti._copy_usm_ndarray_into_usm_ndarray( - src=a_usm_arr[i], - dst=coeff_vecs[i].get_array(), - sycl_queue=a.sycl_queue, - depends=dep_evs, - ) - _manager.add_event_pair(ht_ev, a_copy_ev) + # oneMKL LAPACK gesv overwrites `b` and assumes fortran-like array + # as input. + b_f = dpnp.empty_like(b, order="F", dtype=res_type, usm_type=res_usm_type) - # oneMKL LAPACK assumes fortran-like array as input, so - # allocate a memory with 'F' order for dpnp array of multiple - # dependent variables array - val_vecs[i] = dpnp.empty_like( - b[i], order="F", dtype=res_type, usm_type=res_usm_type - ) + ht_ev, b_copy_ev = ti._copy_usm_ndarray_into_usm_ndarray( + src=b_usm_arr, + dst=b_f.get_array(), + sycl_queue=exec_q, + depends=dep_evs, + ) + _manager.add_event_pair(ht_ev, b_copy_ev) - # use DPCTL tensor function to fill the array of multiple dependent - # variables with content from the input arrays - ht_ev, b_copy_ev = ti._copy_usm_ndarray_into_usm_ndarray( - src=b_usm_arr[i], - dst=val_vecs[i].get_array(), - sycl_queue=b.sycl_queue, - depends=dep_evs, - ) - _manager.add_event_pair(ht_ev, b_copy_ev) + ht_ev, gesv_batch_ev = li._gesv_batch( + exec_q, + a_f.get_array(), + b_f.get_array(), + depends=[a_copy_ev, b_copy_ev], + ) - # Call the LAPACK extension function _gesv to solve the system of - # linear equations using a portion of the coefficient square matrix - # and a corresponding portion of the dependent variables array. - ht_ev, gesv_ev = li._gesv( - exec_q, - coeff_vecs[i].get_array(), - val_vecs[i].get_array(), - depends=[a_copy_ev, b_copy_ev], - ) - _manager.add_event_pair(ht_ev, gesv_ev) + _manager.add_event_pair(ht_ev, gesv_batch_ev) - # TODO: Remove this w/a when MKLD-17201 is solved. - # Waiting for a host task executing an OneMKL LAPACK gesv call - # on CPU causes deadlock due to serialization of all host tasks - # in the queue. - # We need to wait for each host tasks before calling _gesv to avoid - # deadlock. - if is_cpu_device: - dpnp.synchronize_array_data(a) + # Gesv call overwtires `b` in Fortran order, reorder the axes + # to match C order by moving the last axis to the front and + # reshape it back to the original shape of `b`. + v = dpnp.moveaxis(b_f, -1, 0).reshape(b_shape) - # combine the list of solutions into a single array - out_v = dpnp.array( - val_vecs, order=b_order, dtype=res_type, usm_type=res_usm_type - ) - if reshape: - # shape of the out_v must be equal to the shape of the array of - # dependent variables - out_v = out_v.reshape(orig_shape_b) - return out_v + # dpnp.moveaxis can make the array non-contiguous if it is not 2D + # Convert to contiguous to align with NumPy + if b.ndim > 2: + v = dpnp.ascontiguousarray(v) + + return v def _batched_qr(a, mode="reduced"):