Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions cachecontrol/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ class BaseCache(object):
def get(self, key):
raise NotImplementedError()

def set(self, key, value):
def set(self, key, value, expires=None):
raise NotImplementedError()

def delete(self, key):
Expand All @@ -29,7 +29,7 @@ def __init__(self, init_dict=None):
def get(self, key):
return self.data.get(key, None)

def set(self, key, value):
def set(self, key, value, expires=None):
with self.lock:
self.data.update({key: value})

Expand Down
3 changes: 2 additions & 1 deletion cachecontrol/caches/file_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,8 @@ def get(self, key):
except FileNotFoundError:
return None

def set(self, key, value):
def set(self, key, value, expires=None):
# NOTE: `expires` is not used by this cache backend.
name = self._fn(key)

# Make sure the directory exists
Expand Down
6 changes: 1 addition & 5 deletions cachecontrol/caches/redis_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,7 @@ def get(self, key):
return self.conn.get(key)

def set(self, key, value, expires=None):
if not expires:
self.conn.set(key, value)
else:
expires = expires - datetime.utcnow()
self.conn.setex(key, int(expires.total_seconds()), value)
self.conn.set(key, value, ex=expires or None)

def delete(self, key):
self.conn.delete(key)
Expand Down
30 changes: 27 additions & 3 deletions cachecontrol/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ class CacheController(object):
"""An interface to see if request should cached or not.
"""

# 2 weeks max cache for backends that support TTL. Set to falsey for no maximum.
CACHE_TTL_MAX = 14*86400

def __init__(
self, cache=None, cache_etags=True, serializer=None, status_codes=None
):
Expand Down Expand Up @@ -306,7 +309,9 @@ def cache_response(self, request, response, body=None, status_codes=None):
if self.cache_etags and "etag" in response_headers:
logger.debug("Caching due to etag")
self.cache.set(
cache_url, self.serializer.dumps(request, response, body=body)
cache_url,
self.serializer.dumps(request, response, body=body),
expires=self.get_cache_expiry(response_headers, cc),
)

# Add to the cache any 301s. We do this before looking that
Expand All @@ -323,7 +328,9 @@ def cache_response(self, request, response, body=None, status_codes=None):
if "max-age" in cc and cc["max-age"] > 0:
logger.debug("Caching b/c date exists and max-age > 0")
self.cache.set(
cache_url, self.serializer.dumps(request, response, body=body)
cache_url,
self.serializer.dumps(request, response, body=body),
expires=self.get_cache_expiry(response_headers, cc),
)

# If the request can expire, it means we should cache it
Expand All @@ -332,9 +339,26 @@ def cache_response(self, request, response, body=None, status_codes=None):
if response_headers["expires"]:
logger.debug("Caching b/c of expires header")
self.cache.set(
cache_url, self.serializer.dumps(request, response, body=body)
cache_url,
self.serializer.dumps(request, response, body=body),
expires=self.get_cache_expiry(response_headers, cc),
)

def get_cache_expiry(self, headers, cc):
"""Derives a TTL from the response headers to pass to the caching backends that
support it.
"""
ex = cc.get("max-age", 0)
if not ex and "expires" in headers:
try:
date = calendar.timegm(parsedate_tz(headers["date"]))
ex = calendar.timegm(parsedate_tz(headers["expires"])) - date
except (TypeError, KeyError):
pass
if self.CACHE_TTL_MAX:
return min(ex, self.CACHE_TTL_MAX)
return ex

def update_cached_response(self, request, response):
"""On a 304 we will get a new set of headers that we want to
update our cached value with, assuming we have one.
Expand Down
2 changes: 1 addition & 1 deletion tests/test_cache_control.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def test_cache_response_cache_max_age(self, cc):
req = self.req()
cc.cache_response(req, resp)
cc.serializer.dumps.assert_called_with(req, resp, body=None)
cc.cache.set.assert_called_with(self.url, ANY)
cc.cache.set.assert_called_with(self.url, ANY, expires=3600)

def test_cache_response_cache_max_age_with_invalid_value_not_cached(self, cc):
now = time.strftime(TIME_FMT, time.gmtime())
Expand Down
12 changes: 10 additions & 2 deletions tests/test_storage_redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,13 @@ def setup(self):
self.cache = RedisCache(self.conn)

def test_set_expiration(self):
self.cache.set("foo", "bar", expires=datetime(2014, 2, 2))
assert self.conn.setex.called
self.cache.set("foo", "bar", expires=3600)
self.conn.set.assert_called_with("foo", "bar", ex=3600)

def test_set_invalid_age(self):
"""
Verify that expires=0 will not cause Redis to throw an error. It must be passed
as None or we receive: ResponseError: invalid expire time in set)
"""
self.cache.set("foo", "bar", expires=0)
self.conn.set.assert_called_with("foo", "bar", ex=None)