diff --git a/README.md b/README.md index 4c5d873..e3f0f59 100644 --- a/README.md +++ b/README.md @@ -35,15 +35,18 @@ Credits and links can be found in AUTHORS.md. # instantiate BloomFilter with custom settings, # max_elements is how many elements you expect the filter to hold. # error_rate defines accuracy; You can use defaults with + # key is a mapping from input objects to something the filter can take as input # `BloomFilter()` without any arguments. Following example # is same as defaults: bloom = BloomFilter(max_elements=10000, error_rate=0.1) - # Test whether the bloom-filter has seen a key: - assert "test-key" in bloom is False + # Test whether the bloom-filter has seen a object: + assert "test-obj" in bloom is False - # Mark the key as seen - bloom.add("test-key") + # Mark the object as seen + bloom.add("test-obj") # Now check again - assert "test-key" in bloom is True + assert "test-obj" in bloom is True + + # If you are using a key make sure it's a one to one mapping diff --git a/src/bloom_filter/bloom_filter.py b/src/bloom_filter/bloom_filter.py index e16c952..7056c59 100644 --- a/src/bloom_filter/bloom_filter.py +++ b/src/bloom_filter/bloom_filter.py @@ -368,7 +368,6 @@ def __ior__(self, other): self.set(bitno) else: self.clear(bitno) - return self def close(self): @@ -510,7 +509,8 @@ def __init__(self, error_rate=0.1, probe_bitnoer=get_bitno_lin_comb, filename=None, - start_fresh=False): + start_fresh=False, + key=None): # pylint: disable=R0913 # R0913: We want a few arguments if max_elements <= 0: @@ -547,16 +547,20 @@ def __init__(self, real_num_probes_k = (self.num_bits_m / self.ideal_num_elements_n) * math.log(2) self.num_probes_k = int(math.ceil(real_num_probes_k)) self.probe_bitnoer = probe_bitnoer + self.key = key def __repr__(self): - return 'BloomFilter(ideal_num_elements_n=%d, error_rate_p=%f, num_bits_m=%d)' % ( + return 'BloomFilter(ideal_num_elements_n=%d, error_rate_p=%f, num_bits_m=%d, key=%s)' % ( self.ideal_num_elements_n, self.error_rate_p, self.num_bits_m, + str(self.key), ) def add(self, key): """Add an element to the filter""" + if self.key: + key = self.key(key) for bitno in self.probe_bitnoer(self, key): self.backend.set(bitno) @@ -587,6 +591,8 @@ def __iand__(self, bloom_filter): return self def __contains__(self, key): + if self.key: + key=self.key(key) for bitno in self.probe_bitnoer(self, key): if not self.backend.is_set(bitno): return False diff --git a/tests/test_bloom_filter.py b/tests/test_bloom_filter.py index a94508c..f526cec 100755 --- a/tests/test_bloom_filter.py +++ b/tests/test_bloom_filter.py @@ -18,6 +18,7 @@ import random +# if you are doing tests this imports from your pip, so it might not be the updated version import bloom_filter CHARACTERS = 'abcdefghijklmnopqrstuvwxyz1234567890' @@ -105,7 +106,6 @@ def _test(description, values, trials, error_rate, probe_bitnoer=None, filename= return all_good - class States(object): """Generate the USA's state names""" @@ -200,6 +200,42 @@ def length(self): """How many members?""" return int(math.ceil(self.maximum / 2.0)) +def key_test(): + """ Test that we can do everything with a key too """ + class Dummy(): + def __init__(self, biscuit): + self.biscuit = biscuit + def get(self): + return self.biscuit + def dummy_key(dummy): + return dummy.get() + + dummy_filter = bloom_filter.BloomFilter(key=dummy_key) + + dummies = [Dummy(3), Dummy('a'), Dummy('b')] + for dummy in dummies: + dummy_filter += dummy + for dummy in dummies: + if dummy not in dummy_filter: + return False # failed to insert + + other_dummy_filter = bloom_filter.BloomFilter(key=dummy_key) + other_dummy_filter += Dummy('hello world') + + dummy_filter |= other_dummy_filter + for dummy in dummies + [Dummy('hello world')]: + if dummy not in dummy_filter: + return False + + dummy_filter &= other_dummy_filter + for dummy in dummies: + if dummy in dummy_filter: + return False + if Dummy('hello world') not in dummy_filter: + return False + + return True + def and_test(): """Test the & operator""" @@ -304,6 +340,8 @@ def test_bloom_filter(): all_good &= or_test() + all_good &= key_test() + if performance_test: sqrt_of_10 = math.sqrt(10) # for exponent in range(5): # this is a lot, but probably not unreasonable @@ -346,7 +384,5 @@ def test_bloom_filter(): if not all_good: sys.stderr.write('%s: One or more tests failed\n' % sys.argv[0]) sys.exit(1) - - if __name__ == '__main__': test_bloom_filter()