diff --git a/src/cryptojwt/key_bundle.py b/src/cryptojwt/key_bundle.py index f4d7b0dd..162d38c1 100755 --- a/src/cryptojwt/key_bundle.py +++ b/src/cryptojwt/key_bundle.py @@ -180,6 +180,7 @@ def __init__(self, keys=None, source="", cache_time=300, verify_ssl=True, self._keys = [] self.remote = False + self.local = False self.cache_time = cache_time self.time_out = 0 self.etag = "" @@ -189,6 +190,8 @@ def __init__(self, keys=None, source="", cache_time=300, verify_ssl=True, self.keyusage = keyusage self.imp_jwks = None self.last_updated = 0 + self.last_remote = None # HTTP Date of last remote update + self.last_local = None # UNIX timestamp of last local update if httpc: self.httpc = httpc @@ -208,13 +211,13 @@ def __init__(self, keys=None, source="", cache_time=300, verify_ssl=True, self.do_keys(keys) else: self._set_source(source, fileformat) - - if not self.remote and self.source: # local file + if self.local: self._do_local(kid) def _set_source(self, source, fileformat): if source.startswith("file://"): self.source = source[7:] + self.local = True elif source.startswith("http://") or source.startswith("https://"): self.source = source self.remote = True @@ -224,6 +227,7 @@ def _set_source(self, source, fileformat): if fileformat.lower() in ['rsa', 'der', 'jwks']: if os.path.isfile(source): self.source = source + self.local = True else: raise ImportError('No such file') else: @@ -235,6 +239,16 @@ def _do_local(self, kid): elif self.fileformat == "der": self.do_local_der(self.source, self.keytype, self.keyusage, kid) + def _local_update_required(self) -> bool: + stat = os.stat(self.source) + if self.last_local and stat.st_mtime < self.last_local: + LOGGER.debug("%s not modfied", self.source) + return False + else: + LOGGER.debug("%s modfied", self.source) + self.last_local = stat.st_mtime + return True + def do_keys(self, keys): """ Go from JWK description to binary keys @@ -290,12 +304,15 @@ def do_local_jwk(self, filename): :param filename: Name of the file from which the JWKS should be loaded """ + LOGGER.debug("Reading JWKS from %s", filename) with open(filename) as input_file: _info = json.load(input_file) if 'keys' in _info: self.do_keys(_info["keys"]) else: self.do_keys([_info]) + self.last_local = time.time() + self.time_out = self.last_local + self.cache_time def do_local_der(self, filename, keytype, keyusage=None, kid=''): """ @@ -305,6 +322,7 @@ def do_local_der(self, filename, keytype, keyusage=None, kid=''): :param keytype: Presently 'rsa' and 'ec' supported :param keyusage: encryption ('enc') or signing ('sig') or both """ + LOGGER.debug("Reading DER from %s", filename) key_args = {} _kty = keytype.lower() if _kty in ['rsa', 'ec']: @@ -324,6 +342,8 @@ def do_local_der(self, filename, keytype, keyusage=None, kid=''): key_args['kid'] = kid self.do_keys([key_args]) + self.last_local = time.time() + self.time_out = self.last_local + self.cache_time def do_remote(self): """ @@ -336,6 +356,10 @@ def do_remote(self): try: LOGGER.debug('KeyBundle fetch keys from: %s', self.source) + if self.last_remote is not None: + if "headers" not in self.httpc_params: + self.httpc_params["headers"] = {} + self.httpc_params["headers"]["If-Modified-Since"] = self.last_remote _http_resp = self.httpc('GET', self.source, **self.httpc_params) except Exception as err: LOGGER.error(err) @@ -357,6 +381,14 @@ def do_remote(self): LOGGER.error("No 'keys' keyword in JWKS") raise UpdateFailed(MALFORMED.format(self.source)) + if hasattr(_http_resp, "headers"): + headers = getattr(_http_resp, "headers") + self.last_remote = headers.get("last-modified") or headers.get("date") + + elif _http_resp.status_code == 304: # Not modified + LOGGER.debug("%s not modified since %s", self.source, self.last_remote) + pass + else: raise UpdateFailed( REMOTE_FAILED.format(self.source, _http_resp.status_code)) @@ -387,14 +419,12 @@ def _parse_remote_response(self, response): def _uptodate(self): res = False - if not self._keys: - if self.remote: # verify that it's not to old - if time.time() > self.time_out: - if self.update(): - res = True - elif self.remote: - if self.update(): - res = True + if self.remote or self.local: + if time.time() > self.time_out: + if self.local and not self._local_update_required(): + res = True + elif self.update(): + res = True return res def update(self): @@ -412,13 +442,13 @@ def update(self): self._keys = [] try: - if self.remote is False: + if self.local: if self.fileformat in ["jwks", "jwk"]: self.do_local_jwk(self.source) elif self.fileformat == "der": self.do_local_der(self.source, self.keytype, self.keyusage) - else: + elif self.remote: res = self.do_remote() except Exception as err: LOGGER.error('Key bundle update failed: %s', err) @@ -661,8 +691,11 @@ def dump(self): "keys": _keys, "fileformat": self.fileformat, "last_updated": self.last_updated, + "last_remote": self.last_remote, + "last_local": self.last_local, "httpc_params": self.httpc_params, "remote": self.remote, + "local": self.local, "imp_jwks": self.imp_jwks, "time_out": self.time_out, "cache_time": self.cache_time @@ -680,7 +713,10 @@ def load(self, spec): self.source = spec.get("source", None) self.fileformat = spec.get("fileformat", "jwks") self.last_updated = spec.get("last_updated", 0) + self.last_remote = spec.get("last_remote", None) + self.last_local = spec.get("last_local", None) self.remote = spec.get("remote", False) + self.local = spec.get("local", False) self.imp_jwks = spec.get('imp_jwks', None) self.time_out = spec.get('time_out', 0) self.cache_time = spec.get('cache_time', 0) diff --git a/tests/test_03_key_bundle.py b/tests/test_03_key_bundle.py index b10e70bf..8a6733d4 100755 --- a/tests/test_03_key_bundle.py +++ b/tests/test_03_key_bundle.py @@ -938,7 +938,10 @@ def test_export_inactive(): 'imp_jwks', 'keys', 'last_updated', + 'last_remote', + 'last_local', 'remote', + 'local', 'time_out'} kb2 = KeyBundle().load(res)