import base64
import json
import re
from pyeudiw.jwt.exceptions import JWTDecodeError, JWTInvalidElementPosition
# jwt regexp pattern is non terminating, hence it match jwt, sd-jwt and sd-jwt with kb
JWT_REGEXP = r"^[_\w\-]+\.[_\w\-]+\.[_\w\-]+"
[docs]
def decode_jwt_element(jwt: str, position: int) -> dict:
"""
Decodes the element in a determinated position.
:param jwt: a string that represents the jwt.
:type jwt: str
:param position: the position of segment to unpad.
:type position: int
:raises JWTInvalidElementPosition: If the JWT element position is greather then one or less of 0
:raises JWTDecodeError: If the JWT element cannot be decoded.
:returns: a dict with the content of the decoded section.
:rtype: dict
"""
if position < 0:
raise JWTInvalidElementPosition(
f"Cannot accept negative position {position}"
)
if position > 2:
raise JWTInvalidElementPosition(
f"Cannot accept position greater than 2 {position}"
)
try:
if isinstance(jwt, bytes):
jwt = jwt.decode()
splitted_jwt = jwt.split(".")
if (len(splitted_jwt) - 1) < position:
raise JWTInvalidElementPosition(
f"JWT has no element in position {position}"
)
b64_data = jwt.split(".")[position]
data = json.loads(base64_urldecode(b64_data))
return data
except JWTInvalidElementPosition as jwtInvalidElementPosition:
raise jwtInvalidElementPosition
except Exception as e:
raise JWTDecodeError(f"Unable to decode JWT element: {e}")
[docs]
def decode_jwt_payload(jwt: str) -> dict:
"""
Decodes the jwt payload.
:param jwt: a string that represents the jwt.
:type jwt: str
:raises JWTDecodeError: If the JWT header cannot be decoded.
:raises JWTInvalidElementPosition: If the JWT element position is missing
:returns: a dict with the content of the decoded payload.
:rtype: dict
"""
return decode_jwt_element(jwt, position=1)
[docs]
def base64_urlencode(v: bytes) -> str:
"""Urlsafe base64 encoding without padding symbols
:returns: the encooded data
:rtype: str
"""
return base64.urlsafe_b64encode(v).decode("ascii").strip("=")
[docs]
def base64_urldecode(v: str) -> bytes:
"""Urlsafe base64 decoding. This function will handle missing
padding symbols.
:returns: the decoded data in bytes, format, convert to str use method '.decode("utf-8")' on result
:rtype: bytes
"""
padded = f"{v}{'=' * divmod(len(v), 4)[1]}"
return base64.urlsafe_b64decode(padded)