import datetime as dt
from datetime import datetime
from typing import Union
import pymongo
from pymongo.results import UpdateResult
from pyeudiw.storage.base_storage import (
BaseStorage,
TrustType,
trust_anchor_field_map,
trust_attestation_field_map,
trust_type_map,
)
from pyeudiw.storage.exceptions import ChainNotExist, StorageEntryUpdateFailed
[docs]
class MongoStorage(BaseStorage):
def __init__(self, conf: dict, url: str, connection_params: dict = {}) -> None:
super().__init__()
self.storage_conf = conf
self.url = url
self.connection_params = connection_params
self.client = None
self.db = None
self.set_session_retention_ttl(conf.get("data_ttl", None))
@property
def is_connected(self) -> bool:
if not self.client:
return False
try:
self.client.server_info()
except (pymongo.errors.InvalidOperation, pymongo.errors.AutoReconnect, OSError):
# MongoDB may drop connections under load (e.g. many test workers). Invalidate
# so the next _connect() creates a fresh client instead of reusing a dead socket.
self._reset_connection()
return False
return True
def _reset_connection(self) -> None:
"""Clear client and db references so the next _connect() creates a fresh connection.
Used when MongoDB closes the connection (AutoReconnect / "connection closed"),
e.g. under load during pytest runs with multiple DBEngine instances.
"""
self.client = None
self.db = None
self.sessions = None
self.trust_attestations = None
self.trust_anchors = None
self.trust_sources = None
def _connect(self):
if not self.is_connected:
params = dict(self.connection_params or {})
params.setdefault("maxPoolSize", 10)
self.client = pymongo.MongoClient(self.url, **params)
self.db = getattr(self.client, self.storage_conf["db_name"])
self.sessions = getattr(self.db, self.storage_conf["db_sessions_collection"])
self.trust_attestations = getattr(self.db, self.storage_conf["db_trust_attestations_collection"])
self.trust_anchors = getattr(self.db, self.storage_conf["db_trust_anchors_collection"])
self.trust_sources = getattr(self.db, self.storage_conf["db_trust_sources_collection"])
[docs]
def close(self):
self._connect()
self.client.close()
[docs]
def get_by_id(self, document_id: str) -> dict:
self._connect()
document = self.sessions.find_one({"document_id": document_id})
if document is None:
raise ValueError(f"Document with id {document_id} not found")
return document
[docs]
def get_by_nonce_state(self, nonce: str, state: str | None) -> dict:
self._connect()
query = {"state": state, "nonce": nonce}
if not state:
query.pop("state")
document = self.sessions.find_one(query)
if document is None:
raise ValueError(f"Document with nonce {nonce} and state {state} not found")
return document
[docs]
def get_by_session_id(self, session_id: str) -> Union[dict, None]:
self._connect()
query = {"session_id": session_id}
document = self.sessions.find_one(query)
if document is None:
raise ValueError(f"Document with session id {session_id} not found.")
return document
[docs]
def get_by_state_and_session_id(self, state: str, session_id: str = "") -> Union[dict, None]:
self._connect()
query = {"state": state}
if session_id:
query["session_id"] = session_id
document = self.sessions.find_one(query)
if document is None:
raise ValueError(f"Document with state {state} not found.")
return document
[docs]
def init_session(self, document_id: str, session_id: str, state: str, remote_flow_typ: str) -> str:
entity = {
"document_id": document_id,
"creation_date": dt.datetime.now(tz=dt.timezone.utc),
"state": state,
"session_id": session_id,
"remote_flow_typ": remote_flow_typ,
"finalized": False,
"internal_response": None,
}
try:
self._connect()
except Exception as e:
raise e
self.sessions.insert_one(entity)
return document_id
[docs]
def set_session_retention_ttl(self, ttl: int) -> None:
# Runs in __init__; under load (e.g. second DBEngine in register_endpoints) the
# socket can be closed by MongoDB. Retry once with a fresh connection.
def _do_set_ttl() -> None:
self._connect()
if not ttl:
if self.sessions.index_information().get("creation_date_1"):
self.sessions.drop_index("creation_date_1")
else:
self.sessions.create_index([("creation_date", pymongo.ASCENDING)], expireAfterSeconds=ttl)
try:
_do_set_ttl()
except (pymongo.errors.AutoReconnect, OSError) as e:
if isinstance(e, OSError) and "connection closed" not in str(e).lower():
raise
self._reset_connection()
_do_set_ttl()
[docs]
def get_session_retention_ttl(self) -> dict:
return self.sessions.index_information().get("creation_date_1")
[docs]
def has_session_retention_ttl(self) -> bool:
self._connect()
return self.sessions.index_information().get("creation_date_1") is not None
[docs]
def add_dpop_proof_and_attestation(self, document_id: str, dpop_proof: dict, attestation: dict) -> UpdateResult:
self._connect()
update_result: UpdateResult = self.sessions.update_one(
{"document_id": document_id},
{
"$set": {
"dpop_proof": dpop_proof,
"attestation": attestation,
}
},
)
if update_result.matched_count != 1 or update_result.modified_count != 1:
raise ValueError(f"Cannot update document {document_id}'.")
return update_result
[docs]
def update_request_object(self, document_id: str, request_object: dict) -> UpdateResult:
self.get_by_id(document_id)
documentStatus = self.sessions.update_one(
{"document_id": document_id},
{
"$set": {
"request_object": request_object,
"nonce": request_object["nonce"],
"state": request_object["state"],
}
},
)
if documentStatus.matched_count != 1 or documentStatus.modified_count != 1:
raise ValueError(f"Cannot update document {document_id}')")
return documentStatus
[docs]
def set_finalized(self, document_id: str):
self.get_by_id(document_id)
update_result: UpdateResult = self.sessions.update_one(
{"document_id": document_id},
{
"$set": {"finalized": True},
},
)
if update_result.matched_count != 1: # or update_result.modified_count != 1:
raise ValueError(f"Cannot update document {document_id}'")
return update_result
[docs]
def update_response_object(self, nonce: str, state: str, internal_response: dict, isError: bool = False) -> UpdateResult:
document = self.get_by_nonce_state(nonce, state)
document_id = document["_id"]
updated_data_label = "internal_response" if not isError else "error_response"
document_status = self.sessions.update_one(
{"_id": document_id},
{
"$set": {updated_data_label: internal_response},
},
)
return document_status
def _get_db_entity(self, collection: str, entity_id: str) -> dict | None:
self._connect()
db_collection = getattr(self, collection)
return db_collection.find_one({"entity_id": entity_id})
[docs]
def get_trust_source(self, entity_id: str) -> dict | None:
return self._get_db_entity(self.storage_conf["db_trust_sources_collection"], entity_id)
[docs]
def get_trust_attestation(self, entity_id: str) -> dict | None:
return self._get_db_entity(self.storage_conf["db_trust_attestations_collection"], entity_id)
[docs]
def get_trust_anchor(self, entity_id: str) -> dict | None:
return self._get_db_entity(self.storage_conf["db_trust_anchors_collection"], entity_id)
def _has_db_entity(self, collection: str, entity_id: str) -> bool:
return self._get_db_entity(collection, entity_id) is not None
[docs]
def has_trust_attestation(self, entity_id: str) -> bool:
return self._has_db_entity(self.storage_conf["db_trust_attestations_collection"], entity_id)
[docs]
def has_trust_anchor(self, entity_id: str) -> bool:
return self._has_db_entity(self.storage_conf["db_trust_anchors_collection"], entity_id)
[docs]
def has_trust_source(self, entity_id: str) -> bool:
return self._has_db_entity(self.storage_conf["db_trust_sources_collection"], entity_id)
def _upsert_entry(self, key_label: str, collection: str, data: Union[str, dict]) -> tuple[str, dict]:
db_collection = getattr(self, collection)
document_status = db_collection.update_one({key_label: data[key_label]}, {"$set": data}, upsert=True)
if not document_status.acknowledged:
raise StorageEntryUpdateFailed("Trust Anchor matched count is ZERO")
return document_status
def _get_entry_by_key(self, key_label: str, collection: str, key_value: str) -> dict:
db_collection = getattr(self, collection)
document = db_collection.find_one({key_label: key_value})
if document is None:
raise ValueError(f"Document with {key_label} {key_value} not found.")
return document
def _update_attestation_metadata(
self,
entity: dict,
attestation: list[str],
exp: datetime,
trust_type: TrustType,
jwks: list[dict],
):
trust_name = trust_type_map[trust_type]
trust_field = trust_attestation_field_map.get(trust_type, None)
trust_entity = entity.get(trust_name, {})
if trust_field and attestation:
trust_entity[trust_field] = attestation
if exp:
trust_entity["exp"] = exp
if jwks:
trust_entity["jwks"] = jwks
entity[trust_name] = trust_entity
return entity
def _update_anchor_metadata(
self,
entity: dict,
attestation: list[str],
exp: datetime,
trust_type: TrustType,
entity_id: str,
):
if entity.get("entity_id", None) is None:
entity["entity_id"] = entity_id
trust_name = trust_type_map[trust_type]
trust_field = trust_anchor_field_map.get(trust_type, None)
trust_entity = entity.get(trust_name, {})
if trust_field and attestation:
trust_entity[trust_field] = attestation
trust_entity["exp"] = exp
entity[trust_name] = trust_entity
return entity
[docs]
def upsert_session(self, session_id: str, data: dict) -> tuple[str, dict]:
return self._upsert_entry("session_id", self.storage_conf["db_sessions_collection"], {"session_id": session_id, **data})
[docs]
def search_session_by_field(self, field: str, value: str) -> dict:
self._connect()
return self._get_entry_by_key(field, self.storage_conf["db_sessions_collection"], value)
[docs]
def add_trust_attestation(
self,
entity_id: str,
attestation: list[str],
exp: datetime,
trust_type: TrustType,
jwks: list[dict],
) -> str:
entity = {
"entity_id": entity_id,
"federation": {},
"x509": {},
"direct_trust_sd_jwt_vc": {},
"metadata": {},
}
updated_entity = self._update_attestation_metadata(entity, attestation, exp, trust_type, jwks)
self._upsert_entry(
"entity_id",
self.storage_conf["db_trust_attestations_collection"],
updated_entity,
)
return entity_id
[docs]
def add_trust_source(self, trust_source: dict) -> str:
return self._upsert_entry("entity_id", self.storage_conf["db_trust_sources_collection"], trust_source)
[docs]
def add_empty_trust_anchor(self, entity_id: str) -> str:
entity = {"entity_id": entity_id, "federation": {}, "x509": {}}
self._upsert_entry(
"entity_id",
self.storage_conf["db_trust_anchors_collection"],
entity,
)
return entity_id
[docs]
def add_trust_anchor(
self,
entity_id: str,
entity_configuration: str,
exp: datetime,
trust_type: TrustType,
):
entity = {"entity_id": entity_id, "federation": {}, "x509": {}}
updated_entity = self._update_anchor_metadata(entity, entity_configuration, exp, trust_type, entity_id)
self._upsert_entry(
"entity_id",
self.storage_conf["db_trust_anchors_collection"],
updated_entity,
)
return entity_id
[docs]
def update_trust_attestation(
self,
entity_id: str,
attestation: list[str],
exp: datetime,
trust_type: TrustType,
jwks: list[dict],
) -> str:
old_entity = self._get_db_entity(self.storage_conf["db_trust_attestations_collection"], entity_id) or {}
upd_entity = self._update_attestation_metadata(old_entity, attestation, exp, trust_type, jwks)
return self._upsert_entry(
"entity_id",
self.storage_conf["db_trust_attestations_collection"],
upd_entity,
)
[docs]
def update_trust_anchor(
self,
entity_id: str,
entity_configuration: str,
exp: datetime,
trust_type: TrustType,
) -> str:
old_entity = self._get_db_entity(self.storage_conf["db_trust_attestations_collection"], entity_id) or {}
upd_entity = self._update_anchor_metadata(old_entity, entity_configuration, exp, trust_type, entity_id)
if not self.has_trust_anchor(entity_id):
raise ChainNotExist(f"Chain with entity id {entity_id} not exist")
documentStatus = self._upsert_entry("entity_id", self.storage_conf["db_trust_anchors_collection"], upd_entity)
return documentStatus