Source code for pyeudiw.jwt.jws_helper

import binascii
import logging
import os
from copy import deepcopy
from typing import Any, Literal, Union

from cryptojwt import JWS
from cryptojwt.jwk.jwk import key_from_jwk_dict

from pyeudiw.jwk import JWK
from pyeudiw.jwk.exceptions import KidError
from pyeudiw.jwk.jwks import find_jwk_by_kid, find_jwk_by_thumbprint
from pyeudiw.jwk.parse import parse_b64der
from pyeudiw.jwt.exceptions import (
    JWSSigningError,
    JWSVerificationError,
    LifetimeException,
)
from pyeudiw.jwt.helper import (
    JWHelperInterface,
    find_self_contained_key,
    serialize_payload,
    validate_jwt_timestamps_claims,
)
from pyeudiw.jwt.utils import decode_jwt_header

SerializationFormat = Literal["compact", "json"]

logger = logging.getLogger(__name__)

DEFAULT_HASH_FUNC = "SHA-256"

DEFAULT_SIG_KTY_MAP = {"RSA": "RS256", "EC": "ES256"}

DEFAULT_SIG_ALG_MAP = {"RSA": "RS256", "EC": "ES256"}

DEFAULT_ENC_ALG_MAP = {"RSA": "RSA-OAEP", "EC": "ECDH-ES+A256KW"}

DEFAULT_ENC_ENC_MAP = {"RSA": "A256CBC-HS512", "EC": "A256GCM"}

DEFAULT_TOKEN_TIME_TOLERANCE = int(os.getenv("PYEUDIW_TOKEN_TIME_TOLERANCE", "60"), base=10)


