11#include " ../hnswlib/hnswlib.h"
22#include < thread>
3+
4+
35class StopW
46{
57 std::chrono::steady_clock::time_point time_begin;
@@ -22,6 +24,7 @@ class StopW
2224 }
2325};
2426
27+
2528/*
2629 * replacement for the openmp '#pragma omp parallel for' directive
2730 * only handles a subset of functionality (no reductions etc)
@@ -81,8 +84,6 @@ inline void ParallelFor(size_t start, size_t end, size_t numThreads, Function fn
8184 std::rethrow_exception (lastException);
8285 }
8386 }
84-
85-
8687}
8788
8889
@@ -94,7 +95,7 @@ std::vector<datatype> load_batch(std::string path, int size)
9495 assert (sizeof (datatype) == 4 );
9596
9697 std::ifstream file;
97- file.open (path);
98+ file.open (path, std::ios::binary );
9899 if (!file.is_open ())
99100 {
100101 std::cout << " Cannot open " << path << " \n " ;
@@ -107,6 +108,7 @@ std::vector<datatype> load_batch(std::string path, int size)
107108 return batch;
108109}
109110
111+
110112template <typename d_type>
111113static float
112114test_approx (std::vector<float > &queries, size_t qsize, hnswlib::HierarchicalNSW<d_type> &appr_alg, size_t vecdim,
@@ -137,6 +139,7 @@ test_approx(std::vector<float> &queries, size_t qsize, hnswlib::HierarchicalNSW<
137139 return 1 .0f * correct / total;
138140}
139141
142+
140143static void
141144test_vs_recall (std::vector<float > &queries, size_t qsize, hnswlib::HierarchicalNSW<float > &appr_alg, size_t vecdim,
142145 std::vector<std::unordered_set<hnswlib::labeltype>> &answers, size_t k)
@@ -155,6 +158,8 @@ test_vs_recall(std::vector<float> &queries, size_t qsize, hnswlib::HierarchicalN
155158 efs.push_back (i);
156159 }
157160 std::cout << " ef\t recall\t time\t hops\t distcomp\n " ;
161+
162+ bool test_passed = false ;
158163 for (size_t ef : efs)
159164 {
160165 appr_alg.setEf (ef);
@@ -171,20 +176,24 @@ test_vs_recall(std::vector<float> &queries, size_t qsize, hnswlib::HierarchicalN
171176 std::cout << ef << " \t " << recall << " \t " << time_us_per_query << " us \t " <<hops_per_query<<" \t " <<distance_comp_per_query << " \n " ;
172177 if (recall > 0.99 )
173178 {
179+ test_passed = true ;
174180 std::cout << " Recall is over 0.99! " <<recall << " \t " << time_us_per_query << " us \t " <<hops_per_query<<" \t " <<distance_comp_per_query << " \n " ;
175181 break ;
176182 }
177183 }
184+ if (!test_passed)
185+ {
186+ std::cerr << " Test failed\n " ;
187+ exit (1 );
188+ }
178189}
179190
191+
180192int main (int argc, char **argv)
181193{
182-
183194 int M = 16 ;
184195 int efConstruction = 200 ;
185196 int num_threads = std::thread::hardware_concurrency ();
186-
187-
188197
189198 bool update = false ;
190199
@@ -207,7 +216,6 @@ int main(int argc, char **argv)
207216
208217 std::string path = " ../examples/data/" ;
209218
210-
211219 int N;
212220 int dummy_data_multiplier;
213221 int N_queries;
@@ -240,7 +248,6 @@ int main(int argc, char **argv)
240248 if (update)
241249 {
242250 std::cout << " Update iteration 0\n " ;
243-
244251
245252 ParallelFor (1 , N, num_threads, [&](size_t i, size_t threadId) {
246253 appr_alg.addPoint ((void *)(dummy_batch.data () + i * d), i);
@@ -295,4 +302,4 @@ int main(int argc, char **argv)
295302 }
296303
297304 return 0 ;
298- };
305+ };
0 commit comments