import json
import logging
from hashlib import sha256
from typing import Any, Callable, TypeVar
from cryptojwt.jwk.ec import ECKey
from cryptojwt.jwk.rsa import RSAKey
from pyeudiw.jwt.jws_helper import JWSHelper
from pyeudiw.jwt.parse import DecodedJwt
from pyeudiw.jwt.utils import base64_urldecode, base64_urlencode
from pyeudiw.jwt.verification import verify_jws_with_key
from pyeudiw.sd_jwt.common import SDJWTCommon
from pyeudiw.sd_jwt.exceptions import InvalidKeyBinding, UnsupportedSdAlg, MissingConfirmationKey
from pyeudiw.sd_jwt.schema import (
VerifierChallenge,
is_sd_jwt_format,
)
from pyeudiw.sd_jwt import DEFAULT_SD_ALG, DIGEST_ALG_KEY, SD_DIGESTS_KEY, SD_LIST_PREFIX
_JsonTypes = dict | list | str | int | float | bool | None
_JsonTypes_T = TypeVar("_JsonTypes_T", bound=_JsonTypes)
FORMAT_SEPARATOR = SDJWTCommon.COMBINED_SERIALIZATION_FORMAT_SEPARATOR
SUPPORTED_SD_ALG_FN: dict[str, Callable[[str], str]] = {"sha-256": lambda s: base64_urlencode(sha256(s.encode("ascii")).digest())}
logger = logging.getLogger(__name__)
[docs]
class SdJwt:
"""
SdJwt is an utility class to easily parse and verify sd jwt.
All class attributes are intended to be read only
"""
def __init__(self, token: str):
if not is_sd_jwt_format(token):
raise ValueError(f"input [token]={token} is not an sd-jwt with: maybe it is a regular jwt?")
self.token = token
# precomputed values
self.token_without_kb: str = ""
self.issuer_jwt: DecodedJwt = DecodedJwt("", "", "", "")
self.disclosures: list[str] = []
self.holder_kb: DecodedJwt | None = None
self._post_init_precomputed_values()
def _post_init_precomputed_values(self):
iss_jwt, *disclosures, kb_jwt = self.token.split(FORMAT_SEPARATOR)
self.token_without_kb = iss_jwt + FORMAT_SEPARATOR + "".join(disc + FORMAT_SEPARATOR for disc in disclosures)
self.issuer_jwt = DecodedJwt.parse(iss_jwt)
self.disclosures = disclosures
if kb_jwt:
self.holder_kb = DecodedJwt.parse(kb_jwt)
# TODO: schema validations(?)
[docs]
def get_confirmation_key(self) -> dict:
"""
Get the confirmation key from the issuer payload claims.
:raises MissingConfirmationKey: if the confirmation key is missing
:return: the confirmation key
:rtype: dict
"""
cnf: dict = self.issuer_jwt.payload.get("cnf", {}).get("jwk", {})
if not cnf:
raise MissingConfirmationKey("missing confirmation (cnf) key from issuer payload claims")
return cnf
[docs]
def get_disclosed_claims(self) -> dict:
"""
Get the disclosed claims from the issuer payload
:raises UnsupportedSdAlg: if the sd_alg is not supported
:raises ValueError: if there are duplicate digests
:return: the disclosed claims
:rtype: dict
"""
return _extract_claims_from_payload(
self.issuer_jwt.payload,
self.disclosures,
SUPPORTED_SD_ALG_FN[self.get_sd_alg()],
)
[docs]
def get_issuer_jwt(self) -> DecodedJwt:
"""
Get the issuer jwt
:return: the issuer jwt
:rtype: DecodedJwt
"""
return self.issuer_jwt
[docs]
def get_holder_key_binding_jwt(self) -> str:
"""
Get the holder key binding jwt
:return: the holder key binding jwt
:rtype: str
"""
return self.holder_kb.jwt
[docs]
def get_sd_alg(self) -> str:
"""
Get the sd_alg from the issuer jwt
:return: the sd_alg
:rtype: str
"""
return self.issuer_jwt.payload.get("_sd_alg", DEFAULT_SD_ALG)
[docs]
def has_key_binding(self) -> bool:
"""
Check if the token has a key binding
:return: True if the token has a key binding, False otherwise
:rtype: bool
"""
return self.holder_kb is not None
[docs]
def verify_issuer_jwt_signature(self, keys: list[ECKey | RSAKey | dict] | ECKey | RSAKey | dict) -> None:
"""
Verify the issuer jwt signature
:param keys: the public key(s) to use to verify the issuer jwt signature
:type keys: list[ECKey | RSAKey | dict] | ECKey | RSAKey | dict
:raises JWSVerificationError: if the verification fails
"""
jws_verifier = JWSHelper(keys)
jws_verifier.verify(self.issuer_jwt.jwt)
[docs]
def verify_holder_kb_jwt(self, challenge: VerifierChallenge) -> None:
"""
Checks validity of holder key binding.
This procedure always passes when no key binding is used
:raises UnsupportedSdAlg: if verification fails due to an unkown _sd_alg
:raises InvalidKeyBinding: if the verification fails for an invalid key binding
:raises ValueError: if the iat claim is missing or invalid
:raises JWSVerificationError: if the verification fails
"""
if not self.has_key_binding():
return
_verify_key_binding(self.token_without_kb, self.get_sd_alg(), self.holder_kb, challenge)
self.verify_holder_kb_jwt_signature()
[docs]
def verify_holder_kb_jwt_signature(self) -> None:
"""
Verify the holder key binding signature
:raises JWSVerificationError: if the verification fails
"""
if not self.has_key_binding():
return
cnf: dict = self.get_confirmation_key()
verify_jws_with_key(self.holder_kb.jwt, cnf)
def _verify_challenge(hkb: DecodedJwt, challenge: VerifierChallenge) -> None:
"""
Verify the challenge in the key binding
:param hkb: the holder key binding
:type hkb: DecodedJwt
:param challenge: the challenge to verify
:type challenge: VerifierChallenge
:raises InvalidKeyBinding: if the challenge is invalid
"""
if (obt := hkb.payload.get("aud", None)) != (exp := challenge["aud"]):
raise InvalidKeyBinding(f"challenge audience {exp} does not match obtained audience {obt}")
if (obt := hkb.payload.get("nonce", None)) != (exp := challenge["nonce"]):
raise InvalidKeyBinding(f"challenge nonce {exp} does not match obtained nonce {obt}")
def _verify_sd_hash(token_without_hkb: str, sd_hash_alg: str, expected_digest: str) -> None:
"""
Verify the sd-jwt hash
:param token_without_hkb: the token without the holder key binding
:type token_without_hkb: str
:param sd_hash_alg: the algorithm to use to hash the token without the holder key binding
:type sd_hash_alg: str
:param expected_digest: the expected digest
:type expected_digest: str
:raises UnsupportedSdAlg: if the sd_alg is not supported
:raises InvalidKeyBinding: if the key binding is invalid
"""
hash_fn = SUPPORTED_SD_ALG_FN.get(sd_hash_alg, None)
if not hash_fn:
raise UnsupportedSdAlg(f"unsupported sd_alg: {sd_hash_alg}")
if expected_digest != (obt_digest := hash_fn(token_without_hkb)):
raise InvalidKeyBinding(f"sd-jwt digest {obt_digest} does not match expected digest {expected_digest}")
def _verify_iat(payload: dict) -> None:
"""
Verify the iat claim in the payload
:param payload: the payload of the issuer jwt
:type payload: dict
:raises ValueError: if the iat claim is missing or invalid
"""
# we check that 'iat' claim exists, according to sd-jwt specs, but since its a standard claim,
# its value is validated by the general purpose token verification tool JWSHelper accordidng to
# its own rules
iat: int | None = payload.get("iat", None)
if not isinstance(iat, int):
raise ValueError("missing or invalid parameter [iat] in kbjwt")
def _verify_key_binding(
token_without_hkb: str,
sd_hash_alg: str,
hkb: DecodedJwt,
challenge: VerifierChallenge,
) -> None:
"""
Verify the key binding in the sd-jwt
:param token_without_hkb: the token without the holder key binding
:type token_without_hkb: str
:param sd_hash_alg: the algorithm to use to hash the token without the holder key binding
:type sd_hash_alg: str
:param hkb: the holder key binding
:type hkb: DecodedJwt
:param challenge: the challenge to verify
:type challenge: VerifierChallenge
:raises InvalidKeyBinding: if the key binding is invalid
:raises UnsupportedSdAlg: if the sd_alg is not supported
:raises ValueError: if the iat claim is missing or invalid
"""
_verify_challenge(hkb, challenge)
_verify_sd_hash(token_without_hkb, sd_hash_alg, hkb.payload.get("sd_hash", "sha-256"))
_verify_iat(hkb.payload)
def _disclosures_to_hash_mappings(disclosures: list[str], sd_alg: Callable[[str], str]) -> tuple[dict[str, str], dict[str, Any]]:
"""
Convert a list of disclosures to a map of digests to disclosures
:param disclosures: a list of base64 encoded disclosures
:type disclosures: list[str]
:param sd_alg: the function to use to hash the disclosures
:type sd_alg: Callable[[str], str]
:raises ValueError: if there are duplicate digests
:returns: in order
(i) hash_to_disclosure, a map: digest -> raw disclosure (base64 encoded)
(ii) hash_to_dec_disclosure, a map: digest -> decoded disclosure
:rtype: tuple[dict[str, str], dict[str, Any]]
"""
hash_to_disclosure: dict[str, str] = {}
hash_to_dec_disclosure: dict[str, Any] = {}
for disclosure in disclosures:
decoded_disclosure = json.loads(base64_urldecode(disclosure).decode("utf-8"))
digest = sd_alg(disclosure)
if digest in hash_to_dec_disclosure:
raise ValueError(f"duplicate disclosure for digest {digest}")
hash_to_dec_disclosure[digest] = decoded_disclosure
hash_to_disclosure[digest] = disclosure
return hash_to_disclosure, hash_to_dec_disclosure
def _extract_claims_from_payload(payload: dict, disclosures: list[str], sd_alg: Callable[[str], str]) -> dict:
"""
Extract the disclosed claims from the payload
:param payload: the payload of the issuer jwt
:type payload: dict
:param disclosures: a list of base64 encoded disclosures
:type disclosures: list[str]
:param sd_alg: the function to use to hash the disclosures
:type sd_alg: Callable[[str], str]
:raises ValueError: if there are duplicate digests
:returns: the disclosed claims
:rtype: dict
"""
_, hash_to_dec_disclosure = _disclosures_to_hash_mappings(disclosures, sd_alg)
return _unpack_claims(payload, hash_to_dec_disclosure, sd_alg, [])
def _is_element_leaf(element: Any) -> bool:
"""
Check if an element is a leaf in the json tree
:param element: the element to check
:type element: Any
:returns: True if the element is a leaf, False otherwise
:rtype: bool
"""
return type(element) is dict and len(element) == 1 and SD_LIST_PREFIX in element and type(element[SD_LIST_PREFIX]) is str
def _unpack_json_array(
claims: list,
decoded_disclosures_by_digest: dict[str, Any],
sd_alg: Callable[[str], str],
processed_digests: list[str],
) -> list:
"""
Unpack the disclosed claims in the payload
:param claims: the claims to unpack
:type claims: list
:param decoded_disclosures_by_digest: a map of digests to decoded disclosures
:type decoded_disclosures_by_digest: dict[str, Any]
:param sd_alg: the function to use to hash the disclosures
:type sd_alg: Callable[[str], str]
:param processed_digests: a list of processed digests
:type processed_digests: list[str]
:raises ValueError: if there are duplicate digests
:returns: the unpacked claims
:rtype: list
"""
result = []
for element in claims:
if _is_element_leaf(element):
digest: str = element[SD_LIST_PREFIX]
if digest in decoded_disclosures_by_digest:
_, value = decoded_disclosures_by_digest[digest]
result.append(_unpack_claims(value, decoded_disclosures_by_digest, sd_alg, processed_digests))
else:
result.append(_unpack_claims(element, decoded_disclosures_by_digest, sd_alg, processed_digests))
return result
def _unpack_json_dict(
claims: dict,
decoded_disclosures_by_digest: dict[str, Any],
sd_alg: Callable[[str], str],
proceessed_digests: list[str],
) -> dict:
"""
Unpack the disclosed claims in the payload
:param claims: the claims to unpack
:type claims: dict
:param decoded_disclosures_by_digest: a map of digests to decoded disclosures
:type decoded_disclosures_by_digest: dict[str, Any]
:param sd_alg: the function to use to hash the disclosures
:type sd_alg: Callable[[str], str]
:param proceessed_digests: a list of processed digests
:type proceessed_digests: list[str]
:raises ValueError: if there are duplicate digests
:returns: the unpacked claims
:rtype: dict
"""
# First, try to figure out if there are any claims to be
# disclosed in this dict. If so, replace them by their
# disclosed values.
filtered_unpacked_claims = {}
for k, v in claims.items():
if k != SD_DIGESTS_KEY and k != DIGEST_ALG_KEY:
filtered_unpacked_claims[k] = _unpack_claims(v, decoded_disclosures_by_digest, sd_alg, proceessed_digests)
for disclosed_digests in claims.get(SD_DIGESTS_KEY, []):
if disclosed_digests in proceessed_digests:
raise ValueError(f"duplicate hash found in SD-JWT: {disclosed_digests}")
proceessed_digests.append(disclosed_digests)
if disclosed_digests in decoded_disclosures_by_digest:
_, key, value = decoded_disclosures_by_digest[disclosed_digests]
if key in filtered_unpacked_claims:
raise ValueError(f"duplicate key found when unpacking disclosed claim: '{key}' in {filtered_unpacked_claims}; this is not allowed.")
unpacked_value = _unpack_claims(value, decoded_disclosures_by_digest, sd_alg, proceessed_digests)
filtered_unpacked_claims[key] = unpacked_value
return filtered_unpacked_claims
def _unpack_claims(
claims: _JsonTypes_T,
decoded_disclosures_by_digest: dict[str, Any],
sd_alg: Callable[[str], str],
proceessed_digests: list[str],
) -> _JsonTypes_T:
"""
Unpack the disclosed claims in the payload
:param claims: the claims to unpack
:type claims: _JsonTypes_T
:param decoded_disclosures_by_digest: a map of digests to decoded disclosures
:type decoded_disclosures_by_digest: dict[str, Any]
:param sd_alg: the function to use to hash the disclosures
:type sd_alg: Callable[[str], str]
:param proceessed_digests: a list of processed digests
:type proceessed_digests: list[str]
:raises ValueError: if there are duplicate digests
:returns: the unpacked claims
:rtype: _JsonTypes_T
"""
if type(claims) is list:
return _unpack_json_array(claims, decoded_disclosures_by_digest, sd_alg, proceessed_digests)
elif type(claims) is dict:
return _unpack_json_dict(claims, decoded_disclosures_by_digest, sd_alg, proceessed_digests)
else:
return claims