import logging
import os
import random
import secrets
from base64 import urlsafe_b64decode, urlsafe_b64encode
from dataclasses import dataclass
from hashlib import sha256
from json import loads
from typing import List, Union
from pyeudiw.sd_jwt import JSON_SER_DISCLOSURE_KEY, JSON_SER_KB_JWT_KEY, SD_DIGESTS_KEY
from pyeudiw.sd_jwt.exceptions import SDJWTHasSDClaimException
logger = logging.getLogger(__name__)
[docs]
@dataclass
class SDObj:
"""This class can be used to make this part of the object selective disclosable."""
value: any
def __hash__(self):
"""Hash the object."""
return hash(self.value)
[docs]
class SDJWTCommon:
SD_JWT_HEADER = os.getenv(
# TODO: dc is only for digital credential, while you might use another typ ...
"SD_JWT_HEADER",
"dc+sd-jwt",
) # overwriteable with extra_header_parameters = {"typ": "other-example+sd-jwt"}
KB_JWT_TYP_HEADER = "kb+jwt"
HASH_ALG = {"name": "sha-256", "fn": sha256}
COMBINED_SERIALIZATION_FORMAT_SEPARATOR = "~"
def __init__(self, serialization_format):
if serialization_format not in ("compact", "json"):
raise ValueError(f"Unknown serialization format: {serialization_format}")
self._serialization_format = serialization_format
def _b64hash(self, raw: bytes) -> str:
"""
Calculate the SHA 256 hash and output it base64 encoded.
:param raw: The raw data to hash.
:type raw: bytes
:return: The base64 encoded hash.
:rtype: str
"""
return self._base64url_encode(self.HASH_ALG["fn"](raw).digest())
def _combine(self, *parts) -> str:
"""
Combine the parts with the separator.
:param parts: The parts to combine.
:type parts: str
:return: The combined string.
:rtype: str
"""
return self.COMBINED_SERIALIZATION_FORMAT_SEPARATOR.join(parts)
def _split(self, combined: str) -> List[str]:
"""
Split the combined string.
:param combined: The combined string.
:type combined: str
:return: The parts.
:rtype: List[str]
"""
return combined.split(self.COMBINED_SERIALIZATION_FORMAT_SEPARATOR)
@staticmethod
def _base64url_encode(data: bytes) -> str:
"""
Encode the data in base64url encoding.
:param data: The data to encode.
:type data: bytes
:return: The base64url encoded data.
:rtype: str
"""
return urlsafe_b64encode(data).decode("ascii").strip("=")
@staticmethod
def _base64url_decode(b64data: str) -> bytes:
"""
Decode the base64url encoded data.
:param b64data: The base64url encoded data.
:type b64data: str
:return: The decoded data.
:rtype: bytes
"""
padded = f"{b64data}{'=' * divmod(len(b64data),4)[1]}"
return urlsafe_b64decode(padded)
def _generate_salt(self) -> str:
"""
Generate a salt.
:return: The salt.
:rtype: str
"""
return self._base64url_encode(secrets.token_bytes(16))
def _create_hash_mappings(self, disclosurses_list: List) -> None:
"""
Create the hash mappings for the disclosures.
:param disclosurses_list: The list of disclosures.
:type disclosurses_list: List
"""
# Mapping from hash of disclosure to the decoded disclosure
self._hash_to_decoded_disclosure = {}
# Mapping from hash of disclosure to the raw disclosure
self._hash_to_disclosure = {}
for disclosure in disclosurses_list:
decoded_disclosure = loads(
self._base64url_decode(disclosure).decode("utf-8")
)
_hash = self._b64hash(disclosure.encode("ascii"))
if _hash in self._hash_to_decoded_disclosure:
raise ValueError(
f"Duplicate disclosure hash {_hash} for disclosure {decoded_disclosure}"
)
self._hash_to_decoded_disclosure[_hash] = decoded_disclosure
self._hash_to_disclosure[_hash] = disclosure
def _check_for_sd_claim(self, obj: Union[dict, list, any]) -> None:
"""
Check for the presence of the _sd claim in the object.
:param obj: The object to check.
:type obj: Union[dict, list, any]
"""
# Recursively check for the presence of the _sd claim, also
# works for arrays and nested objects.
if isinstance(obj, dict):
for key, value in obj.items():
if key == SD_DIGESTS_KEY:
raise SDJWTHasSDClaimException(obj)
else:
self._check_for_sd_claim(value)
elif isinstance(obj, list):
for item in obj:
self._check_for_sd_claim(item)
else:
return
def _parse_sd_jwt(self, sd_jwt: str) -> None:
"""
Parse the SD-JWT.
:param sd_jwt: The SD-JWT to parse.
:type sd_jwt: str
"""
if self._serialization_format == "compact":
(
self._unverified_input_sd_jwt,
*self._input_disclosures,
self._unverified_input_key_binding_jwt,
) = self._split(sd_jwt)
# Extract only the body from SD-JWT without verifying the signature
_, jwt_body, _ = self._unverified_input_sd_jwt.split(".")
self._unverified_input_sd_jwt_payload = self._base64url_decode(jwt_body)
self._unverified_compact_serialized_input_sd_jwt = (
self._unverified_input_sd_jwt
)
else:
# if the SD-JWT is in JSON format, parse the json and extract the disclosures.
self._unverified_input_sd_jwt = sd_jwt
self._unverified_input_sd_jwt_parsed = loads(sd_jwt)
self._unverified_input_sd_jwt_payload = loads(
self._base64url_decode(self._unverified_input_sd_jwt_parsed["payload"])
)
# distinguish between flattened and general JSON serialization (RFC7515)
if "signature" in self._unverified_input_sd_jwt_parsed:
# flattened
self._input_disclosures = self._unverified_input_sd_jwt_parsed[
"header"
][JSON_SER_DISCLOSURE_KEY]
self._unverified_input_key_binding_jwt = (
self._unverified_input_sd_jwt_parsed["header"].get(
JSON_SER_KB_JWT_KEY, ""
)
)
self._unverified_compact_serialized_input_sd_jwt = ".".join(
[
self._unverified_input_sd_jwt_parsed["protected"],
self._unverified_input_sd_jwt_parsed["payload"],
self._unverified_input_sd_jwt_parsed["signature"],
]
)
elif "signatures" in self._unverified_input_sd_jwt_parsed:
# general, look at the header in the first signature
self._input_disclosures = self._unverified_input_sd_jwt_parsed[
"signatures"
][0]["header"][JSON_SER_DISCLOSURE_KEY]
self._unverified_input_key_binding_jwt = (
self._unverified_input_sd_jwt_parsed["signatures"][0]["header"].get(
JSON_SER_KB_JWT_KEY, ""
)
)
self._unverified_compact_serialized_input_sd_jwt = ".".join(
[
self._unverified_input_sd_jwt_parsed["signatures"][0][
"protected"
],
self._unverified_input_sd_jwt_parsed["payload"],
self._unverified_input_sd_jwt_parsed["signatures"][0][
"signature"
],
]
)
else:
raise ValueError("Invalid JSON serialization of SD-JWT")
def _calculate_kb_hash(self, disclosures: List[str]) -> str:
"""
Calculate the hash over the key binding.
:param disclosures: The list of disclosures.
:type disclosures: List[str]
:return: The hash over the key binding.
:rtype: str
"""
# Temporarily create the combined presentation in order to create the hash over it
# Note: For JSON Serialization, the compact representation of the SD-JWT is restored from the parsed JSON (see common.py)
string_to_hash = self._combine(
self._unverified_compact_serialized_input_sd_jwt, *disclosures, ""
)
return self._b64hash(string_to_hash.encode("ascii"))