diff --git a/tuf/api/metadata.py b/tuf/api/metadata.py index 9902051985..80b27eebfe 100644 --- a/tuf/api/metadata.py +++ b/tuf/api/metadata.py @@ -324,19 +324,23 @@ class Signed: """ + # Signed implementations are expected to override this + _signed_type = None + + @property + def _type(self): + return self._signed_type + # NOTE: Signed is a stupid name, because this might not be signed yet, but # we keep it to match spec terminology (I often refer to this as "payload", # or "inner metadata") def __init__( self, - _type: str, version: int, spec_version: str, expires: datetime, unrecognized_fields: Optional[Mapping[str, Any]] = None, ) -> None: - - self._type = _type self.spec_version = spec_version self.expires = expires @@ -346,8 +350,8 @@ def __init__( self.version = version self.unrecognized_fields = unrecognized_fields or {} - @staticmethod - def _common_fields_from_dict(signed_dict: Mapping[str, Any]) -> list: + @classmethod + def _common_fields_from_dict(cls, signed_dict: Mapping[str, Any]) -> list: """Returns common fields of 'Signed' instances from the passed dict representation, and returns an ordered list to be passed as leading positional arguments to a subclass constructor. @@ -356,6 +360,9 @@ def _common_fields_from_dict(signed_dict: Mapping[str, Any]) -> list: """ _type = signed_dict.pop("_type") + if _type != cls._signed_type: + raise ValueError(f"Expected type {cls._signed_type}, got {_type}") + version = signed_dict.pop("version") spec_version = signed_dict.pop("spec_version") expires_str = signed_dict.pop("expires") @@ -363,7 +370,7 @@ def _common_fields_from_dict(signed_dict: Mapping[str, Any]) -> list: # what the constructor expects and what we store. The inverse operation # is implemented in '_common_fields_to_dict'. expires = formats.expiry_string_to_datetime(expires_str) - return [_type, version, spec_version, expires] + return [version, spec_version, expires] def _common_fields_to_dict(self) -> Dict[str, Any]: """Returns dict representation of common fields of 'Signed' instances. @@ -517,13 +524,14 @@ class Root(Signed): """ + _signed_type = "root" + # TODO: determine an appropriate value for max-args and fix places where # we violate that. This __init__ function takes 7 arguments, whereas the # default max-args value for pylint is 5 # pylint: disable=too-many-arguments def __init__( self, - _type: str, version: int, spec_version: str, expires: datetime, @@ -532,9 +540,7 @@ def __init__( roles: Mapping[str, Role], unrecognized_fields: Optional[Mapping[str, Any]] = None, ) -> None: - super().__init__( - _type, version, spec_version, expires, unrecognized_fields - ) + super().__init__(version, spec_version, expires, unrecognized_fields) self.consistent_snapshot = consistent_snapshot self.keys = keys self.roles = roles @@ -612,18 +618,17 @@ class Timestamp(Signed): """ + _signed_type = "timestamp" + def __init__( self, - _type: str, version: int, spec_version: str, expires: datetime, meta: Mapping[str, Any], unrecognized_fields: Optional[Mapping[str, Any]] = None, ) -> None: - super().__init__( - _type, version, spec_version, expires, unrecognized_fields - ) + super().__init__(version, spec_version, expires, unrecognized_fields) # TODO: Add class for meta self.meta = meta @@ -680,18 +685,17 @@ class Snapshot(Signed): """ + _signed_type = "snapshot" + def __init__( self, - _type: str, version: int, spec_version: str, expires: datetime, meta: Mapping[str, Any], unrecognized_fields: Optional[Mapping[str, Any]] = None, ) -> None: - super().__init__( - _type, version, spec_version, expires, unrecognized_fields - ) + super().__init__(version, spec_version, expires, unrecognized_fields) # TODO: Add class for meta self.meta = meta @@ -782,13 +786,14 @@ class Targets(Signed): """ + _signed_type = "targets" + # TODO: determine an appropriate value for max-args and fix places where # we violate that. This __init__ function takes 7 arguments, whereas the # default max-args value for pylint is 5 # pylint: disable=too-many-arguments def __init__( self, - _type: str, version: int, spec_version: str, expires: datetime, @@ -796,9 +801,7 @@ def __init__( delegations: Mapping[str, Any], unrecognized_fields: Optional[Mapping[str, Any]] = None, ) -> None: - super().__init__( - _type, version, spec_version, expires, unrecognized_fields - ) + super().__init__(version, spec_version, expires, unrecognized_fields) # TODO: Add class for meta self.targets = targets self.delegations = delegations