diff --git a/pyproject.toml b/pyproject.toml index 4527577c..01b995c3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,7 +22,7 @@ exclude_lines = [ [tool.poetry] name = "cryptojwt" -version = "1.4.1" +version = "1.5.0" description = "Python implementation of JWT, JWE, JWS and JWK" authors = ["Roland Hedberg "] license = "Apache-2.0" diff --git a/src/cryptojwt/key_bundle.py b/src/cryptojwt/key_bundle.py index c01211c9..34b44205 100755 --- a/src/cryptojwt/key_bundle.py +++ b/src/cryptojwt/key_bundle.py @@ -6,6 +6,8 @@ import time from datetime import datetime from functools import cmp_to_key +from typing import List +from typing import Optional import requests @@ -24,7 +26,6 @@ from .jwk.jwk import dump_jwk from .jwk.jwk import import_jwk from .jwk.rsa import RSAKey -from .jwk.rsa import import_private_rsa_key_from_file from .jwk.rsa import new_rsa_key from .utils import as_unicode @@ -152,6 +153,26 @@ def ec_init(spec): class KeyBundle: """The Key Bundle""" + params = { + "cache_time": 0, + "etag": "", + "fileformat": "jwks", + "httpc_params": {}, + "ignore_errors_period": 0, + "ignore_errors_until": None, + "ignore_invalid_keys": True, + "imp_jwks": None, + "keytype": "RSA", + "keyusage": None, + "last_local": None, + "last_remote": None, + "last_updated": 0, + "local": False, + "remote": False, + "source": None, + "time_out": 0, + } + def __init__( self, keys=None, @@ -189,22 +210,22 @@ def __init__( """ self._keys = [] - self.remote = False - self.local = False self.cache_time = cache_time - self.ignore_errors_period = ignore_errors_period - self.ignore_errors_until = None # UNIX timestamp of last error - self.time_out = 0 self.etag = "" - self.source = None self.fileformat = fileformat.lower() + self.ignore_errors_period = ignore_errors_period + self.ignore_errors_until = None # UNIX timestamp of last error + self.ignore_invalid_keys = ignore_invalid_keys + self.imp_jwks = None self.keytype = keytype self.keyusage = keyusage - self.imp_jwks = None - self.last_updated = 0 - self.last_remote = None # HTTP Date of last remote update self.last_local = None # UNIX timestamp of last local update - self.ignore_invalid_keys = ignore_invalid_keys + self.last_remote = None # HTTP Date of last remote update + self.last_updated = 0 + self.local = False + self.remote = False + self.source = None + self.time_out = 0 if httpc: self.httpc = httpc @@ -490,6 +511,7 @@ def update(self): # reread everything self._keys = [] + updated = None try: if self.local: @@ -751,48 +773,68 @@ def difference(self, bundle): return [k for k in self._keys if k not in bundle] - def dump(self): - _keys = [] - for _k in self._keys: - _ser = _k.to_dict() - if _k.inactive_since: - _ser["inactive_since"] = _k.inactive_since - _keys.append(_ser) - - res = { - "keys": _keys, - "fileformat": self.fileformat, - "last_updated": self.last_updated, - "last_remote": self.last_remote, - "last_local": self.last_local, - "httpc_params": self.httpc_params, - "remote": self.remote, - "local": self.local, - "imp_jwks": self.imp_jwks, - "time_out": self.time_out, - "cache_time": self.cache_time, - } + def dump(self, exclude_attributes: Optional[List[str]] = None): + if exclude_attributes is None: + exclude_attributes = [] - if self.source: - res["source"] = self.source + res = {} + + if "keys" not in exclude_attributes: + _keys = [] + for _k in self._keys: + _ser = _k.to_dict() + if _k.inactive_since: + _ser["inactive_since"] = _k.inactive_since + _keys.append(_ser) + res["keys"] = _keys + + for attr, default in self.params.items(): + if attr in exclude_attributes: + continue + val = getattr(self, attr) + res[attr] = val return res def load(self, spec): + """ + Sets attributes according to a specification. + Does not overwrite an existing attributes value with a default value. + + :param spec: Dictionary with attributes and value to populate the instance with + :return: The instance itself + """ _keys = spec.get("keys", []) if _keys: self.do_keys(_keys) - self.source = spec.get("source", None) - self.fileformat = spec.get("fileformat", "jwks") - self.last_updated = spec.get("last_updated", 0) - self.last_remote = spec.get("last_remote", None) - self.last_local = spec.get("last_local", None) - self.remote = spec.get("remote", False) - self.local = spec.get("local", False) - self.imp_jwks = spec.get("imp_jwks", None) - self.time_out = spec.get("time_out", 0) - self.cache_time = spec.get("cache_time", 0) - self.httpc_params = spec.get("httpc_params", {}) + + for attr, default in self.params.items(): + val = spec.get(attr) + if val: + setattr(self, attr, val) + + return self + + def flush(self): + self._keys = [] + self.cache_time = (300,) + self.etag = "" + self.fileformat = "jwks" + # self.httpc=None, + self.httpc_params = (None,) + self.ignore_errors_period = 0 + self.ignore_errors_until = None + self.ignore_invalid_keys = True + self.imp_jwks = None + self.keytype = ("RSA",) + self.keyusage = (None,) + self.last_local = None # UNIX timestamp of last local update + self.last_remote = None # HTTP Date of last remote update + self.last_updated = 0 + self.local = False + self.remote = False + self.source = None + self.time_out = 0 return self @@ -1246,3 +1288,19 @@ def init_key(filename, type, kid="", **kwargs): _new_key = key_gen(type, kid=kid, **kwargs) dump_jwk(filename, _new_key) return _new_key + + +def key_by_alg(alg: str): + if alg.startswith("RS"): + return key_gen("RSA", alg="RS256") + elif alg.startswith("ES"): + if alg == "ES256": + return key_gen("EC", crv="P-256") + elif alg == "ES384": + return key_gen("EC", crv="P-384") + elif alg == "ES512": + return key_gen("EC", crv="P-521") + elif alg.startswith("HS"): + return key_gen("sym") + + raise ValueError("Don't know who to create a key to use with '{}'".format(alg)) diff --git a/src/cryptojwt/key_issuer.py b/src/cryptojwt/key_issuer.py index 0e02e1ac..42a262ee 100755 --- a/src/cryptojwt/key_issuer.py +++ b/src/cryptojwt/key_issuer.py @@ -1,6 +1,8 @@ import json import logging import os +from typing import List +from typing import Optional from requests import request @@ -15,13 +17,21 @@ __author__ = "Roland Hedberg" - logger = logging.getLogger(__name__) class KeyIssuer(object): """ A key issuer instance contains a number of KeyBundles. """ + params = { + "ca_certs": None, + "httpc_params": None, + "keybundle_cls": KeyBundle, + "name": "", + "remove_after": 3600, + "spec2key": None, + } + def __init__( self, ca_certs=None, @@ -45,14 +55,13 @@ def __init__( self._bundles = [] - self.keybundle_cls = keybundle_cls - self.name = name - - self.spec2key = {} self.ca_certs = ca_certs - self.remove_after = remove_after self.httpc = httpc or request self.httpc_params = httpc_params or {} + self.keybundle_cls = keybundle_cls + self.name = name + self.remove_after = remove_after + self.spec2key = {} def __repr__(self) -> str: return ''.format(self.name, self.key_summary()) @@ -350,43 +359,57 @@ def __len__(self): nr += len(kb) return nr - def dump(self, exclude=None): + def dump(self, exclude_attributes: Optional[List[str]] = None) -> dict: """ Returns the content as a dictionary. + :param exclude_attributes: List of attribute names for objects that should be ignored. :return: A dictionary """ - _bundles = [] - for kb in self._bundles: - _bundles.append(kb.dump()) - - info = { - "name": self.name, - "bundles": _bundles, - "keybundle_cls": qualified_name(self.keybundle_cls), - "spec2key": self.spec2key, - "ca_certs": self.ca_certs, - "remove_after": self.remove_after, - "httpc_params": self.httpc_params, - } + if exclude_attributes is None: + exclude_attributes = [] + + info = {} + for attr, default in self.params.items(): + if attr in exclude_attributes: + continue + val = getattr(self, attr) + if attr == "keybundle_cls": + val = qualified_name(val) + info[attr] = val + + if "bundles" not in exclude_attributes: + _bundles = [] + for kb in self._bundles: + _bundles.append(kb.dump(exclude_attributes=exclude_attributes)) + info["bundles"] = _bundles + return info def load(self, info): """ - :param items: A list with the information + :param items: A dictionary with the information to load :return: """ - self.name = info["name"] - self.keybundle_cls = importer(info["keybundle_cls"]) - self.spec2key = info["spec2key"] - self.ca_certs = info["ca_certs"] - self.remove_after = info["remove_after"] - self.httpc_params = info["httpc_params"] + for attr, default in self.params.items(): + val = info.get(attr) + if val: + if attr == "keybundle_cls": + val = importer(val) + setattr(self, attr, val) + self._bundles = [KeyBundle().load(val) for val in info["bundles"]] return self + def flush(self): + for attr, default in self.params.items(): + setattr(self, attr, default) + + self._bundles = [] + return self + def update(self): for kb in self._bundles: kb.update() diff --git a/src/cryptojwt/key_jar.py b/src/cryptojwt/key_jar.py index 4b58bfe7..c6fee93b 100755 --- a/src/cryptojwt/key_jar.py +++ b/src/cryptojwt/key_jar.py @@ -32,7 +32,6 @@ def __init__( remove_after=3600, httpc=None, httpc_params=None, - storage=None, ): """ KeyJar init function @@ -43,15 +42,9 @@ def __init__( :param remove_after: How long keys marked as inactive will remain in the key Jar. :param httpc: A HTTP client to use. Default is Requests request. :param httpc_params: HTTP request parameters - :param storage: An instance that can store information. It basically look like dictionary. :return: Keyjar instance """ - - if storage is None: - self._issuers = {} - else: - self._issuers = storage - + self._issuers = {} self.spec2key = {} self.ca_certs = ca_certs self.keybundle_cls = keybundle_cls @@ -617,8 +610,6 @@ def copy(self): """ Make deep copy of the content of this key jar. - Note that if this key jar uses an external storage module the copy will not. - :return: A :py:class:`oidcmsg.key_jar.KeyJar` instance """ @@ -635,44 +626,100 @@ def copy(self): def __len__(self): return len(self._issuers) - def dump(self, exclude=None): + def _dump_issuers( + self, + exclude_issuers: Optional[List[str]] = None, + exclude_attributes: Optional[List[str]] = None, + ): + _issuers = {} + for _id, _issuer in self._issuers.items(): + if exclude_issuers and _issuer.name in exclude_issuers: + continue + _issuers[_id] = _issuer.dump(exclude_attributes=exclude_attributes) + return _issuers + + def dump( + self, + exclude_issuers: Optional[List[str]] = None, + exclude_attributes: Optional[List[str]] = None, + ) -> dict: """ Returns the key jar content as dictionary + :param exclude_issuers: A list of issuers you don't want included. + :param exclude_attributes: list of attribute names that should be ignored when dumping. + :type exclude_attributes: list :return: A dictionary """ info = { - "spec2key": self.spec2key, "ca_certs": self.ca_certs, + "httpc_params": self.httpc_params, "keybundle_cls": qualified_name(self.keybundle_cls), "remove_after": self.remove_after, - "httpc_params": self.httpc_params, + "spec2key": self.spec2key, } - _issuers = {} - for _id, _issuer in self._issuers.items(): - if exclude and _issuer.name in exclude: - continue - _issuers[_id] = _issuer.dump() - info["issuers"] = _issuers + if exclude_attributes: + for attr in exclude_attributes: + try: + del info[attr] + except KeyError: + pass + + if exclude_attributes is None: + info["issuers"] = self._dump_issuers( + exclude_issuers=exclude_issuers, exclude_attributes=exclude_attributes + ) + elif "issuers" not in exclude_attributes: + info["issuers"] = self._dump_issuers( + exclude_issuers=exclude_issuers, exclude_attributes=exclude_attributes + ) return info + def dumps(self, exclude_issuers: Optional[List[str]] = None): + """ + Returns a JSON representation of the key jar + + :param exclude_issuers: Exclude these issuers + :return: A string + """ + _dict = self.dump(exclude_issuers=exclude_issuers) + return json.dumps(_dict) + def load(self, info): """ :param info: A dictionary with the information :return: """ - self.spec2key = info["spec2key"] - self.ca_certs = info["ca_certs"] - self.keybundle_cls = importer(info["keybundle_cls"]) - self.remove_after = info["remove_after"] - self.httpc_params = info["httpc_params"] + self.ca_certs = info.get("ca_certs", None) + self.httpc_params = info.get("httpc_params", None) + self.keybundle_cls = importer(info.get("keybundle_cls", KeyBundle)) + self.remove_after = info.get("remove_after", 3600) + self.spec2key = info.get("spec2key", {}) + + _issuers = info.get("issuers", None) + if _issuers is None: + self._issuers = {} + else: + for _issuer_id, _issuer_desc in _issuers.items(): + self._issuers[_issuer_id] = KeyIssuer().load(_issuer_desc) + return self + + def loads(self, string): + return self.load(json.loads(string)) + + def flush(self): + self.ca_certs = None + self.httpc_params = None + self._issuers = {} + self.keybundle_cls = KeyBundle + self.remove_after = 3600 + self.spec2key = {} + # self.httpc=None, - for _issuer_id, _issuer_desc in info["issuers"].items(): - self._issuers[_issuer_id] = KeyIssuer().load(_issuer_desc) return self @deprecated_alias(issuer="issuer_id", owner="issuer_id") @@ -705,7 +752,7 @@ def rotate_keys(self, key_conf, kid_template="", issuer_id=""): # ============================================================================= -def build_keyjar(key_conf, kid_template="", keyjar=None, issuer_id="", storage=None): +def build_keyjar(key_conf, kid_template="", keyjar=None, issuer_id=""): """ Builds a :py:class:`oidcmsg.key_jar.KeyJar` instance or adds keys to an existing KeyJar based on a key specification. @@ -744,7 +791,6 @@ def build_keyjar(key_conf, kid_template="", keyjar=None, issuer_id="", storage=N kid_template is given then the built-in function add_kid() will be used. :param keyjar: If an KeyJar instance the new keys are added to this key jar. :param issuer_id: The default owner of the keys in the key jar. - :param storage: A Storage instance. :return: A KeyJar instance """ @@ -753,7 +799,7 @@ def build_keyjar(key_conf, kid_template="", keyjar=None, issuer_id="", storage=N return None if keyjar is None: - keyjar = KeyJar(storage=storage) + keyjar = KeyJar() keyjar[issuer_id] = _issuer @@ -767,7 +813,6 @@ def init_key_jar( key_defs="", issuer_id="", read_only=True, - storage=None, ): """ A number of cases here: @@ -805,7 +850,6 @@ def init_key_jar( :param key_defs: A definition of what keys should be created if they are not already available :param issuer_id: The owner of the keys :param read_only: This function should not attempt to write anything to a file system. - :param storage: A Storage instance. :return: An instantiated :py:class;`oidcmsg.key_jar.KeyJar` instance """ @@ -819,7 +863,7 @@ def init_key_jar( if _issuer is None: raise ValueError("Could not find any keys") - keyjar = KeyJar(storage=storage) + keyjar = KeyJar() keyjar[issuer_id] = _issuer return keyjar diff --git a/src/cryptojwt/serialize/item.py b/src/cryptojwt/serialize/item.py index a087c8dc..206eb7de 100644 --- a/src/cryptojwt/serialize/item.py +++ b/src/cryptojwt/serialize/item.py @@ -7,7 +7,7 @@ class KeyIssuer: @staticmethod def serialize(item: key_issuer.KeyIssuer) -> str: """ Convert from KeyIssuer to JSON """ - return json.dumps(item.dump()) + return json.dumps(item.dump(exclude_attributes=["keybundle_cls"])) def deserialize(self, spec: str) -> key_issuer.KeyIssuer: """ Convert from JSON to KeyIssuer """ diff --git a/tests/test_03_key_bundle.py b/tests/test_03_key_bundle.py index bfd4f958..d32af9a9 100755 --- a/tests/test_03_key_bundle.py +++ b/tests/test_03_key_bundle.py @@ -953,15 +953,22 @@ def test_export_inactive(): res = kb.dump() assert set(res.keys()) == { "cache_time", + "etag", "fileformat", "httpc_params", + "ignore_errors_until", + "ignore_errors_period", + "ignore_invalid_keys", "imp_jwks", "keys", + "keytype", + "keyusage", "last_updated", "last_remote", "last_local", "remote", "local", + "source", "time_out", } @@ -1079,3 +1086,31 @@ def test_ignore_invalid_keys(): with pytest.raises(UnknownKeyType): KeyBundle(keys={"keys": [rsa_key_dict]}, ignore_invalid_keys=False) + + +def test_exclude_attributes(): + source = "https://example.com/keys.json" + # Mock response + with responses.RequestsMock() as rsps: + rsps.add(method="GET", url=source, json=JWKS_DICT, status=200) + httpc_params = {"timeout": (2, 2)} # connect, read timeouts in seconds + kb = KeyBundle(source=source, httpc=requests.request, httpc_params=httpc_params) + kb.do_remote() + + exp = kb.dump(exclude_attributes=["cache_time", "ignore_invalid_keys"]) + kb2 = KeyBundle(cache_time=600, ignore_invalid_keys=False).load(exp) + assert kb2.cache_time == 600 + assert kb2.ignore_invalid_keys is False + + +def test_remote_dump_json(): + source = "https://example.com/keys.json" + # Mock response + with responses.RequestsMock() as rsps: + rsps.add(method="GET", url=source, json=JWKS_DICT, status=200) + httpc_params = {"timeout": (2, 2)} # connect, read timeouts in seconds + kb = KeyBundle(source=source, httpc=requests.request, httpc_params=httpc_params) + kb.do_remote() + + exp = kb.dump() + assert json.dumps(exp) diff --git a/tests/test_04_key_issuer.py b/tests/test_04_key_issuer.py index fedad978..7a4ca372 100755 --- a/tests/test_04_key_issuer.py +++ b/tests/test_04_key_issuer.py @@ -1,3 +1,4 @@ +import json import os import shutil import time @@ -743,6 +744,14 @@ def test_dump(): assert nkj.get("sig", "rsa", kid="MnC_VZcATfM5pOYiJHMba9goEKY") +def test_dump_json(): + issuer = KeyIssuer() + issuer.add_kb(KeyBundle(JWK2["keys"])) + + res = issuer.dump() + assert json.dumps(res) + + def test_contains(): issuer = KeyIssuer() issuer.add_kb(KeyBundle(JWK1["keys"])) diff --git a/tests/test_04_key_jar.py b/tests/test_04_key_jar.py index 53cb5ef5..1a472b6a 100755 --- a/tests/test_04_key_jar.py +++ b/tests/test_04_key_jar.py @@ -470,7 +470,7 @@ def test_provider(self): _msg = "{} is not available at this moment!".format(_url) warnings.warn(_msg) else: - assert iss_kes[0].keys() + assert iss_keys[0].keys() def test_import_jwks(): @@ -1033,6 +1033,16 @@ def test_dump(): assert nkj.get_signing_key("rsa", "C", kid="MnC_VZcATfM5pOYiJHMba9goEKY") +def test_dump_json(): + kj = KeyJar() + kj.add_kb("Alice", KeyBundle(JWK0["keys"])) + kj.add_kb("Bob", KeyBundle(JWK1["keys"])) + kj.add_kb("C", KeyBundle(JWK2["keys"])) + + res = kj.dump() + assert json.dumps(res) + + def test_contains(): kj = KeyJar() kj.add_kb("Alice", KeyBundle(JWK0["keys"]))