@@ -42,7 +42,7 @@ inline void ParallelFor(size_t start, size_t end, size_t numThreads, Function fn
4242 while (true ) {
4343 size_t id = current.fetch_add (1 );
4444
45- if (( id >= end) ) {
45+ if (id >= end) {
4646 break ;
4747 }
4848
@@ -79,6 +79,45 @@ inline void assert_true(bool expr, const std::string & msg) {
7979}
8080
8181
82+ inline void get_input_array_shapes (const py::buffer_info& buffer, size_t * rows, size_t * features) {
83+ if (buffer.ndim != 2 && buffer.ndim != 1 ) throw std::runtime_error (" data must be a 1d/2d array" );
84+ if (buffer.ndim == 2 ) {
85+ *rows = buffer.shape [0 ];
86+ *features = buffer.shape [1 ];
87+ } else {
88+ *rows = 1 ;
89+ *features = buffer.shape [0 ];
90+ }
91+ }
92+
93+
94+ inline std::vector<size_t > get_input_ids_and_check_shapes (const py::object& ids_, size_t rows) {
95+ std::vector<size_t > ids;
96+ if (!ids_.is_none ()) {
97+ py::array_t < size_t , py::array::c_style | py::array::forcecast > items (ids_);
98+ auto ids_numpy = items.request ();
99+ // check shapes
100+ bool valid = false ;
101+ if ((ids_numpy.ndim == 1 && ids_numpy.shape [0 ] == rows) || (ids_numpy.ndim == 0 && rows == 1 )) {
102+ valid = true ;
103+ }
104+ if (!valid) throw std::runtime_error (" wrong dimensionality of the labels" );
105+ // extract data
106+ if (ids_numpy.ndim == 1 ) {
107+ std::vector<size_t > ids1 (ids_numpy.shape [0 ]);
108+ for (size_t i = 0 ; i < ids1.size (); i++) {
109+ ids1[i] = items.data ()[i];
110+ }
111+ ids.swap (ids1);
112+ } else if (ids_numpy.ndim == 0 ) {
113+ ids.push_back (*items.data ());
114+ }
115+ }
116+
117+ return ids;
118+ }
119+
120+
82121template <typename dist_t , typename data_t = float >
83122class Index {
84123 public:
@@ -146,7 +185,7 @@ class Index {
146185 void set_ef (size_t ef) {
147186 default_ef = ef;
148187 if (appr_alg)
149- appr_alg->ef_ = ef;
188+ appr_alg->ef_ = ef;
150189 }
151190
152191
@@ -188,15 +227,7 @@ class Index {
188227 num_threads = num_threads_default;
189228
190229 size_t rows, features;
191-
192- if (buffer.ndim != 2 && buffer.ndim != 1 ) throw std::runtime_error (" data must be a 1d/2d array" );
193- if (buffer.ndim == 2 ) {
194- rows = buffer.shape [0 ];
195- features = buffer.shape [1 ];
196- } else {
197- rows = 1 ;
198- features = buffer.shape [0 ];
199- }
230+ get_input_array_shapes (buffer, &rows, &features);
200231
201232 if (features != dim)
202233 throw std::runtime_error (" wrong dimensionality of the vectors" );
@@ -206,23 +237,7 @@ class Index {
206237 num_threads = 1 ;
207238 }
208239
209- std::vector<size_t > ids;
210-
211- if (!ids_.is_none ()) {
212- py::array_t < size_t , py::array::c_style | py::array::forcecast > items (ids_);
213- auto ids_numpy = items.request ();
214- if (ids_numpy.ndim == 1 && ids_numpy.shape [0 ] == rows) {
215- std::vector<size_t > ids1 (ids_numpy.shape [0 ]);
216- for (size_t i = 0 ; i < ids1.size (); i++) {
217- ids1[i] = items.data ()[i];
218- }
219- ids.swap (ids1);
220- } else if (ids_numpy.ndim == 0 && rows == 1 ) {
221- ids.push_back (*items.data ());
222- } else {
223- throw std::runtime_error (" wrong dimensionality of the labels" );
224- }
225- }
240+ std::vector<size_t > ids = get_input_ids_and_check_shapes (ids_, rows);
226241
227242 {
228243 int start = 0 ;
@@ -561,15 +576,7 @@ class Index {
561576
562577 {
563578 py::gil_scoped_release l;
564-
565- if (buffer.ndim != 2 && buffer.ndim != 1 ) throw std::runtime_error (" data must be a 1d/2d array" );
566- if (buffer.ndim == 2 ) {
567- rows = buffer.shape [0 ];
568- features = buffer.shape [1 ];
569- } else {
570- rows = 1 ;
571- features = buffer.shape [0 ];
572- }
579+ get_input_array_shapes (buffer, &rows, &features);
573580
574581 // avoid using threads when the number of searches is small:
575582 if (rows <= num_threads * 4 ) {
@@ -725,36 +732,12 @@ class BFIndex {
725732 py::array_t < dist_t , py::array::c_style | py::array::forcecast > items (input);
726733 auto buffer = items.request ();
727734 size_t rows, features;
728-
729- if (buffer.ndim != 2 && buffer.ndim != 1 ) throw std::runtime_error (" data must be a 1d/2d array" );
730- if (buffer.ndim == 2 ) {
731- rows = buffer.shape [0 ];
732- features = buffer.shape [1 ];
733- } else {
734- rows = 1 ;
735- features = buffer.shape [0 ];
736- }
735+ get_input_array_shapes (buffer, &rows, &features);
737736
738737 if (features != dim)
739738 throw std::runtime_error (" wrong dimensionality of the vectors" );
740739
741- std::vector<size_t > ids;
742-
743- if (!ids_.is_none ()) {
744- py::array_t < size_t , py::array::c_style | py::array::forcecast > items (ids_);
745- auto ids_numpy = items.request ();
746- if (ids_numpy.ndim == 1 && ids_numpy.shape [0 ] == rows) {
747- std::vector<size_t > ids1 (ids_numpy.shape [0 ]);
748- for (size_t i = 0 ; i < ids1.size (); i++) {
749- ids1[i] = items.data ()[i];
750- }
751- ids.swap (ids1);
752- } else if (ids_numpy.ndim == 0 && rows == 1 ) {
753- ids.push_back (*items.data ());
754- } else {
755- throw std::runtime_error (" wrong dimensionality of the labels" );
756- }
757- }
740+ std::vector<size_t > ids = get_input_ids_and_check_shapes (ids_, rows);
758741
759742 {
760743 for (size_t row = 0 ; row < rows; row++) {
@@ -802,14 +785,7 @@ class BFIndex {
802785 {
803786 py::gil_scoped_release l;
804787
805- if (buffer.ndim != 2 && buffer.ndim != 1 ) throw std::runtime_error (" data must be a 1d/2d array" );
806- if (buffer.ndim == 2 ) {
807- rows = buffer.shape [0 ];
808- features = buffer.shape [1 ];
809- } else {
810- rows = 1 ;
811- features = buffer.shape [0 ];
812- }
788+ get_input_array_shapes (buffer, &rows, &features);
813789
814790 data_numpy_l = new hnswlib::labeltype[rows * k];
815791 data_numpy_d = new dist_t [rows * k];
@@ -836,14 +812,14 @@ class BFIndex {
836812
837813 return py::make_tuple (
838814 py::array_t <hnswlib::labeltype>(
839- {rows, k}, // shape
840- {k * sizeof (hnswlib::labeltype),
815+ { rows, k }, // shape
816+ { k * sizeof (hnswlib::labeltype),
841817 sizeof (hnswlib::labeltype)}, // C-style contiguous strides for each index
842818 data_numpy_l, // the data pointer
843819 free_when_done_l),
844820 py::array_t <dist_t >(
845- {rows, k}, // shape
846- {k * sizeof (dist_t ), sizeof (dist_t )}, // C-style contiguous strides for each index
821+ { rows, k }, // shape
822+ { k * sizeof (dist_t ), sizeof (dist_t ) }, // C-style contiguous strides for each index
847823 data_numpy_d, // the data pointer
848824 free_when_done_d));
849825 }
0 commit comments