diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 5e0c4a54..bf429e28 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -34,7 +34,10 @@ jobs: python -VV python -m site python -m pip install --upgrade pip setuptools wheel - python -m pip install --upgrade virtualenv tox tox-gh-actions + python -m pip install --upgrade virtualenv tox tox-gh-actions - name: "Run tox targets for ${{ matrix.python-version }}" run: "python -m tox" + + - name: "Run mypy for ${{ matrix.python-version }}" + run: "python -m tox -e mypy" diff --git a/cachecontrol/__init__.py b/cachecontrol/__init__.py index f631ae6d..9a13eed8 100644 --- a/cachecontrol/__init__.py +++ b/cachecontrol/__init__.py @@ -10,9 +10,19 @@ __email__ = "eric@ionrock.org" __version__ = "0.12.11" -from .wrapper import CacheControl from .adapter import CacheControlAdapter from .controller import CacheController +from .wrapper import CacheControl + +__all__ = [ + "__author__", + "__email__", + "__version__", + "CacheControlAdapter", + "CacheController", + "CacheControl", +] import logging + logging.getLogger(__name__).addHandler(logging.NullHandler()) diff --git a/cachecontrol/_cmd.py b/cachecontrol/_cmd.py index ccee0079..bf3bbac8 100644 --- a/cachecontrol/_cmd.py +++ b/cachecontrol/_cmd.py @@ -3,6 +3,8 @@ # SPDX-License-Identifier: Apache-2.0 import logging +from argparse import ArgumentParser +from typing import TYPE_CHECKING import requests @@ -10,16 +12,19 @@ from cachecontrol.cache import DictCache from cachecontrol.controller import logger -from argparse import ArgumentParser +if TYPE_CHECKING: + from argparse import Namespace + from .controller import CacheController -def setup_logging(): + +def setup_logging() -> None: logger.setLevel(logging.DEBUG) handler = logging.StreamHandler() logger.addHandler(handler) -def get_session(): +def get_session() -> requests.Session: adapter = CacheControlAdapter( DictCache(), cache_etags=True, serializer=None, heuristic=None ) @@ -27,17 +32,17 @@ def get_session(): sess.mount("http://", adapter) sess.mount("https://", adapter) - sess.cache_controller = adapter.controller + sess.cache_controller = adapter.controller # type: ignore[attr-defined] return sess -def get_args(): +def get_args() -> "Namespace": parser = ArgumentParser() parser.add_argument("url", help="The URL to try and cache") return parser.parse_args() -def main(args=None): +def main() -> None: args = get_args() sess = get_session() @@ -48,10 +53,13 @@ def main(args=None): setup_logging() # try setting the cache - sess.cache_controller.cache_response(resp.request, resp.raw) + cache_controller: "CacheController" = ( + sess.cache_controller # type: ignore[attr-defined] + ) + cache_controller.cache_response(resp.request, resp.raw) # Now try to get it - if sess.cache_controller.cached_request(resp.request): + if cache_controller.cached_request(resp.request): print("Cached!") else: print("Not cached :(") diff --git a/cachecontrol/adapter.py b/cachecontrol/adapter.py index 22b49638..7ecec2ef 100644 --- a/cachecontrol/adapter.py +++ b/cachecontrol/adapter.py @@ -2,31 +2,40 @@ # # SPDX-License-Identifier: Apache-2.0 -import types import functools +import types import zlib +from typing import TYPE_CHECKING, Any, Collection, Mapping, Optional, Tuple, Type, Union from requests.adapters import HTTPAdapter -from .controller import CacheController, PERMANENT_REDIRECT_STATUSES from .cache import DictCache +from .controller import PERMANENT_REDIRECT_STATUSES, CacheController from .filewrapper import CallbackFileWrapper +if TYPE_CHECKING: + from requests import PreparedRequest, Response + + from .cache import BaseCache + from .compat import HTTPResponse + from .heuristics import BaseHeuristic + from .serialize import Serializer + class CacheControlAdapter(HTTPAdapter): invalidating_methods = {"PUT", "PATCH", "DELETE"} def __init__( self, - cache=None, - cache_etags=True, - controller_class=None, - serializer=None, - heuristic=None, - cacheable_methods=None, - *args, - **kw - ): + cache: Optional["BaseCache"] = None, + cache_etags: bool = True, + controller_class: Optional[Type[CacheController]] = None, + serializer: Optional["Serializer"] = None, + heuristic: Optional["BaseHeuristic"] = None, + cacheable_methods: Optional[Collection[str]] = None, + *args: Any, + **kw: Any, + ) -> None: super(CacheControlAdapter, self).__init__(*args, **kw) self.cache = DictCache() if cache is None else cache self.heuristic = heuristic @@ -37,7 +46,18 @@ def __init__( self.cache, cache_etags=cache_etags, serializer=serializer ) - def send(self, request, cacheable_methods=None, **kw): + def send( + self, + request: "PreparedRequest", + stream: bool = False, + timeout: Union[None, float, Tuple[float, float], Tuple[float, None]] = None, + verify: Union[bool, str] = True, + cert: Union[ + None, bytes, str, Tuple[Union[bytes, str], Union[bytes, str]] + ] = None, + proxies: Optional[Mapping[str, str]] = None, + cacheable_methods: Optional[Collection[str]] = None, + ) -> "Response": """ Send a request. Use the request information to see if it exists in the cache and cache the response if we need to and can. @@ -54,13 +74,19 @@ def send(self, request, cacheable_methods=None, **kw): # check for etags and add headers if appropriate request.headers.update(self.controller.conditional_headers(request)) - resp = super(CacheControlAdapter, self).send(request, **kw) + resp = super(CacheControlAdapter, self).send( + request, stream, timeout, verify, cert, proxies + ) return resp def build_response( - self, request, response, from_cache=False, cacheable_methods=None - ): + self, + request: "PreparedRequest", + response: "HTTPResponse", + from_cache: bool = False, + cacheable_methods: Optional[Collection[str]] = None, + ) -> "Response": """ Build a response by making a request or using the cache. @@ -111,7 +137,7 @@ def build_response( if response.chunked: super_update_chunk_length = response._update_chunk_length - def _update_chunk_length(self): + def _update_chunk_length(self: "HTTPResponse") -> None: super_update_chunk_length() if self.chunk_left == 0: self._fp._close() @@ -120,18 +146,21 @@ def _update_chunk_length(self): _update_chunk_length, response ) - resp = super(CacheControlAdapter, self).build_response(request, response) + resp: "Response" = super( # type: ignore[no-untyped-call] + CacheControlAdapter, self + ).build_response(request, response) # See if we should invalidate the cache. if request.method in self.invalidating_methods and resp.ok: + assert request.url is not None cache_url = self.controller.cache_url(request.url) self.cache.delete(cache_url) # Give the request a from_cache attr to let people use it - resp.from_cache = from_cache + resp.from_cache = from_cache # type: ignore[attr-defined] return resp - def close(self): + def close(self) -> None: self.cache.close() - super(CacheControlAdapter, self).close() + super(CacheControlAdapter, self).close() # type: ignore[no-untyped-call] diff --git a/cachecontrol/cache.py b/cachecontrol/cache.py index 2a965f59..61031d23 100644 --- a/cachecontrol/cache.py +++ b/cachecontrol/cache.py @@ -7,37 +7,43 @@ safe in-memory dictionary. """ from threading import Lock +from typing import IO, TYPE_CHECKING, MutableMapping, Optional, Union +if TYPE_CHECKING: + from datetime import datetime -class BaseCache(object): - def get(self, key): +class BaseCache(object): + def get(self, key: str) -> Optional[bytes]: raise NotImplementedError() - def set(self, key, value, expires=None): + def set( + self, key: str, value: bytes, expires: Optional[Union[int, "datetime"]] = None + ) -> None: raise NotImplementedError() - def delete(self, key): + def delete(self, key: str) -> None: raise NotImplementedError() - def close(self): + def close(self) -> None: pass class DictCache(BaseCache): - - def __init__(self, init_dict=None): + def __init__(self, init_dict: Optional[MutableMapping[str, bytes]] = None) -> None: self.lock = Lock() self.data = init_dict or {} - def get(self, key): + def get(self, key: str) -> Optional[bytes]: return self.data.get(key, None) - def set(self, key, value, expires=None): + def set( + self, key: str, value: bytes, expires: Optional[Union[int, "datetime"]] = None + ) -> None: with self.lock: self.data.update({key: value}) - def delete(self, key): + def delete(self, key: str) -> None: with self.lock: if key in self.data: self.data.pop(key) @@ -55,10 +61,11 @@ class SeparateBodyBaseCache(BaseCache): Similarly, the body should be loaded separately via ``get_body()``. """ - def set_body(self, key, body): + + def set_body(self, key: str, body: bytes) -> None: raise NotImplementedError() - def get_body(self, key): + def get_body(self, key: str) -> Optional["IO[bytes]"]: """ Return the body as file-like object. """ diff --git a/cachecontrol/caches/file_cache.py b/cachecontrol/caches/file_cache.py index f82e7b85..4a6c4d7c 100644 --- a/cachecontrol/caches/file_cache.py +++ b/cachecontrol/caches/file_cache.py @@ -5,10 +5,17 @@ import hashlib import os from textwrap import dedent +from typing import IO, TYPE_CHECKING, Optional, Type, Union from ..cache import BaseCache, SeparateBodyBaseCache from ..controller import CacheController +if TYPE_CHECKING: + from datetime import datetime + + from filelock import BaseFileLock + + try: FileNotFoundError except NameError: @@ -16,7 +23,7 @@ FileNotFoundError = (IOError, OSError) -def _secure_open_write(filename, fmode): +def _secure_open_write(filename: str, fmode: int) -> "IO[bytes]": # We only want to write to this file, so open it in write only mode flags = os.O_WRONLY @@ -62,16 +69,16 @@ class _FileCacheMixin: def __init__( self, - directory, - forever=False, - filemode=0o0600, - dirmode=0o0700, - lock_class=None, - ): - + directory: str, + forever: bool = False, + filemode: int = 0o0600, + dirmode: int = 0o0700, + lock_class: Optional[Type["BaseFileLock"]] = None, + ) -> None: try: if lock_class is None: from filelock import FileLock + lock_class = FileLock except ImportError: notice = dedent( @@ -90,17 +97,17 @@ def __init__( self.lock_class = lock_class @staticmethod - def encode(x): + def encode(x: str) -> str: return hashlib.sha224(x.encode()).hexdigest() - def _fn(self, name): + def _fn(self, name: str) -> str: # NOTE: This method should not change as some may depend on it. # See: https://github.com/ionrock/cachecontrol/issues/63 hashed = self.encode(name) parts = list(hashed[:5]) + [hashed] return os.path.join(self.directory, *parts) - def get(self, key): + def get(self, key: str) -> Optional[bytes]: name = self._fn(key) try: with open(name, "rb") as fh: @@ -109,11 +116,13 @@ def get(self, key): except FileNotFoundError: return None - def set(self, key, value, expires=None): + def set( + self, key: str, value: bytes, expires: Optional[Union[int, "datetime"]] = None + ) -> None: name = self._fn(key) self._write(name, value) - def _write(self, path, data: bytes): + def _write(self, path: str, data: bytes) -> None: """ Safely write the data to the given path. """ @@ -128,7 +137,7 @@ def _write(self, path, data: bytes): with _secure_open_write(path, self.filemode) as fh: fh.write(data) - def _delete(self, key, suffix): + def _delete(self, key: str, suffix: str) -> None: name = self._fn(key) + suffix if not self.forever: try: @@ -143,7 +152,7 @@ class FileCache(_FileCacheMixin, BaseCache): downloads. """ - def delete(self, key): + def delete(self, key: str) -> None: self._delete(key, "") @@ -153,23 +162,23 @@ class SeparateBodyFileCache(_FileCacheMixin, SeparateBodyBaseCache): peak memory usage. """ - def get_body(self, key): + def get_body(self, key: str) -> Optional["IO[bytes]"]: name = self._fn(key) + ".body" try: return open(name, "rb") except FileNotFoundError: return None - def set_body(self, key, body): + def set_body(self, key: str, body: bytes) -> None: name = self._fn(key) + ".body" self._write(name, body) - def delete(self, key): + def delete(self, key: str) -> None: self._delete(key, "") self._delete(key, ".body") -def url_to_file_path(url, filecache): +def url_to_file_path(url: str, filecache: FileCache) -> str: """Return the file cache path based on the URL. This does not ensure the file exists! diff --git a/cachecontrol/caches/redis_cache.py b/cachecontrol/caches/redis_cache.py index 7bcb38a2..98e9f3e3 100644 --- a/cachecontrol/caches/redis_cache.py +++ b/cachecontrol/caches/redis_cache.py @@ -5,35 +5,41 @@ from __future__ import division from datetime import datetime +from typing import TYPE_CHECKING, Optional, Union + from cachecontrol.cache import BaseCache +if TYPE_CHECKING: + from redis import Redis -class RedisCache(BaseCache): - def __init__(self, conn): +class RedisCache(BaseCache): + def __init__(self, conn: "Redis[bytes]") -> None: self.conn = conn - def get(self, key): + def get(self, key: str) -> Optional[bytes]: return self.conn.get(key) - def set(self, key, value, expires=None): + def set( + self, key: str, value: bytes, expires: Optional[Union[int, datetime]] = None + ) -> None: if not expires: self.conn.set(key, value) elif isinstance(expires, datetime): - expires = expires - datetime.utcnow() - self.conn.setex(key, int(expires.total_seconds()), value) + delta = expires - datetime.utcnow() + self.conn.setex(key, int(delta.total_seconds()), value) else: self.conn.setex(key, expires, value) - def delete(self, key): + def delete(self, key: str) -> None: self.conn.delete(key) - def clear(self): + def clear(self) -> None: """Helper for clearing all the keys in a database. Use with caution!""" for key in self.conn.keys(): self.conn.delete(key) - def close(self): + def close(self) -> None: """Redis uses connection pooling, no need to close the connection.""" pass diff --git a/cachecontrol/compat.py b/cachecontrol/compat.py index 72c456cf..e5be1563 100644 --- a/cachecontrol/compat.py +++ b/cachecontrol/compat.py @@ -30,3 +30,5 @@ text_type = unicode except NameError: text_type = str + +__all__ = ["urljoin", "pickle", "HTTPResponse", "is_fp_closed", "text_type"] diff --git a/cachecontrol/controller.py b/cachecontrol/controller.py index 184fe667..dc71fccb 100644 --- a/cachecontrol/controller.py +++ b/cachecontrol/controller.py @@ -5,17 +5,23 @@ """ The httplib2 algorithms ported for use with requests. """ +import calendar import logging import re -import calendar import time from email.utils import parsedate_tz +from typing import TYPE_CHECKING, Collection, Dict, Mapping, Optional, Tuple, Union from requests.structures import CaseInsensitiveDict from .cache import DictCache, SeparateBodyBaseCache from .serialize import Serializer +if TYPE_CHECKING: + from requests import PreparedRequest + + from .cache import BaseCache + from .compat import HTTPResponse logger = logging.getLogger(__name__) @@ -24,12 +30,14 @@ PERMANENT_REDIRECT_STATUSES = (301, 308) -def parse_uri(uri): +def parse_uri(uri: str) -> Tuple[str, str, str, str, str]: """Parses a URI using the regex given in Appendix B of RFC 3986. (scheme, authority, path, query, fragment) = parse_uri(uri) """ - groups = URI.match(uri).groups() + match = URI.match(uri) + assert match is not None + groups = match.groups() return (groups[1], groups[3], groups[4], groups[6], groups[8]) @@ -37,7 +45,11 @@ class CacheController(object): """An interface to see if request should cached or not.""" def __init__( - self, cache=None, cache_etags=True, serializer=None, status_codes=None + self, + cache: Optional["BaseCache"] = None, + cache_etags: bool = True, + serializer: Optional[Serializer] = None, + status_codes: Optional[Collection[int]] = None, ): self.cache = DictCache() if cache is None else cache self.cache_etags = cache_etags @@ -45,7 +57,7 @@ def __init__( self.cacheable_status_codes = status_codes or (200, 203, 300, 301, 308) @classmethod - def _urlnorm(cls, uri): + def _urlnorm(cls, uri: str) -> str: """Normalize the URL to create a safe key for the cache""" (scheme, authority, path, query, fragment) = parse_uri(uri) if not scheme or not authority: @@ -65,10 +77,12 @@ def _urlnorm(cls, uri): return defrag_uri @classmethod - def cache_url(cls, uri): + def cache_url(cls, uri: str) -> str: return cls._urlnorm(uri) - def parse_cache_control(self, headers): + def parse_cache_control( + self, headers: Mapping[str, str] + ) -> Dict[str, Optional[int]]: known_directives = { # https://tools.ietf.org/html/rfc7234#section-5.2 "max-age": (int, True), @@ -87,7 +101,7 @@ def parse_cache_control(self, headers): cc_headers = headers.get("cache-control", headers.get("Cache-Control", "")) - retval = {} + retval: Dict[str, Optional[int]] = {} for cc_directive in cc_headers.split(","): if not cc_directive.strip(): @@ -122,11 +136,12 @@ def parse_cache_control(self, headers): return retval - def _load_from_cache(self, request): + def _load_from_cache(self, request: "PreparedRequest") -> Optional["HTTPResponse"]: """ Load a cached response, or return None if it's not available. """ cache_url = request.url + assert cache_url is not None cache_data = self.cache.get(cache_url) if cache_data is None: logger.debug("No cache entry available") @@ -142,11 +157,12 @@ def _load_from_cache(self, request): logger.warning("Cache entry deserialization failed, entry ignored") return result - def cached_request(self, request): + def cached_request(self, request: "PreparedRequest") -> Union["HTTPResponse", bool]: """ Return a cached response if it exists in the cache, otherwise return False. """ + assert request.url is not None cache_url = self.cache_url(request.url) logger.debug('Looking up "%s" in the cache', cache_url) cc = self.parse_cache_control(request.headers) @@ -182,7 +198,7 @@ def cached_request(self, request): logger.debug(msg) return resp - headers = CaseInsensitiveDict(resp.headers) + headers: CaseInsensitiveDict[str] = CaseInsensitiveDict(resp.headers) if not headers or "date" not in headers: if "etag" not in headers: # Without date or etag, the cached response can never be used @@ -193,7 +209,9 @@ def cached_request(self, request): return False now = time.time() - date = calendar.timegm(parsedate_tz(headers["date"])) + time_tuple = parsedate_tz(headers["date"]) + assert time_tuple is not None + date = calendar.timegm(time_tuple[:6]) current_age = max(0, now - date) logger.debug("Current age based on date: %i", current_age) @@ -207,28 +225,30 @@ def cached_request(self, request): freshness_lifetime = 0 # Check the max-age pragma in the cache control header - if "max-age" in resp_cc: - freshness_lifetime = resp_cc["max-age"] + max_age = resp_cc.get("max-age") + if max_age is not None: + freshness_lifetime = max_age logger.debug("Freshness lifetime from max-age: %i", freshness_lifetime) # If there isn't a max-age, check for an expires header elif "expires" in headers: expires = parsedate_tz(headers["expires"]) if expires is not None: - expire_time = calendar.timegm(expires) - date + expire_time = calendar.timegm(expires[:6]) - date freshness_lifetime = max(0, expire_time) logger.debug("Freshness lifetime from expires: %i", freshness_lifetime) # Determine if we are setting freshness limit in the # request. Note, this overrides what was in the response. - if "max-age" in cc: - freshness_lifetime = cc["max-age"] + max_age = cc.get("max-age") + if max_age is not None: + freshness_lifetime = max_age logger.debug( "Freshness lifetime from request max-age: %i", freshness_lifetime ) - if "min-fresh" in cc: - min_fresh = cc["min-fresh"] + min_fresh = cc.get("min-fresh") + if min_fresh is not None: # adjust our current age by our min fresh current_age += min_fresh logger.debug("Adjusted current age from min-fresh: %i", current_age) @@ -247,12 +267,12 @@ def cached_request(self, request): # return the original handler return False - def conditional_headers(self, request): + def conditional_headers(self, request: "PreparedRequest") -> Dict[str, str]: resp = self._load_from_cache(request) new_headers = {} if resp: - headers = CaseInsensitiveDict(resp.headers) + headers: CaseInsensitiveDict[str] = CaseInsensitiveDict(resp.headers) if "etag" in headers: new_headers["If-None-Match"] = headers["ETag"] @@ -262,7 +282,14 @@ def conditional_headers(self, request): return new_headers - def _cache_set(self, cache_url, request, response, body=None, expires_time=None): + def _cache_set( + self, + cache_url: str, + request: "PreparedRequest", + response: "HTTPResponse", + body: Optional[bytes] = None, + expires_time: Optional[int] = None, + ) -> None: """ Store the data in the cache. """ @@ -285,7 +312,13 @@ def _cache_set(self, cache_url, request, response, body=None, expires_time=None) expires=expires_time, ) - def cache_response(self, request, response, body=None, status_codes=None): + def cache_response( + self, + request: "PreparedRequest", + response: "HTTPResponse", + body: Optional[bytes] = None, + status_codes: Optional[Collection[int]] = None, + ) -> None: """ Algorithm for caching requests. @@ -300,10 +333,14 @@ def cache_response(self, request, response, body=None, status_codes=None): ) return - response_headers = CaseInsensitiveDict(response.headers) + response_headers: CaseInsensitiveDict[str] = CaseInsensitiveDict( + response.headers + ) if "date" in response_headers: - date = calendar.timegm(parsedate_tz(response_headers["date"])) + time_tuple = parsedate_tz(response_headers["date"]) + assert time_tuple is not None + date = calendar.timegm(time_tuple[:6]) else: date = 0 @@ -322,6 +359,7 @@ def cache_response(self, request, response, body=None, status_codes=None): cc_req = self.parse_cache_control(request.headers) cc = self.parse_cache_control(response_headers) + assert request.url is not None cache_url = self.cache_url(request.url) logger.debug('Updating cache with response from "%s"', cache_url) @@ -354,7 +392,7 @@ def cache_response(self, request, response, body=None, status_codes=None): if response_headers.get("expires"): expires = parsedate_tz(response_headers["expires"]) if expires is not None: - expires_time = calendar.timegm(expires) - date + expires_time = calendar.timegm(expires[:6]) - date expires_time = max(expires_time, 14 * 86400) @@ -372,11 +410,14 @@ def cache_response(self, request, response, body=None, status_codes=None): # is no date header then we can't do anything about expiring # the cache. elif "date" in response_headers: - date = calendar.timegm(parsedate_tz(response_headers["date"])) + time_tuple = parsedate_tz(response_headers["date"]) + assert time_tuple is not None + date = calendar.timegm(time_tuple[:6]) # cache when there is a max-age > 0 - if "max-age" in cc and cc["max-age"] > 0: + max_age = cc.get("max-age") + if max_age is not None and max_age > 0: logger.debug("Caching b/c date exists and max-age > 0") - expires_time = cc["max-age"] + expires_time = max_age self._cache_set( cache_url, request, @@ -391,7 +432,7 @@ def cache_response(self, request, response, body=None, status_codes=None): if response_headers["expires"]: expires = parsedate_tz(response_headers["expires"]) if expires is not None: - expires_time = calendar.timegm(expires) - date + expires_time = calendar.timegm(expires[:6]) - date else: expires_time = None @@ -408,13 +449,16 @@ def cache_response(self, request, response, body=None, status_codes=None): expires_time, ) - def update_cached_response(self, request, response): + def update_cached_response( + self, request: "PreparedRequest", response: "HTTPResponse" + ) -> "HTTPResponse": """On a 304 we will get a new set of headers that we want to update our cached value with, assuming we have one. This should only ever be called when we've sent an ETag and gotten a 304 as the response. """ + assert request.url is not None cache_url = self.cache_url(request.url) cached_response = self._load_from_cache(request) diff --git a/cachecontrol/filewrapper.py b/cachecontrol/filewrapper.py index f5ed5f6f..472ba600 100644 --- a/cachecontrol/filewrapper.py +++ b/cachecontrol/filewrapper.py @@ -2,8 +2,12 @@ # # SPDX-License-Identifier: Apache-2.0 -from tempfile import NamedTemporaryFile import mmap +from tempfile import NamedTemporaryFile +from typing import TYPE_CHECKING, Any, Callable, Optional + +if TYPE_CHECKING: + from http.client import HTTPResponse class CallbackFileWrapper(object): @@ -25,12 +29,14 @@ class CallbackFileWrapper(object): performance impact. """ - def __init__(self, fp, callback): + def __init__( + self, fp: "HTTPResponse", callback: Optional[Callable[[bytes], None]] + ) -> None: self.__buf = NamedTemporaryFile("rb+", delete=True) self.__fp = fp self.__callback = callback - def __getattr__(self, name): + def __getattr__(self, name: str) -> Any: # The vaguaries of garbage collection means that self.__fp is # not always set. By using __getattribute__ and the private # name[0] allows looking up the attribute value and raising an @@ -42,7 +48,7 @@ def __getattr__(self, name): fp = self.__getattribute__("_CallbackFileWrapper__fp") return getattr(fp, name) - def __is_fp_closed(self): + def __is_fp_closed(self) -> bool: try: return self.__fp.fp is None @@ -50,7 +56,8 @@ def __is_fp_closed(self): pass try: - return self.__fp.closed + closed: bool = self.__fp.closed + return closed except AttributeError: pass @@ -59,7 +66,7 @@ def __is_fp_closed(self): # TODO: Add some logging here... return False - def _close(self): + def _close(self) -> None: if self.__callback: if self.__buf.tell() == 0: # Empty file: @@ -86,8 +93,8 @@ def _close(self): # Important when caching big files. self.__buf.close() - def read(self, amt=None): - data = self.__fp.read(amt) + def read(self, amt: Optional[int] = None) -> bytes: + data: bytes = self.__fp.read(amt) if data: # We may be dealing with b'', a sign that things are over: # it's passed e.g. after we've already closed self.__buf. @@ -97,8 +104,8 @@ def read(self, amt=None): return data - def _safe_read(self, amt): - data = self.__fp._safe_read(amt) + def _safe_read(self, amt: int) -> bytes: + data: bytes = self.__fp._safe_read(amt) # type: ignore[attr-defined] if amt == 2 and data == b"\r\n": # urllib executes this read to toss the CRLF at the end # of the chunk. diff --git a/cachecontrol/heuristics.py b/cachecontrol/heuristics.py index ebe4a96f..12aded82 100644 --- a/cachecontrol/heuristics.py +++ b/cachecontrol/heuristics.py @@ -4,26 +4,27 @@ import calendar import time - +from datetime import datetime, timedelta from email.utils import formatdate, parsedate, parsedate_tz +from typing import TYPE_CHECKING, Any, Dict, Mapping, Optional -from datetime import datetime, timedelta +if TYPE_CHECKING: + from .compat import HTTPResponse TIME_FMT = "%a, %d %b %Y %H:%M:%S GMT" -def expire_after(delta, date=None): +def expire_after(delta: timedelta, date: Optional[datetime] = None) -> datetime: date = date or datetime.utcnow() return date + delta -def datetime_to_header(dt): +def datetime_to_header(dt: datetime) -> str: return formatdate(calendar.timegm(dt.timetuple())) class BaseHeuristic(object): - - def warning(self, response): + def warning(self, response: "HTTPResponse") -> Optional[str]: """ Return a valid 1xx warning header value describing the cache adjustments. @@ -34,7 +35,7 @@ def warning(self, response): """ return '110 - "Response is Stale"' - def update_headers(self, response): + def update_headers(self, response: "HTTPResponse") -> Dict[str, str]: """Update the response headers with any new headers. NOTE: This SHOULD always include some Warning header to @@ -43,7 +44,7 @@ def update_headers(self, response): """ return {} - def apply(self, response): + def apply(self, response: "HTTPResponse") -> "HTTPResponse": updated_headers = self.update_headers(response) if updated_headers: @@ -61,7 +62,7 @@ class OneDayCache(BaseHeuristic): future. """ - def update_headers(self, response): + def update_headers(self, response: "HTTPResponse") -> Dict[str, str]: headers = {} if "expires" not in response.headers: @@ -77,14 +78,14 @@ class ExpiresAfter(BaseHeuristic): Cache **all** requests for a defined time period. """ - def __init__(self, **kw): + def __init__(self, **kw: Any) -> None: self.delta = timedelta(**kw) - def update_headers(self, response): + def update_headers(self, response: "HTTPResponse") -> Dict[str, str]: expires = expire_after(self.delta) return {"expires": datetime_to_header(expires), "cache-control": "public"} - def warning(self, response): + def warning(self, response: "HTTPResponse") -> Optional[str]: tmpl = "110 - Automatically cached for %s. Response might be stale" return tmpl % self.delta @@ -101,12 +102,23 @@ class LastModified(BaseHeuristic): http://lxr.mozilla.org/mozilla-release/source/netwerk/protocol/http/nsHttpResponseHead.cpp#397 Unlike mozilla we limit this to 24-hr. """ + cacheable_by_default_statuses = { - 200, 203, 204, 206, 300, 301, 404, 405, 410, 414, 501 + 200, + 203, + 204, + 206, + 300, + 301, + 404, + 405, + 410, + 414, + 501, } - def update_headers(self, resp): - headers = resp.headers + def update_headers(self, resp: "HTTPResponse") -> Dict[str, str]: + headers: Mapping[str, str] = resp.headers if "expires" in headers: return {} @@ -120,9 +132,11 @@ def update_headers(self, resp): if "date" not in headers or "last-modified" not in headers: return {} - date = calendar.timegm(parsedate_tz(headers["date"])) + time_tuple = parsedate_tz(headers["date"]) + assert time_tuple is not None + date = calendar.timegm(time_tuple[:6]) last_modified = parsedate(headers["last-modified"]) - if date is None or last_modified is None: + if last_modified is None: return {} now = time.time() @@ -135,5 +149,5 @@ def update_headers(self, resp): expires = date + freshness_lifetime return {"expires": time.strftime(TIME_FMT, time.gmtime(expires))} - def warning(self, resp): + def warning(self, resp: "HTTPResponse") -> Optional[str]: return None diff --git a/cachecontrol/py.typed b/cachecontrol/py.typed new file mode 100644 index 00000000..e69de29b diff --git a/cachecontrol/serialize.py b/cachecontrol/serialize.py index 664d7c70..446d7d7e 100644 --- a/cachecontrol/serialize.py +++ b/cachecontrol/serialize.py @@ -6,18 +6,22 @@ import io import json import zlib +from typing import IO, TYPE_CHECKING, Any, Mapping, Optional import msgpack from requests.structures import CaseInsensitiveDict from .compat import HTTPResponse, pickle, text_type +if TYPE_CHECKING: + from requests import PreparedRequest, Request -def _b64_decode_bytes(b): + +def _b64_decode_bytes(b: str) -> bytes: return base64.b64decode(b.encode("ascii")) -def _b64_decode_str(s): +def _b64_decode_str(s: str) -> str: return _b64_decode_bytes(s).decode("utf8") @@ -25,8 +29,15 @@ def _b64_decode_str(s): class Serializer(object): - def dumps(self, request, response, body=None): - response_headers = CaseInsensitiveDict(response.headers) + def dumps( + self, + request: "PreparedRequest", + response: HTTPResponse, + body: Optional[bytes] = None, + ) -> bytes: + response_headers: CaseInsensitiveDict[str] = CaseInsensitiveDict( + response.headers + ) if body is None: # When a body isn't passed in, we'll read the response. We @@ -36,40 +47,38 @@ def dumps(self, request, response, body=None): response._fp = io.BytesIO(body) response.length_remaining = len(body) - # NOTE: This is all a bit weird, but it's really important that on - # Python 2.x these objects are unicode and not str, even when - # they contain only ascii. The problem here is that msgpack - # understands the difference between unicode and bytes and we - # have it set to differentiate between them, however Python 2 - # doesn't know the difference. Forcing these to unicode will be - # enough to have msgpack know the difference. data = { - u"response": { - u"body": body, # Empty bytestring if body is stored separately - u"headers": dict( + "response": { + "body": body, # Empty bytestring if body is stored separately + "headers": dict( (text_type(k), text_type(v)) for k, v in response.headers.items() ), - u"status": response.status, - u"version": response.version, - u"reason": text_type(response.reason), - u"decode_content": response.decode_content, + "status": response.status, + "version": response.version, + "reason": text_type(response.reason), + "decode_content": response.decode_content, } } # Construct our vary headers - data[u"vary"] = {} - if u"vary" in response_headers: - varied_headers = response_headers[u"vary"].split(",") + data["vary"] = {} + if "vary" in response_headers: + varied_headers = response_headers["vary"].split(",") for header in varied_headers: header = text_type(header).strip() header_value = request.headers.get(header, None) if header_value is not None: header_value = text_type(header_value) - data[u"vary"][header] = header_value + data["vary"][header] = header_value return b",".join([b"cc=4", msgpack.dumps(data, use_bin_type=True)]) - def loads(self, request, data, body_file=None): + def loads( + self, + request: "PreparedRequest", + data: bytes, + body_file: Optional["IO[bytes]"] = None, + ) -> HTTPResponse: # Short circuit if we've been given an empty set of data if not data: return @@ -88,18 +97,23 @@ def loads(self, request, data, body_file=None): ver = b"cc=0" # Get the version number out of the cc=N - ver = ver.split(b"=", 1)[-1].decode("ascii") + verstr = ver.split(b"=", 1)[-1].decode("ascii") # Dispatch to the actual load method for the given version try: - return getattr(self, "_loads_v{}".format(ver))(request, data, body_file) + return getattr(self, "_loads_v{}".format(verstr))(request, data, body_file) except AttributeError: # This is a version we don't have a loads function for, so we'll # just treat it as a miss and return None return - def prepare_response(self, request, cached, body_file=None): + def prepare_response( + self, + request: "Request", + cached: Mapping[str, Any], + body_file: Optional["IO[bytes]"] = None, + ) -> Optional[HTTPResponse]: """Verify our vary headers match and construct a real urllib3 HTTPResponse object. """ @@ -108,23 +122,26 @@ def prepare_response(self, request, cached, body_file=None): # This case is also handled in the controller code when creating # a cache entry, but is left here for backwards compatibility. if "*" in cached.get("vary", {}): - return + return None # Ensure that the Vary headers for the cached response match our # request for header, value in cached.get("vary", {}).items(): if request.headers.get(header, None) != value: - return + return None body_raw = cached["response"].pop("body") - headers = CaseInsensitiveDict(data=cached["response"]["headers"]) + headers: CaseInsensitiveDict[str] = CaseInsensitiveDict( + data=cached["response"]["headers"] + ) if headers.get("transfer-encoding", "") == "chunked": headers.pop("transfer-encoding") cached["response"]["headers"] = headers try: + body: "IO[bytes]" if body_file is None: body = io.BytesIO(body_raw) else: @@ -143,26 +160,41 @@ def prepare_response(self, request, cached, body_file=None): return HTTPResponse(body=body, preload_content=False, **cached["response"]) - def _loads_v0(self, request, data, body_file=None): + def _loads_v0( + self, + request: "Request", + data: bytes, + body_file: Optional["IO[bytes]"] = None, + ) -> None: # The original legacy cache data. This doesn't contain enough # information to construct everything we need, so we'll treat this as # a miss. return - def _loads_v1(self, request, data, body_file=None): + def _loads_v1( + self, + request: "Request", + data: bytes, + body_file: Optional["IO[bytes]"] = None, + ) -> Optional[HTTPResponse]: try: cached = pickle.loads(data) except ValueError: - return + return None return self.prepare_response(request, cached, body_file) - def _loads_v2(self, request, data, body_file=None): + def _loads_v2( + self, + request: "Request", + data: bytes, + body_file: Optional["IO[bytes]"] = None, + ) -> Optional[HTTPResponse]: assert body_file is None try: cached = json.loads(zlib.decompress(data).decode("utf8")) except (ValueError, zlib.error): - return + return None # We need to decode the items that we've base64 encoded cached["response"]["body"] = _b64_decode_bytes(cached["response"]["body"]) @@ -178,16 +210,26 @@ def _loads_v2(self, request, data, body_file=None): return self.prepare_response(request, cached, body_file) - def _loads_v3(self, request, data, body_file): + def _loads_v3( + self, + request: "Request", + data: bytes, + body_file: Optional["IO[bytes]"] = None, + ) -> None: # Due to Python 2 encoding issues, it's impossible to know for sure # exactly how to load v3 entries, thus we'll treat these as a miss so # that they get rewritten out as v4 entries. return - def _loads_v4(self, request, data, body_file=None): + def _loads_v4( + self, + request: "Request", + data: bytes, + body_file: Optional["IO[bytes]"] = None, + ) -> Optional[HTTPResponse]: try: cached = msgpack.loads(data, raw=False) except ValueError: - return + return None return self.prepare_response(request, cached, body_file) diff --git a/cachecontrol/wrapper.py b/cachecontrol/wrapper.py index b6ee7f20..a1064418 100644 --- a/cachecontrol/wrapper.py +++ b/cachecontrol/wrapper.py @@ -2,20 +2,30 @@ # # SPDX-License-Identifier: Apache-2.0 +from typing import TYPE_CHECKING, Collection, Optional, Type + from .adapter import CacheControlAdapter from .cache import DictCache +if TYPE_CHECKING: + import requests + + from .cache import BaseCache + from .controller import CacheController + from .heuristics import BaseHeuristic + from .serialize import Serializer + def CacheControl( - sess, - cache=None, - cache_etags=True, - serializer=None, - heuristic=None, - controller_class=None, - adapter_class=None, - cacheable_methods=None, -): + sess: "requests.Session", + cache: Optional["BaseCache"] = None, + cache_etags: bool = True, + serializer: Optional["Serializer"] = None, + heuristic: Optional["BaseHeuristic"] = None, + controller_class: Optional[Type["CacheController"]] = None, + adapter_class: Optional[Type[CacheControlAdapter]] = None, + cacheable_methods: Optional[Collection[str]] = None, +) -> "requests.Session": cache = DictCache() if cache is None else cache adapter_class = adapter_class or CacheControlAdapter diff --git a/dev_requirements.txt b/dev_requirements.txt index 46d00b04..557342c2 100644 --- a/dev_requirements.txt +++ b/dev_requirements.txt @@ -16,3 +16,6 @@ bumpversion twine black wheel +mypy +types-redis +types-requests diff --git a/setup.cfg b/setup.cfg index 53862d7c..18359265 100644 --- a/setup.cfg +++ b/setup.cfg @@ -11,3 +11,14 @@ norecursedirs = bin lib include build [bdist_wheel] universal = 1 + +[mypy] +show_error_codes = true +strict = true +enable_error_code = ignore-without-code,redundant-expr,truthy-bool + +[mypy-cachecontrol.compat] +ignore_errors = true + +[mypy-msgpack] +ignore_missing_imports = true diff --git a/setup.py b/setup.py index efda30e4..cf1f0ab8 100644 --- a/setup.py +++ b/setup.py @@ -16,7 +16,7 @@ url="https://github.com/ionrock/cachecontrol", keywords="requests http caching web", packages=setuptools.find_packages(exclude=["tests", "tests.*"]), - package_data={"": ["LICENSE.txt"]}, + package_data={"": ["LICENSE.txt"], "cachecontrol": ["py.typed"]}, package_dir={"cachecontrol": "cachecontrol"}, include_package_data=True, description="httplib2 caching for requests", diff --git a/tox.ini b/tox.ini index d3768132..47fb6e0d 100644 --- a/tox.ini +++ b/tox.ini @@ -20,3 +20,11 @@ deps = pytest redis filelock commands = py.test {posargs:tests/} + +[testenv:mypy] +deps = + {[testenv]deps} + mypy + types-redis + types-requests +commands = mypy {posargs:cachecontrol}