Skip to content
Merged
24 changes: 11 additions & 13 deletions dpnp/backend/extensions/blas/blas_py.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,13 +64,13 @@ PYBIND11_MODULE(_blas_impl, m)
blas_ext::DotContigFactory>(
dot_dispatch_vector);

auto dot_pypi = [&](sycl::queue exec_q, arrayT src1, arrayT src2,
arrayT dst, const event_vecT &depends = {}) {
auto dot_pyapi = [&](sycl::queue exec_q, arrayT src1, arrayT src2,
arrayT dst, const event_vecT &depends = {}) {
return dot_ext::dot_func(exec_q, src1, src2, dst, depends,
dot_dispatch_vector);
};

m.def("_dot", dot_pypi,
m.def("_dot", dot_pyapi,
"Call `dot` from OneMKL BLAS library to return "
"the dot product of two real-valued vectors.",
py::arg("sycl_queue"), py::arg("vectorA"), py::arg("vectorB"),
Expand All @@ -82,13 +82,13 @@ PYBIND11_MODULE(_blas_impl, m)
blas_ext::DotcContigFactory>(
dotc_dispatch_vector);

auto dotc_pypi = [&](sycl::queue exec_q, arrayT src1, arrayT src2,
arrayT dst, const event_vecT &depends = {}) {
auto dotc_pyapi = [&](sycl::queue exec_q, arrayT src1, arrayT src2,
arrayT dst, const event_vecT &depends = {}) {
return dot_ext::dot_func(exec_q, src1, src2, dst, depends,
dotc_dispatch_vector);
};

m.def("_dotc", dotc_pypi,
m.def("_dotc", dotc_pyapi,
"Call `dotc` from OneMKL BLAS library to return "
"the dot product of two complex vectors, "
"conjugating the first vector.",
Expand All @@ -101,13 +101,13 @@ PYBIND11_MODULE(_blas_impl, m)
blas_ext::DotuContigFactory>(
dotu_dispatch_vector);

auto dotu_pypi = [&](sycl::queue exec_q, arrayT src1, arrayT src2,
arrayT dst, const event_vecT &depends = {}) {
auto dotu_pyapi = [&](sycl::queue exec_q, arrayT src1, arrayT src2,
arrayT dst, const event_vecT &depends = {}) {
return dot_ext::dot_func(exec_q, src1, src2, dst, depends,
dotu_dispatch_vector);
};

m.def("_dotu", dotu_pypi,
m.def("_dotu", dotu_pyapi,
"Call `dotu` from OneMKL BLAS library to return "
"the dot product of two complex vectors.",
py::arg("sycl_queue"), py::arg("vectorA"), py::arg("vectorB"),
Expand All @@ -119,16 +119,14 @@ PYBIND11_MODULE(_blas_impl, m)
"Call `gemm` from OneMKL BLAS library to return "
"the matrix-matrix product with 2-D matrices.",
py::arg("sycl_queue"), py::arg("matrixA"), py::arg("matrixB"),
py::arg("result"), py::arg("depends") = py::list());
py::arg("resultC"), py::arg("depends") = py::list());
}

{
m.def("_gemm_batch", &blas_ext::gemm_batch,
"Call `gemm_batch` from OneMKL BLAS library to return "
"the matrix-matrix product for a batch of 2-D matrices.",
py::arg("sycl_queue"), py::arg("matrixA"), py::arg("matrixB"),
py::arg("result"), py::arg("batch_size"), py::arg("stridea"),
py::arg("strideb"), py::arg("stridec"),
py::arg("depends") = py::list());
py::arg("resultC"), py::arg("depends") = py::list());
}
}
4 changes: 0 additions & 4 deletions dpnp/backend/extensions/blas/gemm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,6 @@ extern std::pair<sycl::event, sycl::event>
dpctl::tensor::usm_ndarray matrixA,
dpctl::tensor::usm_ndarray matrixB,
dpctl::tensor::usm_ndarray resultC,
const std::int64_t batch_size,
size_t stridea,
size_t strideb,
size_t stridec,
const std::vector<sycl::event> &depends);

extern void init_gemm_dispatch_table(void);
Expand Down
75 changes: 41 additions & 34 deletions dpnp/backend/extensions/blas/gemm_batch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -150,10 +150,6 @@ std::pair<sycl::event, sycl::event>
dpctl::tensor::usm_ndarray matrixA,
dpctl::tensor::usm_ndarray matrixB,
dpctl::tensor::usm_ndarray resultC,
const std::int64_t batch_size,
size_t stridea,
size_t strideb,
size_t stridec,
const std::vector<sycl::event> &depends = {})
{
const int matrixA_nd = matrixA.get_ndim();
Expand Down Expand Up @@ -185,49 +181,60 @@ std::pair<sycl::event, sycl::event>
const py::ssize_t *a_shape = matrixA.get_shape_raw();
const py::ssize_t *b_shape = matrixB.get_shape_raw();
const py::ssize_t *c_shape = resultC.get_shape_raw();
const std::int64_t m = a_shape[matrixA_nd - 2];
const std::int64_t n = b_shape[matrixB_nd - 1];
const std::int64_t k = a_shape[matrixA_nd - 1];
if (a_shape[matrixA_nd - 1] != b_shape[matrixB_nd - 2]) {
const std::int64_t m = a_shape[1];
const std::int64_t n = b_shape[2];
const std::int64_t k = a_shape[2];
const std::int64_t batch_size = c_shape[0];
if (a_shape[2] != b_shape[1]) {
throw py::value_error("The number of columns in A must be equal to "
"the number of rows in B.");
}
if (a_shape[matrixA_nd - 2] != c_shape[resultC_nd - 2]) {
if (a_shape[1] != c_shape[1]) {
throw py::value_error("The number of rows in A must be equal to "
"the number of rows in result array.");
}
if (b_shape[matrixB_nd - 1] != c_shape[resultC_nd - 1]) {
if (b_shape[2] != c_shape[2]) {
throw py::value_error("The number of columns in B must be equal to "
"the number of columns in result array.");
}

bool shapes_equal = true;
size_t src_nelems = 1;
py::ssize_t lead_dim;
for (int i = 0; i < matrixA_nd - 2; ++i) {
if (a_shape[i] == b_shape[i]) {
lead_dim = a_shape[i];
}
else if (a_shape[i] == 1 || b_shape[i] == 1) {
lead_dim = std::max(a_shape[i], b_shape[i]);
}
else {
throw py::value_error("Array shapes do not match.");
}
src_nelems *= static_cast<size_t>(lead_dim);
shapes_equal = shapes_equal && (lead_dim == c_shape[i]);
std::int64_t first_dim;
if (a_shape[0] == b_shape[0]) {
first_dim = a_shape[0];
}
else if (a_shape[0] == 1 || b_shape[0] == 1) {
first_dim = std::max(a_shape[0], b_shape[0]);
}
src_nelems *= (m * n);
if (!shapes_equal) {
else {
throw py::value_error("Array shapes do not match.");
}
if (first_dim != c_shape[0]) {
throw py::value_error("Array shapes do not match.");
}
std::int64_t src_nelems = first_dim * m * n;
dpctl::tensor::validation::CheckWritable::throw_if_not_writable(resultC);
dpctl::tensor::validation::AmpleMemory::throw_if_not_ample(resultC,
src_nelems);

// transA and transB are always False
oneapi::mkl::transpose transA = oneapi::mkl::transpose::N;
oneapi::mkl::transpose transB = oneapi::mkl::transpose::N;
std::vector<py::ssize_t> a_stride = matrixA.get_strides_vector();
std::vector<py::ssize_t> b_stride = matrixB.get_strides_vector();
std::vector<py::ssize_t> c_stride = resultC.get_strides_vector();
const std::int64_t stridea = a_stride[0];
const std::int64_t strideb = b_stride[0];
const std::int64_t stridec = c_stride[0];
bool A_base_is_f_contig = a_stride[1] == 1 && a_stride[2] == a_shape[1];
bool B_base_is_f_contig = b_stride[1] == 1 && b_stride[2] == b_shape[1];

oneapi::mkl::transpose transA = A_base_is_f_contig
? oneapi::mkl::transpose::T
: oneapi::mkl::transpose::N;
oneapi::mkl::transpose transB = B_base_is_f_contig
? oneapi::mkl::transpose::T
: oneapi::mkl::transpose::N;

const std::int64_t lda = (transA == oneapi::mkl::transpose::N) ? k : m;
const std::int64_t ldb = (transB == oneapi::mkl::transpose::N) ? n : k;
const std::int64_t ldc = n; // always n for row_major

int matrixA_typenum = matrixA.get_typenum();
int matrixB_typenum = matrixB.get_typenum();
Expand All @@ -252,10 +259,10 @@ std::pair<sycl::event, sycl::event>
char *b_typeless_ptr = matrixB.get_data();
char *r_typeless_ptr = resultC.get_data();

// Note that lda = k, ldb = n, and ld_result = n
sycl::event gemm_batch_ev = gemm_batch_fn(
exec_q, m, n, k, batch_size, k, n, n, stridea, strideb, stridec, transA,
transB, a_typeless_ptr, b_typeless_ptr, r_typeless_ptr, depends);
sycl::event gemm_batch_ev =
gemm_batch_fn(exec_q, m, n, k, batch_size, lda, ldb, ldc, stridea,
strideb, stridec, transA, transB, a_typeless_ptr,
b_typeless_ptr, r_typeless_ptr, depends);

sycl::event args_batch_ev = dpctl::utils::keep_args_alive(
exec_q, {matrixA, matrixB, resultC}, {gemm_batch_ev});
Expand Down
Loading