[docs] class JWSHelper(JWHelperInterface): """ Helper class for working with JWS, extended to support SD-JWT. """
[docs] def sign( self, plain_dict: Union[dict, str, int, None], protected: dict | None = None, unprotected: dict | None = None, serialization_format: SerializationFormat = "compact", signing_kid: str = "", kid_in_header: bool = True, signing_algs: list[str] = [], **kwargs, ) -> str: """Generate a signed JWS with the given payload and header. This method provides no guarantee that the input header is fully preserved, not does it guarantee that some optional but usually found header such as 'typ' and 'kid' are present. If the signing jwk has a kid claim, and the JWS header does not a have a kid claim, a kid matching the signing key 'kid' can be injected in the protected header by setting kid_in_header=True. Header claim 'alg' is always added as it is mandated by RFC7515 and, if present, will be overridden with the actual 'alg' used for singing. This is done to make sure that untrusted alg values, such as none, cannot be used. The signing key is selected among the constructor jwks based on internal heuristics. The user can force with key he can attempt to use by setting signing_key, which will then be looked in the internal set of available keys. If the header already contains indication of a key, such as 'kid', 'trust_chain' and 'x5c', the method will attempt to match the signing key among the available keys based on such claims, but there is no guarantee that the correct key will be selected. We assume that is it responsibility of the class initiator to make those checks. To avoid any possible ambiguity, it is suggested to initilize the class with one (signing) key only. :param plain_dict: The payload to be signed. :param protected: Protected header for the JWS. :param unprotected: Unprotected header for the JWS (only for JSON serialization). :param serialization_format: The format of the signature (compact or JSON). :param signing_kid: The key ID for signing. :param kid_in_header: If True, include the key ID in the token header. :param kwargs: Additional parameters for the signing process. :returns: The signed JWS token. :rtype: str :raises JWSSigningError: if there is any signing error, such as the signing key not being suitable for such a cryptographic operation """ if protected is None: protected = {} if unprotected is None: unprotected = {} # Select the signing key signing_key = self._select_signing_key((protected, unprotected), signing_kid, signing_algs) if signing_key["kty"] == "oct": raise JWSSigningError(f"Key {signing_key['kid']} is a symmetric key") try: _validate_key_with_jws_header(signing_key, protected, unprotected) except Exception as e: raise JWSSigningError(f"failed to validate signing key: it's content it not valid for current header claims: {e}", e) payload = serialize_payload(plain_dict) # Select a trusted algorithm and override header signing_alg: str = DEFAULT_SIG_KTY_MAP[JWK(signing_key).key.kty] protected["alg"] = signing_alg # Add "typ" header if not present if "typ" not in protected: protected["typ"] = "sd-jwt" if self.is_sd_jwt(plain_dict) else "JWT" # Include the signing key's kid in the header if required header_kid = protected.get("kid") signer_kid = signing_key.get("kid") if kid_in_header and signer_kid: # note that is actually redundant as the underlying library auto-update the header with the kid protected["kid"] = signer_kid # this is a hack: if the header to be signed does NOT have kid and we do # not want to include it, then we must remove it from the signing kid # otherwise the signing library will auto insert it if not kid_in_header and not header_kid: signing_key = deepcopy(signing_key) signing_key.pop("kid", None) signing_key_jwk = key_from_jwk_dict(signing_key) if not signing_key_jwk.priv_key: raise JWSSigningError(f"Key {signing_key_jwk.kid} is not a private key") signer = JWS(payload, alg=signing_alg) keys = [signing_key_jwk] if serialization_format == "compact": try: signed = signer.sign_compact(keys, protected=protected, **kwargs) return signed except Exception as e: raise JWSSigningError("Signing error: error in step", e) return signer.sign_json( keys=keys, headers=[(protected, unprotected)], flatten=True, )
def _select_signing_key( self, headers: tuple[dict, dict], signing_kid: str = "", signing_algs: list[str] = [], ) -> dict: """ Select a signing key based on the provided headers and optional parameters. This method attempts to find a suitable signing key from the initialized JWKS. :param headers: A tuple containing the protected and unprotected headers. :param signing_kid: Optional key ID to force the selection of a specific signing key. :param signing_algs: Optional list of algorithms to force the selection of a signing key. :returns: A dictionary representing the selected signing key. :raises JWSSigningError: If no suitable signing key is found or if the key cannot be used for signing. """ if len(self.jwks) == 0: raise JWSSigningError("signing error: no key available for signature; note that {'alg':'none'} is not supported") # Case 1: key forced by the user if signing_kid: signing_key = self.get_jwk_by_kid(signing_kid) if not signing_key: raise JWSSigningError(f"signing forced by using key with {signing_kid=}, but no such key is available") return signing_key.to_dict() # Case 2: key forced by the user by a list of alg if len(signing_algs) > 0: signing_key: dict | None = None for alg in signing_algs: if signing_key := self._select_key_by_sig_alg(alg): break if signing_key: return signing_key else: raise JWSSigningError(f"signing forced by using algs {signing_algs}, but no such key is available") # Case 3: only one key if signing_key := self._select_signing_key_by_uniqueness(): return signing_key # Case 4: only one *signing* key if signing_key := self._select_key_by_use(use="sig"): return signing_key # Case 5: match key by kid if signing_key := self._select_key_by_kid(headers): return signing_key # Case 6: match key by x5c if signing_key := self._select_key_by_x5c(headers): return signing_key raise JWSSigningError("signing error: not possible to uniquely determine the signing key") def _select_signing_key_by_uniqueness(self) -> dict | None: if len(self.jwks) == 1: return self.jwks[0].to_dict() return None def _select_key_by_use(self, use: str) -> dict | None: candidate_signing_keys: list[dict] = [] for key in self.jwks: key_d = key.to_dict() if use == key_d.get("use", ""): candidate_signing_keys.append(key_d) if len(candidate_signing_keys) == 1: return candidate_signing_keys[0] return None def _select_key_by_sig_alg(self, alg: str) -> dict | None: """ Select a key based on the signature algorithm. This is a helper method to find a key that matches the given signature algorithm. """ for key in self.jwks: key_d: dict[str, Any] = key.to_dict() if alg == DEFAULT_SIG_KTY_MAP.get(key_d.get("kty", ""), ""): return key_d return None def _select_key_by_kid(self, headers: tuple[dict[str, Any], dict[str, Any]]) -> dict | None: if not headers: return None if "kid" in headers[0]: kid = headers[0]["kid"] elif "kid" in headers[1]: kid = headers[1]["kid"] else: return None return find_jwk_by_kid([key.to_dict() for key in self.jwks], kid) def _select_key_by_x5c(self, headers: tuple[dict[str, Any], dict[str, Any]]) -> dict | None: if not headers: return None x5c: list[str] | None = headers[0].get("x5c") or headers[1].get("x5c") if not x5c: return None header_jwk = parse_b64der(x5c[0]) for key in self.jwks: key_d = key.to_dict() if JWK(key_d).thumbprint == header_jwk.thumbprint: return key_d return None
[docs] def verify(self, jwt: str, tolerance_s: int = DEFAULT_TOKEN_TIME_TOLERANCE) -> dict | bytes | str | Any: """Verify a JWS with one of the initialized keys and validate standard standard claims if possible, such as 'iat' and 'exp'. Verification of tokens in JSON serialization format is not supported. :param jwt: The JWS token to be verified. :type jws: str :param tolerance_s: optional tolerance window, in seconds, which can be \ used to account for some clock skew between the token issuer and the \ token verifier when validating lifetime claims. :type tolerance_s: int :raises JWSVerificationError: if jws field is not in compact jws format or if the signature is invalid :returns: the decoded payload of the verified tokens. :rtype: dict | bytes | str | Any """ try: header = decode_jwt_header(jwt) except (binascii.Error, Exception) as e: raise JWSVerificationError(f"Not a valid JWS format for the following reason: {e}") verifying_key = self._select_verifying_key(header) if not verifying_key: raise JWSVerificationError(f"Verification error: unable to find matching public key for header {header}") # sanity check: kid must match if present if expected_kid := header.get("kid"): obtained_kid = verifying_key.get("kid", None) if obtained_kid and (obtained_kid != expected_kid): raise JWSVerificationError( KidError( "unexpected verification state: found a valid verifying key," f"but its kid {obtained_kid} does not match token header kid {expected_kid}" ) ) # Verify the JWS compact signature verifier = JWS(alg=header["alg"]) # Validate JWT claims try: msg = verifier.verify_compact(jwt, [key_from_jwk_dict(verifying_key)]) if isinstance(msg, dict): validate_jwt_timestamps_claims(msg, tolerance_s) return msg except LifetimeException as e: raise JWSVerificationError(f"Invalid JWT claims: {e}") except Exception as e: raise JWSVerificationError(f"Error during signature verification: {e}")
def _select_verifying_key(self, header: dict) -> dict | None: available_keys = [key.to_dict() for key in self.jwks] # case 1: can be found by header if "kid" in header: if verifying_key := find_jwk_by_kid(available_keys, header["kid"]): return verifying_key # case 2: the token is self contained, and the verification key matches one of the key in the whitelist if self_contained_claims_key_pair := find_self_contained_key(header): # check if the self contained key matches a trusted jwk _, candidate_key = self_contained_claims_key_pair if hasattr(candidate_key, "thumbprint"): if verifying_key := find_jwk_by_thumbprint(available_keys, candidate_key.thumbprint): return verifying_key else: logger.error(f"Candidate key {candidate_key} does not have a thumbprint attribute.") raise ValueError("Invalid key: missing thumbprint.") # case 3: if only one key and there is no header claim that can identitfy any key, than that MUST # be the only valid CANDIDATE key for signature verification if len(self.jwks) == 1: return self.jwks[0].to_dict() return None
[docs] def is_sd_jwt(self, token: str) -> bool: """ Determines if the provided JWT is an SD-JWT. :param token: The JWT token to inspect. :type token: str :returns: True if the token is an SD-JWT, False otherwise. :rtype: bool """ if not token: return False try: # Decode the JWT header to inspect the 'typ' field header = decode_jwt_header(token) # Check if 'typ' field exists and is equal to 'sd-jwt' return header.get("typ") == "sd-jwt" except Exception as e: # Log or handle errors (optional) logger.warning(f"Unable to determine if token is SD-JWT: {e}") return False
def _validate_key_with_header_kid(key: dict, header: dict) -> None: """ :raises Exception: if the key is not compatible with the header content kid (if any) """ if (key_kid := key.get("kid")) and (header_kid := header.get("kid")) and (key_kid != header_kid): raise Exception(f"token header contains a kid {header_kid} that does not match the signing key kid {key_kid}") return def _validate_key_with_header_x5c(key: dict, header: dict) -> None: """ Validate that a key has a public component that matches what defined in the x5c leaf certificate in the header (if any). Note that this method DOES NOT validate the chain. Instead, it actually checks that the leaf of the chain has the same cryptographic material of the argument key. :raises Exception: if the key is not compatible with the header content x5c (if any) """ x5c: list[str] | None = header.get("x5c") if not x5c: return leaf_cert: str = x5c[0] # if the key has a certificate, check the cert, otherwise check the public material key_x5c: list[str] | None = key.get("x5c") if key_x5c: if leaf_cert != (leaf_x5c_cert := key_x5c[0]): raise Exception( f"token header containes a chain whose leaf certificate {leaf_cert} does not match the signing key leaf certificate {leaf_x5c_cert}" ) return header_key = parse_b64der(leaf_cert) if header_key.thumbprint != JWK(key).thumbprint: raise Exception(f"public material of the key does not matches the key in the leaf certificate {leaf_cert}") return def _validate_key_with_jws_header(key: dict, protected_jws_header: dict, unprotected_jws_header: dict) -> None: """ Validate that a key used for some operations (sign, verify) on a token is compatible with the token header itself. :raises Exception: if the key is not compatible with the token header """ header = deepcopy(protected_jws_header) header.update(unprotected_jws_header) # NOTE: consistency with usage claims such as 'alg', 'kty' and 'use' # are done by the signer library and are not required here _validate_key_with_header_kid(key, header) _validate_key_with_header_x5c(key, header)