Source code for pyeudiw.sd_jwt.issuer

import logging
import secrets
from typing import Dict, List, Union

from cryptojwt.jwk.jwk import key_from_jwk_dict
from cryptojwt.jws.jws import JWS

from pyeudiw.jwt.jws_helper import JWSHelper
from pyeudiw.sd_jwt import (
    DEFAULT_SIGNING_ALG,
    DIGEST_ALG_KEY,
    JSON_SER_DISCLOSURE_KEY,
    SD_DIGESTS_KEY,
    SD_LIST_PREFIX,
)
from pyeudiw.sd_jwt.common import SDJWTCommon, SDObj
from pyeudiw.sd_jwt.disclosure import SDJWTDisclosure

logger = logging.getLogger(__name__)


[docs] class SDJWTIssuer(SDJWTCommon): DECOY_MIN_ELEMENTS = 2 DECOY_MAX_ELEMENTS = 5 sd_jwt_payload: Dict sd_jwt: JWS serialized_sd_jwt: str ii_disclosures: List sd_jwt_issuance: str decoy_digests: List def __init__( self, user_claims: Dict, issuer_keys: Union[Dict, List[Dict]], holder_key=None, sign_alg=None, add_decoy_claims: bool = False, serialization_format: str = "compact", extra_header_parameters: dict = {}, ): super().__init__(serialization_format=serialization_format) self._user_claims = user_claims if not isinstance(issuer_keys, list): issuer_keys = [issuer_keys] for key in issuer_keys: if not isinstance(key, dict): raise ValueError("Not valid jwk dict instance for one or more issuer_keys") self._issuer_keys = issuer_keys if holder_key and not isinstance(holder_key, dict): raise ValueError("Not valid jwk dict instance for holder_key") self._holder_key = holder_key self._sign_alg = sign_alg or DEFAULT_SIGNING_ALG self._add_decoy_claims = add_decoy_claims self._extra_header_parameters = extra_header_parameters self.ii_disclosures = [] self.decoy_digests = [] if len(self._issuer_keys) > 1 and self._serialization_format != "json": raise ValueError( f"Multiple issuer keys (here {len(self._issuer_keys)}) are only supported with JSON serialization." f"Keys found: {self._issuer_keys}" ) self._check_for_sd_claim(self._user_claims) self._assemble_sd_jwt_payload() self._create_signed_jws() self._create_combined() def _assemble_sd_jwt_payload(self): # Create the JWS payload self.sd_jwt_payload = self._create_sd_claims(self._user_claims) self.sd_jwt_payload.update( { DIGEST_ALG_KEY: self.HASH_ALG["name"], } ) if self._holder_key: self.sd_jwt_payload["cnf"] = {"jwk": key_from_jwk_dict(self._holder_key).serialize()} def _create_decoy_claim_entry(self) -> str: digest = self._b64hash(self._generate_salt().encode("ascii")) self.decoy_digests.append(digest) return digest def _create_sd_claims(self, user_claims): # This function can be called recursively. # # If the user claims are a list, apply this function # to each item in the list. if isinstance(user_claims, list): return self._create_sd_claims_list(user_claims) # If the user claims are a dictionary, apply this function # to each key/value pair in the dictionary. elif isinstance(user_claims, dict): return self._create_sd_claims_object(user_claims) # For other types, assume that the value can be disclosed. elif isinstance(user_claims, SDObj): raise ValueError(f"SDObj found in illegal place. The claim value '{user_claims}' should not be wrapped by SDObj.") return user_claims def _create_sd_claims_list(self, user_claims: List): # Walk through all elements in the list. # If an element is marked as SD, then create a proper disclosure for it. # Otherwise, just return the element. output_user_claims = [] for claim in user_claims: if isinstance(claim, SDObj): subtree_from_here = self._create_sd_claims(claim.value) # Create a new disclosure disclosure = SDJWTDisclosure( self, key=None, value=subtree_from_here, ) # Add to ii_disclosures self.ii_disclosures.append(disclosure) # Assemble all hash digests in the disclosures list. output_user_claims.append({SD_LIST_PREFIX: disclosure.hash}) else: subtree_from_here = self._create_sd_claims(claim) output_user_claims.append(subtree_from_here) return output_user_claims def _create_sd_claims_object(self, user_claims: Dict): sd_claims = {SD_DIGESTS_KEY: []} for key, value in user_claims.items(): subtree_from_here = self._create_sd_claims(value) if isinstance(key, SDObj): # Create a new disclosure disclosure = SDJWTDisclosure( self, key=key.value, value=subtree_from_here, ) # Add to ii_disclosures self.ii_disclosures.append(disclosure) # Assemble all hash digests in the disclosures list. sd_claims[SD_DIGESTS_KEY].append(disclosure.hash) else: sd_claims[key] = subtree_from_here # Add decoy claims if requested if self._add_decoy_claims: sr = secrets.SystemRandom() for _ in range(sr.randint(self.DECOY_MIN_ELEMENTS, self.DECOY_MAX_ELEMENTS)): sd_claims[SD_DIGESTS_KEY].append(self._create_decoy_claim_entry()) # Delete the SD_DIGESTS_KEY if it is empty if len(sd_claims[SD_DIGESTS_KEY]) == 0: del sd_claims[SD_DIGESTS_KEY] else: # Sort the hash digests otherwise sd_claims[SD_DIGESTS_KEY].sort() return sd_claims def _create_signed_jws(self): """ Create the SD-JWT. If serialization_format is "compact", then the SD-JWT is a JWT (JWS in compact serialization). If serialization_format is "json", then the SD-JWT is a JWS in JSON serialization. The disclosures in this case will be added in a separate "disclosures" property of the JSON. """ # Assemble protected headers starting with default _protected_headers = {"alg": self._sign_alg, "typ": self.SD_JWT_HEADER} if len(self._issuer_keys) == 1 and "kid" in self._issuer_keys[0]: _protected_headers["kid"] = self._issuer_keys[0]["kid"] # override if any _protected_headers.update(self._extra_header_parameters) _unprotected_headers = {} for i, key in enumerate(self._issuer_keys): _unprotected_headers = {"kid": key["kid"]} if "kid" in key else None if self._serialization_format == "json" and i == 0: _unprotected_headers = _unprotected_headers or {} _unprotected_headers[JSON_SER_DISCLOSURE_KEY] = [d.b64 for d in self.ii_disclosures] self.sd_jwt = JWSHelper(jwks=self._issuer_keys) self.serialized_sd_jwt = self.sd_jwt.sign( self.sd_jwt_payload, protected=_protected_headers, unprotected=_unprotected_headers, serialization_format=self._serialization_format, ) def _create_combined(self): if self._serialization_format == "compact": self.sd_jwt_issuance = self._combine(self.serialized_sd_jwt, *(d.b64 for d in self.ii_disclosures)) self.sd_jwt_issuance += self.COMBINED_SERIALIZATION_FORMAT_SEPARATOR else: self.sd_jwt_issuance = self.serialized_sd_jwt