From 00b434fe8c5647d822c908621e3d83dae716df9b Mon Sep 17 00:00:00 2001 From: Dmitry Yashunin Date: Sun, 18 Sep 2022 12:22:45 +0200 Subject: [PATCH 1/2] Remove some code duplication in bindings --- python_bindings/bindings.cpp | 126 ++++++++++++++--------------------- 1 file changed, 51 insertions(+), 75 deletions(-) diff --git a/python_bindings/bindings.cpp b/python_bindings/bindings.cpp index fcb444da..d2444729 100644 --- a/python_bindings/bindings.cpp +++ b/python_bindings/bindings.cpp @@ -42,7 +42,7 @@ inline void ParallelFor(size_t start, size_t end, size_t numThreads, Function fn while (true) { size_t id = current.fetch_add(1); - if ((id >= end)) { + if (id >= end) { break; } @@ -79,6 +79,45 @@ inline void assert_true(bool expr, const std::string & msg) { } +inline void get_input_array_shapes(const py::buffer_info& buffer, size_t* rows, size_t* features) { + if (buffer.ndim != 2 && buffer.ndim != 1) throw std::runtime_error("data must be a 1d/2d array"); + if (buffer.ndim == 2) { + *rows = buffer.shape[0]; + *features = buffer.shape[1]; + } else { + *rows = 1; + *features = buffer.shape[0]; + } +} + + +inline std::vector get_input_ids_and_check_shapes(const py::object& ids_, size_t rows) { + std::vector ids; + if (!ids_.is_none()) { + py::array_t < size_t, py::array::c_style | py::array::forcecast > items(ids_); + auto ids_numpy = items.request(); + // check shapes + bool valid = false; + if ((ids_numpy.ndim == 1 && ids_numpy.shape[0] == rows) || (ids_numpy.ndim == 0 && rows == 1)) { + valid = true; + } + if (!valid) throw std::runtime_error("wrong dimensionality of the labels"); + // extract data + if (ids_numpy.ndim == 1) { + std::vector ids1(ids_numpy.shape[0]); + for (size_t i = 0; i < ids1.size(); i++) { + ids1[i] = items.data()[i]; + } + ids.swap(ids1); + } else if (ids_numpy.ndim == 0) { + ids.push_back(*items.data()); + } + } + + return ids; +} + + template class Index { public: @@ -146,7 +185,7 @@ class Index { void set_ef(size_t ef) { default_ef = ef; if (appr_alg) - appr_alg->ef_ = ef; + appr_alg->ef_ = ef; } @@ -188,15 +227,7 @@ class Index { num_threads = num_threads_default; size_t rows, features; - - if (buffer.ndim != 2 && buffer.ndim != 1) throw std::runtime_error("data must be a 1d/2d array"); - if (buffer.ndim == 2) { - rows = buffer.shape[0]; - features = buffer.shape[1]; - } else { - rows = 1; - features = buffer.shape[0]; - } + get_input_array_shapes(buffer, &rows, &features); if (features != dim) throw std::runtime_error("wrong dimensionality of the vectors"); @@ -206,23 +237,7 @@ class Index { num_threads = 1; } - std::vector ids; - - if (!ids_.is_none()) { - py::array_t < size_t, py::array::c_style | py::array::forcecast > items(ids_); - auto ids_numpy = items.request(); - if (ids_numpy.ndim == 1 && ids_numpy.shape[0] == rows) { - std::vector ids1(ids_numpy.shape[0]); - for (size_t i = 0; i < ids1.size(); i++) { - ids1[i] = items.data()[i]; - } - ids.swap(ids1); - } else if (ids_numpy.ndim == 0 && rows == 1) { - ids.push_back(*items.data()); - } else { - throw std::runtime_error("wrong dimensionality of the labels"); - } - } + std::vector ids = get_input_ids_and_check_shapes(ids_, rows); { int start = 0; @@ -561,15 +576,7 @@ class Index { { py::gil_scoped_release l; - - if (buffer.ndim != 2 && buffer.ndim != 1) throw std::runtime_error("data must be a 1d/2d array"); - if (buffer.ndim == 2) { - rows = buffer.shape[0]; - features = buffer.shape[1]; - } else { - rows = 1; - features = buffer.shape[0]; - } + get_input_array_shapes(buffer, &rows, &features); // avoid using threads when the number of searches is small: if (rows <= num_threads * 4) { @@ -725,36 +732,12 @@ class BFIndex { py::array_t < dist_t, py::array::c_style | py::array::forcecast > items(input); auto buffer = items.request(); size_t rows, features; - - if (buffer.ndim != 2 && buffer.ndim != 1) throw std::runtime_error("data must be a 1d/2d array"); - if (buffer.ndim == 2) { - rows = buffer.shape[0]; - features = buffer.shape[1]; - } else { - rows = 1; - features = buffer.shape[0]; - } + get_input_array_shapes(buffer, &rows, &features); if (features != dim) throw std::runtime_error("wrong dimensionality of the vectors"); - std::vector ids; - - if (!ids_.is_none()) { - py::array_t < size_t, py::array::c_style | py::array::forcecast > items(ids_); - auto ids_numpy = items.request(); - if (ids_numpy.ndim == 1 && ids_numpy.shape[0] == rows) { - std::vector ids1(ids_numpy.shape[0]); - for (size_t i = 0; i < ids1.size(); i++) { - ids1[i] = items.data()[i]; - } - ids.swap(ids1); - } else if (ids_numpy.ndim == 0 && rows == 1) { - ids.push_back(*items.data()); - } else { - throw std::runtime_error("wrong dimensionality of the labels"); - } - } + std::vector ids = get_input_ids_and_check_shapes(ids_, rows); { for (size_t row = 0; row < rows; row++) { @@ -802,14 +785,7 @@ class BFIndex { { py::gil_scoped_release l; - if (buffer.ndim != 2 && buffer.ndim != 1) throw std::runtime_error("data must be a 1d/2d array"); - if (buffer.ndim == 2) { - rows = buffer.shape[0]; - features = buffer.shape[1]; - } else { - rows = 1; - features = buffer.shape[0]; - } + get_input_array_shapes(buffer, &rows, &features); data_numpy_l = new hnswlib::labeltype[rows * k]; data_numpy_d = new dist_t[rows * k]; @@ -836,14 +812,14 @@ class BFIndex { return py::make_tuple( py::array_t( - {rows, k}, // shape - {k * sizeof(hnswlib::labeltype), + { rows, k }, // shape + { k * sizeof(hnswlib::labeltype), sizeof(hnswlib::labeltype)}, // C-style contiguous strides for each index data_numpy_l, // the data pointer free_when_done_l), py::array_t( - {rows, k}, // shape - {k * sizeof(dist_t), sizeof(dist_t)}, // C-style contiguous strides for each index + { rows, k }, // shape + { k * sizeof(dist_t), sizeof(dist_t) }, // C-style contiguous strides for each index data_numpy_d, // the data pointer free_when_done_d)); } From 97e70098bc99ba2ff3ff5d44b11539dbb6b88257 Mon Sep 17 00:00:00 2001 From: Dmitry Yashunin Date: Mon, 19 Sep 2022 19:26:12 +0200 Subject: [PATCH 2/2] Refactoring --- python_bindings/bindings.cpp | 27 ++++++++++++++++++--------- 1 file changed, 18 insertions(+), 9 deletions(-) diff --git a/python_bindings/bindings.cpp b/python_bindings/bindings.cpp index d2444729..85751c0b 100644 --- a/python_bindings/bindings.cpp +++ b/python_bindings/bindings.cpp @@ -80,7 +80,13 @@ inline void assert_true(bool expr, const std::string & msg) { inline void get_input_array_shapes(const py::buffer_info& buffer, size_t* rows, size_t* features) { - if (buffer.ndim != 2 && buffer.ndim != 1) throw std::runtime_error("data must be a 1d/2d array"); + if (buffer.ndim != 2 && buffer.ndim != 1) { + char msg[256]; + snprintf(msg, sizeof(msg), + "Input vector data wrong shape. Number of dimensions %d. Data must be a 1D or 2D array.", + buffer.ndim); + throw std::runtime_error(msg); + } if (buffer.ndim == 2) { *rows = buffer.shape[0]; *features = buffer.shape[1]; @@ -91,17 +97,20 @@ inline void get_input_array_shapes(const py::buffer_info& buffer, size_t* rows, } -inline std::vector get_input_ids_and_check_shapes(const py::object& ids_, size_t rows) { +inline std::vector get_input_ids_and_check_shapes(const py::object& ids_, size_t feature_rows) { std::vector ids; if (!ids_.is_none()) { py::array_t < size_t, py::array::c_style | py::array::forcecast > items(ids_); auto ids_numpy = items.request(); // check shapes - bool valid = false; - if ((ids_numpy.ndim == 1 && ids_numpy.shape[0] == rows) || (ids_numpy.ndim == 0 && rows == 1)) { - valid = true; + if (!((ids_numpy.ndim == 1 && ids_numpy.shape[0] == feature_rows) || + (ids_numpy.ndim == 0 && feature_rows == 1))) { + char msg[256]; + snprintf(msg, sizeof(msg), + "The input label shape %d does not match the input data vector shape %d", + ids_numpy.ndim, feature_rows); + throw std::runtime_error(msg); } - if (!valid) throw std::runtime_error("wrong dimensionality of the labels"); // extract data if (ids_numpy.ndim == 1) { std::vector ids1(ids_numpy.shape[0]); @@ -230,7 +239,7 @@ class Index { get_input_array_shapes(buffer, &rows, &features); if (features != dim) - throw std::runtime_error("wrong dimensionality of the vectors"); + throw std::runtime_error("Wrong dimensionality of the vectors"); // avoid using threads when the number of additions is small: if (rows <= num_threads * 4) { @@ -518,7 +527,7 @@ class Index { for (size_t i = 0; i < appr_alg->cur_element_count; i++) { if (label_lookup_val_npy.data()[i] < 0) { - throw std::runtime_error("internal id cannot be negative!"); + throw std::runtime_error("Internal id cannot be negative!"); } else { appr_alg->label_lookup_.insert(std::make_pair(label_lookup_key_npy.data()[i], label_lookup_val_npy.data()[i])); } @@ -735,7 +744,7 @@ class BFIndex { get_input_array_shapes(buffer, &rows, &features); if (features != dim) - throw std::runtime_error("wrong dimensionality of the vectors"); + throw std::runtime_error("Wrong dimensionality of the vectors"); std::vector ids = get_input_ids_and_check_shapes(ids_, rows);