diff --git a/cachecontrol/cache.py b/cachecontrol/cache.py index 94e07732..8d86bcba 100644 --- a/cachecontrol/cache.py +++ b/cachecontrol/cache.py @@ -10,7 +10,7 @@ class BaseCache(object): def get(self, key): raise NotImplementedError() - def set(self, key, value): + def set(self, key, value, expires=None): raise NotImplementedError() def delete(self, key): @@ -29,7 +29,7 @@ def __init__(self, init_dict=None): def get(self, key): return self.data.get(key, None) - def set(self, key, value): + def set(self, key, value, expires=None): with self.lock: self.data.update({key: value}) diff --git a/cachecontrol/caches/file_cache.py b/cachecontrol/caches/file_cache.py index 607b9452..69ea2362 100644 --- a/cachecontrol/caches/file_cache.py +++ b/cachecontrol/caches/file_cache.py @@ -114,7 +114,8 @@ def get(self, key): except FileNotFoundError: return None - def set(self, key, value): + def set(self, key, value, expires=None): + # NOTE: `expires` is not used by this cache backend. name = self._fn(key) # Make sure the directory exists diff --git a/cachecontrol/caches/redis_cache.py b/cachecontrol/caches/redis_cache.py index 16da0aed..495a6d17 100644 --- a/cachecontrol/caches/redis_cache.py +++ b/cachecontrol/caches/redis_cache.py @@ -13,11 +13,7 @@ def get(self, key): return self.conn.get(key) def set(self, key, value, expires=None): - if not expires: - self.conn.set(key, value) - else: - expires = expires - datetime.utcnow() - self.conn.setex(key, int(expires.total_seconds()), value) + self.conn.set(key, value, ex=expires or None) def delete(self, key): self.conn.delete(key) diff --git a/cachecontrol/controller.py b/cachecontrol/controller.py index c5c4a508..1602922e 100644 --- a/cachecontrol/controller.py +++ b/cachecontrol/controller.py @@ -31,6 +31,9 @@ class CacheController(object): """An interface to see if request should cached or not. """ + # 2 weeks max cache for backends that support TTL. Set to falsey for no maximum. + CACHE_TTL_MAX = 14*86400 + def __init__( self, cache=None, cache_etags=True, serializer=None, status_codes=None ): @@ -306,7 +309,9 @@ def cache_response(self, request, response, body=None, status_codes=None): if self.cache_etags and "etag" in response_headers: logger.debug("Caching due to etag") self.cache.set( - cache_url, self.serializer.dumps(request, response, body=body) + cache_url, + self.serializer.dumps(request, response, body=body), + expires=self.get_cache_expiry(response_headers, cc), ) # Add to the cache any 301s. We do this before looking that @@ -323,7 +328,9 @@ def cache_response(self, request, response, body=None, status_codes=None): if "max-age" in cc and cc["max-age"] > 0: logger.debug("Caching b/c date exists and max-age > 0") self.cache.set( - cache_url, self.serializer.dumps(request, response, body=body) + cache_url, + self.serializer.dumps(request, response, body=body), + expires=self.get_cache_expiry(response_headers, cc), ) # If the request can expire, it means we should cache it @@ -332,9 +339,26 @@ def cache_response(self, request, response, body=None, status_codes=None): if response_headers["expires"]: logger.debug("Caching b/c of expires header") self.cache.set( - cache_url, self.serializer.dumps(request, response, body=body) + cache_url, + self.serializer.dumps(request, response, body=body), + expires=self.get_cache_expiry(response_headers, cc), ) + def get_cache_expiry(self, headers, cc): + """Derives a TTL from the response headers to pass to the caching backends that + support it. + """ + ex = cc.get("max-age", 0) + if not ex and "expires" in headers: + try: + date = calendar.timegm(parsedate_tz(headers["date"])) + ex = calendar.timegm(parsedate_tz(headers["expires"])) - date + except (TypeError, KeyError): + pass + if self.CACHE_TTL_MAX: + return min(ex, self.CACHE_TTL_MAX) + return ex + def update_cached_response(self, request, response): """On a 304 we will get a new set of headers that we want to update our cached value with, assuming we have one. diff --git a/tests/test_cache_control.py b/tests/test_cache_control.py index 7ede3713..8cf6e901 100644 --- a/tests/test_cache_control.py +++ b/tests/test_cache_control.py @@ -83,7 +83,7 @@ def test_cache_response_cache_max_age(self, cc): req = self.req() cc.cache_response(req, resp) cc.serializer.dumps.assert_called_with(req, resp, body=None) - cc.cache.set.assert_called_with(self.url, ANY) + cc.cache.set.assert_called_with(self.url, ANY, expires=3600) def test_cache_response_cache_max_age_with_invalid_value_not_cached(self, cc): now = time.strftime(TIME_FMT, time.gmtime()) diff --git a/tests/test_storage_redis.py b/tests/test_storage_redis.py index d7b3afc1..7d327b66 100644 --- a/tests/test_storage_redis.py +++ b/tests/test_storage_redis.py @@ -11,5 +11,13 @@ def setup(self): self.cache = RedisCache(self.conn) def test_set_expiration(self): - self.cache.set("foo", "bar", expires=datetime(2014, 2, 2)) - assert self.conn.setex.called + self.cache.set("foo", "bar", expires=3600) + self.conn.set.assert_called_with("foo", "bar", ex=3600) + + def test_set_invalid_age(self): + """ + Verify that expires=0 will not cause Redis to throw an error. It must be passed + as None or we receive: ResponseError: invalid expire time in set) + """ + self.cache.set("foo", "bar", expires=0) + self.conn.set.assert_called_with("foo", "bar", ex=None)