Source code for pyeudiw.jwt.jwe_helper

import binascii
import json
import logging
from typing import Union

import cryptojwt
from cryptojwt.jwe.jwe import factory
from cryptojwt.jwe.jwe_ec import JWE_EC
from cryptojwt.jwe.jwe_rsa import JWE_RSA

from pyeudiw.jwt.exceptions import JWEDecryptionError, JWEEncryptionError
from pyeudiw.jwt.helper import JWHelperInterface
from pyeudiw.jwt.jws_helper import DEFAULT_ENC_ALG_MAP, DEFAULT_ENC_ENC_MAP
from pyeudiw.jwt.utils import decode_jwt_header

logger = logging.getLogger(__name__)


[docs] class JWEHelper(JWHelperInterface): """ The helper class for work with JWEs. """
[docs] def encrypt(self, plain_dict: Union[dict, str, int, None], **kwargs) -> str: """ Generate a encrypted JWE string. :param plain_dict: The payload of JWE. :type plain_dict: Union[dict, str, int, None] :param kwargs: Other optional fields to generate the JWE. :returns: A string that represents the JWE. :rtype: str """ if isinstance(plain_dict, dict): _payload = json.dumps(plain_dict).encode() elif not plain_dict: _payload = "" elif isinstance(plain_dict, (str, int)): _payload = plain_dict else: _payload = "" encryption_keys = [key for key in self.jwks if key.appropriate_for("encrypt")] if len(encryption_keys) == 0: raise JWEEncryptionError("unable to produce JWE: no available encryption key(s)") for key in self.jwks: if isinstance(key, cryptojwt.jwk.rsa.RSAKey): JWE_CLASS = JWE_RSA elif isinstance(key, cryptojwt.jwk.ec.ECKey): JWE_CLASS = JWE_EC else: # unsupported key: go to next one continue _keyobj = JWE_CLASS( _payload, alg=DEFAULT_ENC_ALG_MAP[key.kty], enc=DEFAULT_ENC_ENC_MAP[key.kty], kid=key.kid, **kwargs, ) if key.kty == "EC": _keyobj: JWE_EC cek, encrypted_key, iv, params, _ = _keyobj.enc_setup(msg=_payload, key=key) kwargs = { "params": params, "cek": cek, "iv": iv, "encrypted_key": encrypted_key, } return _keyobj.encrypt(**kwargs) else: return _keyobj.encrypt(key=key.public_key()) raise JWEEncryptionError("unable to produce JWE: no supported encryption key(s)")
[docs] def decrypt(self, jwe: str) -> dict: """ Generate a dict containing the content of decrypted JWE string. :param jwe: A string representing the jwe. :type jwe: str :raises JWEDecryptionError: if jwe field is not in a JWE Format :returns: A dict that represents the payload of decrypted JWE. :rtype: dict """ try: jwe_header = decode_jwt_header(jwe) except (binascii.Error, Exception) as e: raise JWEDecryptionError(f"Not a valid JWE format for the following reason: {e}") _alg = jwe_header.get("alg") _enc = jwe_header.get("enc") _kid = jwe_header.get("kid") _jwk = self.get_jwk_by_kid(_kid) _decryptor = factory(jwe, alg=_alg, enc=_enc) if isinstance(_jwk, cryptojwt.jwk.ec.ECKey): jwdec = JWE_EC() jwdec.dec_setup(_decryptor.jwt, key=_jwk.private_key()) msg = jwdec.decrypt(_decryptor.jwt) else: msg = _decryptor.decrypt(jwe, [_jwk]) try: msg_dict = json.loads(msg) except json.decoder.JSONDecodeError: msg_dict = msg return msg_dict