diff --git a/src/cryptojwt/key_bundle.py b/src/cryptojwt/key_bundle.py index 9b2f200a..3f8205d7 100755 --- a/src/cryptojwt/key_bundle.py +++ b/src/cryptojwt/key_bundle.py @@ -4,6 +4,7 @@ import logging import os import time +from datetime import datetime from functools import cmp_to_key import requests @@ -156,6 +157,7 @@ def __init__( keys=None, source="", cache_time=300, + ignore_errors_period=0, fileformat="jwks", keytype="RSA", keyusage=None, @@ -188,6 +190,8 @@ def __init__( self.remote = False self.local = False self.cache_time = cache_time + self.ignore_errors_period = ignore_errors_period + self.ignore_errors_until = None # UNIX timestamp of last error self.time_out = 0 self.etag = "" self.source = None @@ -365,6 +369,14 @@ def do_remote(self): # if self.verify_ssl is not None: # self.httpc_params["verify"] = self.verify_ssl + if self.ignore_errors_until and time.time() < self.ignore_errors_until: + LOGGER.warning( + "Not reading remote JWKS from %s (in error holddown until %s)", + self.source, + datetime.fromtimestamp(self.ignore_errors_until), + ) + return False + LOGGER.info("Reading remote JWKS from %s", self.source) try: LOGGER.debug("KeyBundle fetch keys from: %s", self.source) @@ -390,6 +402,7 @@ def do_remote(self): self.do_keys(self.imp_jwks["keys"]) except KeyError: LOGGER.error("No 'keys' keyword in JWKS") + self.ignore_errors_until = time.time() + self.ignore_errors_period raise UpdateFailed(MALFORMED.format(self.source)) if hasattr(_http_resp, "headers"): @@ -406,8 +419,11 @@ def do_remote(self): _http_resp.status_code, self.source, ) + self.ignore_errors_until = time.time() + self.ignore_errors_period raise UpdateFailed(REMOTE_FAILED.format(self.source, _http_resp.status_code)) + self.last_updated = time.time() + self.ignore_errors_until = None return True def _parse_remote_response(self, response): diff --git a/tests/test_03_key_bundle.py b/tests/test_03_key_bundle.py index 7d25b392..35fbdff2 100755 --- a/tests/test_03_key_bundle.py +++ b/tests/test_03_key_bundle.py @@ -17,6 +17,7 @@ from cryptojwt.jwk.rsa import import_rsa_key_from_cert_file from cryptojwt.jwk.rsa import new_rsa_key from cryptojwt.key_bundle import KeyBundle +from cryptojwt.key_bundle import UpdateFailed from cryptojwt.key_bundle import build_key_bundle from cryptojwt.key_bundle import dump_jwks from cryptojwt.key_bundle import init_key @@ -1024,3 +1025,43 @@ def test_remote_not_modified(): assert kb2.httpc_params == {"timeout": (2, 2)} assert kb2.imp_jwks assert kb2.last_updated + + +def test_ignore_errors_period(): + source_good = "https://example.com/keys.json" + source_bad = "https://example.com/keys-bad.json" + ignore_errors_period = 1 + # Mock response + with responses.RequestsMock() as rsps: + rsps.add(method="GET", url=source_good, json=JWKS_DICT, status=200) + rsps.add(method="GET", url=source_bad, json=JWKS_DICT, status=500) + httpc_params = {"timeout": (2, 2)} # connect, read timeouts in seconds + kb = KeyBundle( + source=source_good, + httpc=requests.request, + httpc_params=httpc_params, + ignore_errors_period=ignore_errors_period, + ) + res = kb.do_remote() + assert res == True + assert kb.ignore_errors_until is None + + # refetch, but fail by using a bad source + kb.source = source_bad + try: + res = kb.do_remote() + except UpdateFailed: + pass + + # retry should fail silently as we're in holddown + res = kb.do_remote() + assert kb.ignore_errors_until is not None + assert res == False + + # wait until holddown + time.sleep(ignore_errors_period + 1) + + # try again + kb.source = source_good + res = kb.do_remote() + assert res == True