Source code for advanced_alchemy.extensions.litestar.store

import base64
import datetime
from contextlib import asynccontextmanager, contextmanager
from typing import TYPE_CHECKING, Any, Final, Generic, Optional, TypeVar, Union, cast

from litestar.exceptions import ImproperlyConfiguredException
from litestar.stores.base import NamespacedStore
from litestar.types import Empty, EmptyType
from litestar.utils.empty import value_or_default
from sqlalchemy import (
    BooleanClauseList,
    Dialect,
    Index,
    LargeBinary,
    String,
    UniqueConstraint,
    delete,
    func,
    insert,
    select,
    update,
)
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.ext.hybrid import hybrid_property
from sqlalchemy.orm import Mapped, Session, declarative_mixin, declared_attr, mapped_column

from advanced_alchemy.base import UUIDv7Base

# Import config types and async_ utility
from advanced_alchemy.extensions.litestar.plugins.init import (
    SQLAlchemyAsyncConfig,
    SQLAlchemySyncConfig,
)
from advanced_alchemy.operations import OnConflictUpsert
from advanced_alchemy.utils.sync_tools import async_

if TYPE_CHECKING:
    from collections.abc import AsyncGenerator, Generator

    from sqlalchemy.ext.asyncio import AsyncSession
    from sqlalchemy.orm.decl_base import _TableArgsType as TableArgsType  # pyright: ignore[reportPrivateUsage]
    from sqlalchemy.sql.elements import BooleanClauseList

SQLAlchemyConfigT = TypeVar("SQLAlchemyConfigT", bound=Union[SQLAlchemyAsyncConfig, SQLAlchemySyncConfig])

__all__ = ("SQLAlchemyStore", "StoreModelMixin")

_POSTGRES_VERSION_SUPPORTING_MERGE: Final = 15

# Temporary toggle to disable PostgreSQL MERGE due to locking concerns
_DISABLE_POSTGRES_MERGE: Final = True


[docs] @declarative_mixin class StoreModelMixin(UUIDv7Base): """Mixin for session storage.""" __abstract__ = True @declared_attr.directive @classmethod def __table_args__(cls) -> "TableArgsType": return ( UniqueConstraint( cls.key, cls.namespace, name=f"uq_{cls.__tablename__}_key_namespace", ).ddl_if(callable_=cls._create_unique_store_key_namespace_constraint), Index( f"ix_{cls.__tablename__}_key_namespace_unique", cls.key, cls.namespace, unique=True, ).ddl_if(callable_=cls._create_unique_store_key_namespace_index), ) @declared_attr def key(cls) -> Mapped[str]: return mapped_column(String(length=255), nullable=False) @declared_attr def namespace(cls) -> Mapped[str]: return mapped_column(String(length=255), nullable=False) @declared_attr def value(cls) -> Mapped[bytes]: return mapped_column(LargeBinary, nullable=False) @declared_attr def expires_at(cls) -> Mapped[datetime.datetime]: return mapped_column(index=True) @classmethod def _create_unique_store_key_namespace_index(cls, *_: Any, **kwargs: Any) -> bool: dialect_name = kwargs.get("dialect", {}).name if "dialect" in kwargs else "" return bool("spanner" in dialect_name.lower()) @classmethod def _create_unique_store_key_namespace_constraint(cls, *_: Any, **kwargs: Any) -> bool: dialect_name = kwargs.get("dialect", {}).name if "dialect" in kwargs else "" return "spanner" not in dialect_name.lower() @hybrid_property def is_expired(self) -> bool: # pyright: ignore """Boolean indicating if the session has expired. Returns: `True` if the session has expired, otherwise `False` """ return datetime.datetime.now(datetime.timezone.utc) > self.expires_at @is_expired.expression # type: ignore[no-redef] def is_expired(cls) -> "BooleanClauseList": # noqa: N805 """SQL-Expression to check if the session has expired. Returns: SQL-Expression to check if the session has expired. """ return cast("BooleanClauseList", func.now() > cls.expires_at)
[docs] class SQLAlchemyStore(NamespacedStore, Generic[SQLAlchemyConfigT]): """SQLAlchemy based, thread and process safe asynchronous key/value store. Supports both synchronous and asynchronous SQLAlchemy configurations. """ __slots__ = ("_config", "_is_async", "_model", "_session_maker")
[docs] def __init__( self, config: "SQLAlchemyConfigT", model: "type[StoreModelMixin]" = StoreModelMixin, namespace: "Optional[Union[str, EmptyType]]" = Empty, ) -> None: """Initialize :class:`SQLAlchemyStore`. Args: config: An instance of ``SQLAlchemyAsyncConfig`` or ``SQLAlchemySyncConfig``. model: The SQLAlchemy model to use for storing data. Defaults to :class:`StoreItem`. namespace: A virtual namespace for keys. If not given, defaults to ``LITESTAR``. Namespacing can be explicitly disabled by passing ``None``. This will make :meth:`.delete_all` unavailable. """ self._config = config self._model = model self._is_async = isinstance(config, SQLAlchemyAsyncConfig) self.namespace: Optional[str] = value_or_default(namespace, "LITESTAR")
@asynccontextmanager async def _get_async_session(self) -> "AsyncGenerator[AsyncSession, None]": if not self._is_async: # This should ideally not be called if configured for sync, # but provides a safeguard. msg = "Store configured for synchronous operation." raise ImproperlyConfiguredException(msg) async with cast("AsyncSession", self._config.get_session()) as session: yield session @contextmanager def _get_sync_session(self) -> "Generator[Session, None, None]": if self._is_async: msg = "Store configured for asynchronous operation." raise ImproperlyConfiguredException(msg) with cast("Session", self._config.get_session()) as session: yield session @staticmethod def supports_merge(dialect: Optional[Dialect] = None, force_disable_merge: bool = False) -> bool: return bool( dialect and ( ( dialect.server_version_info is not None and dialect.server_version_info[0] >= _POSTGRES_VERSION_SUPPORTING_MERGE and dialect.name == "postgresql" and not _DISABLE_POSTGRES_MERGE # Temporary PostgreSQL MERGE disable ) or dialect.name == "oracle" ) and not force_disable_merge ) @staticmethod def supports_upsert(dialect: Optional[Dialect] = None, force_disable_upsert: bool = False) -> bool: return bool( dialect and (dialect.name in {"postgresql", "cockroachdb", "sqlite", "mysql", "mariadb", "duckdb"}) and not force_disable_upsert ) def _make_key(self, key: str) -> tuple[str, Optional[str]]: """Return the key and namespace tuple, handling potential None namespace.""" return key, self.namespace def _decode_base64_value(self, value: Optional[bytes], dialect_name: str) -> Optional[bytes]: """Decode base64 encoded value from Spanner if needed. Spanner automatically base64-encodes binary data when storing it, so we need to decode it when retrieving. """ if value is None or not dialect_name.startswith("spanner"): return value try: return base64.b64decode(value) except Exception: # noqa: BLE001 return value def _set_sync( self, key: str, value: Union[bytes, str], expires_in: Optional[Union[int, datetime.timedelta]] = None ) -> None: """Synchronous implementation for setting a value.""" db_key, db_namespace = self._make_key(key) expires_at: Optional[datetime.datetime] = None serialized_value = value if isinstance(value, bytes) else value.encode("utf-8") if expires_in is not None: delta = expires_in if isinstance(expires_in, datetime.timedelta) else datetime.timedelta(seconds=expires_in) expires_at = datetime.datetime.now(datetime.timezone.utc) + delta with self._get_sync_session() as session, session.begin(): if session.bind is None: msg = "Database connection is not available" raise ImproperlyConfiguredException(msg) dialect = session.bind.dialect dialect_name = dialect.name values = { "key": db_key, "namespace": db_namespace, "value": serialized_value, "expires_at": expires_at, } conflict_columns = ["key", "namespace"] update_columns = ["value", "expires_at"] if OnConflictUpsert.supports_native_upsert(dialect_name): upsert_stmt = OnConflictUpsert.create_upsert( table=self._model.__table__, # type: ignore[arg-type] values=values, conflict_columns=conflict_columns, update_columns=update_columns, dialect_name=dialect_name, validate_identifiers=False, ) session.execute(upsert_stmt) elif self.supports_merge(dialect): merge_stmt, additional_params = OnConflictUpsert.create_merge_upsert( table=self._model.__table__, # type: ignore[arg-type] values=values, conflict_columns=conflict_columns, update_columns=update_columns, dialect_name=dialect_name, validate_identifiers=False, ) # Merge additional Oracle parameters with original values merge_values = {**values, **additional_params} session.execute(merge_stmt, merge_values) else: # Fallback logic: Check existence, then update or insert existing = session.execute( select(1).where(self._model.key == db_key, self._model.namespace == db_namespace) ).scalar_one_or_none() if existing: session.execute( update(self._model) .where(self._model.key == db_key, self._model.namespace == db_namespace) .values(value=serialized_value, expires_at=expires_at) ) else: session.execute( insert(self._model).values( key=db_key, namespace=db_namespace, value=serialized_value, expires_at=expires_at ) ) session.commit() async def _set_async( self, key: str, value: Union[bytes, str], expires_in: Optional[Union[int, datetime.timedelta]] = None ) -> None: """Asynchronous implementation for setting a value.""" db_key, db_namespace = self._make_key(key) serialized_value = value if isinstance(value, bytes) else value.encode("utf-8") expires_at: Optional[datetime.datetime] = None if expires_in is not None: delta = expires_in if isinstance(expires_in, datetime.timedelta) else datetime.timedelta(seconds=expires_in) expires_at = datetime.datetime.now(datetime.timezone.utc) + delta async with self._get_async_session() as session, session.begin(): if session.bind is None: # pyright: ignore[reportUnnecessaryComparison] msg = "Database connection is not available" # type: ignore[unreachable] raise ImproperlyConfiguredException(msg) dialect = session.bind.dialect dialect_name = dialect.name values = { "key": db_key, "namespace": db_namespace, "value": serialized_value, "expires_at": expires_at, } conflict_columns = ["key", "namespace"] update_columns = ["value", "expires_at"] if OnConflictUpsert.supports_native_upsert(dialect_name): upsert_stmt = OnConflictUpsert.create_upsert( table=self._model.__table__, # type: ignore[arg-type] values=values, conflict_columns=conflict_columns, update_columns=update_columns, dialect_name=dialect_name, validate_identifiers=False, ) await session.execute(upsert_stmt) elif self.supports_merge(dialect): merge_stmt, additional_params = OnConflictUpsert.create_merge_upsert( table=self._model.__table__, # type: ignore[arg-type] values=values, conflict_columns=conflict_columns, update_columns=update_columns, dialect_name=dialect_name, validate_identifiers=False, ) # Merge additional Oracle parameters with original values merge_values = {**values, **additional_params} await session.execute(merge_stmt, merge_values) else: # Fallback logic: Check existence, then update or insert existing_id = ( await session.execute( select(self._model.id).where( self._model.key == db_key, self._model.namespace == db_namespace, ) ) ).scalar_one_or_none() if existing_id: await session.execute( update(self._model) .where(self._model.id == existing_id) .values(value=serialized_value, expires_at=expires_at) ) else: await session.execute( insert(self._model).values( key=db_key, namespace=db_namespace, value=serialized_value, expires_at=expires_at ) ) await session.commit()
[docs] async def set( self, key: str, value: Union[bytes, str], expires_in: Optional[Union[int, datetime.timedelta]] = None ) -> None: """Set a value. Handles both sync and async backends.""" if self._is_async: await self._set_async(key, value, expires_in) else: await async_(self._set_sync)(key, value, expires_in)
def _get_sync(self, key: str, renew_for: Optional[Union[int, datetime.timedelta]] = None) -> Optional[bytes]: db_key, db_namespace = self._make_key(key) now = datetime.datetime.now(datetime.timezone.utc) with self._get_sync_session() as session, session.begin(): if session.bind is None: msg = "Database connection is not available" raise ImproperlyConfiguredException(msg) dialect_name = session.bind.dialect.name value = session.execute( select(self._model.value).where( self._model.key == db_key, self._model.namespace == db_namespace, (self._model.expires_at.is_(None)) | (self._model.expires_at > now), ) ).scalar_one_or_none() if value: if renew_for: delta = ( renew_for if isinstance(renew_for, datetime.timedelta) else datetime.timedelta(seconds=renew_for) ) session.execute( update(self._model) .where(self._model.key == db_key, self._model.namespace == db_namespace) .values(expires_at=now + delta) ) session.commit() return self._decode_base64_value(value, dialect_name) session.commit() # Commit even if not found return None async def _get_async(self, key: str, renew_for: Optional[Union[int, datetime.timedelta]] = None) -> Optional[bytes]: db_key, db_namespace = self._make_key(key) now = datetime.datetime.now(datetime.timezone.utc) async with self._get_async_session() as session, session.begin(): if session.bind is None: # pyright: ignore[reportUnnecessaryComparison] msg = "Database connection is not available" # type: ignore[unreachable] raise ImproperlyConfiguredException(msg) dialect_name = session.bind.dialect.name value = ( await session.execute( select(self._model.value).where( self._model.key == db_key, self._model.namespace == db_namespace, (self._model.expires_at.is_(None)) | (self._model.expires_at > now), ) ) ).scalar_one_or_none() if value: if renew_for: delta = ( renew_for if isinstance(renew_for, datetime.timedelta) else datetime.timedelta(seconds=renew_for) ) await session.execute( update(self._model) .where(self._model.key == db_key, self._model.namespace == db_namespace) .values(expires_at=now + delta) ) await session.commit() return self._decode_base64_value(value, dialect_name) await session.commit() # Commit even if not found return None
[docs] async def get(self, key: str, renew_for: Optional[Union[int, datetime.timedelta]] = None) -> Optional[bytes]: """Get a value. Handles both sync and async backends. Args: key: The key to get the value for. renew_for: The amount of time to renew the value for. Returns: The value of the key, or None if the key does not exist or has expired. """ if self._is_async: return await self._get_async(key, renew_for) return await async_(self._get_sync)(key, renew_for)
def _delete_sync(self, key: str) -> None: db_key, db_namespace = self._make_key(key) with self._get_sync_session() as session, session.begin(): session.execute(delete(self._model).where(self._model.key == db_key, self._model.namespace == db_namespace)) session.commit() async def _delete_async(self, key: str) -> None: db_key, db_namespace = self._make_key(key) async with self._get_async_session() as session, session.begin(): await session.execute( delete(self._model).where(self._model.key == db_key, self._model.namespace == db_namespace) ) await session.commit()
[docs] async def delete(self, key: str) -> None: """Delete a value. Handles both sync and async backends. Args: key: The key to delete. """ if self._is_async: await self._delete_async(key) else: await async_(self._delete_sync)(key)
def _delete_all_sync(self) -> None: if self.namespace is None: msg = "Cannot perform delete operation: No namespace configured" raise ImproperlyConfiguredException(msg) db_namespace = self.namespace with self._get_sync_session() as session, session.begin(): session.execute(delete(self._model).where(self._model.namespace == db_namespace)) session.commit() async def _delete_all_async(self) -> None: if self.namespace is None: msg = "Cannot perform delete operation: No namespace configured" raise ImproperlyConfiguredException(msg) db_namespace = self.namespace async with self._get_async_session() as session, session.begin(): await session.execute(delete(self._model).where(self._model.namespace == db_namespace)) await session.commit()
[docs] async def delete_all(self) -> None: """Delete all values in the namespace. Handles both sync and async backends. Args: key: The key to delete. """ if self._is_async: await self._delete_all_async() else: await async_(self._delete_all_sync)()
def _exists_sync(self, key: str) -> bool: db_key, db_namespace = self._make_key(key) now = datetime.datetime.now(datetime.timezone.utc) with self._get_sync_session() as session: # Use count for potentially better performance if only existence is needed stmt = ( select(self._model.key) .where( self._model.key == db_key, self._model.namespace == db_namespace, (self._model.expires_at.is_(None)) | (self._model.expires_at > now), ) .limit(1) ) # limit 1 is important for performance return session.execute(stmt).scalar_one_or_none() is not None async def _exists_async(self, key: str) -> bool: db_key, db_namespace = self._make_key(key) now = datetime.datetime.now(datetime.timezone.utc) async with self._get_async_session() as session: # Use count for potentially better performance if only existence is needed stmt = ( select(self._model.key) .where( self._model.key == db_key, self._model.namespace == db_namespace, (self._model.expires_at.is_(None)) | (self._model.expires_at > now), ) .limit(1) ) result = await session.execute(stmt) return result.scalar_one_or_none() is not None
[docs] async def exists(self, key: str) -> bool: """Check if a key exists. Handles both sync and async backends. Args: key: The key to check if it exists. Returns: True if the key exists, False otherwise. """ if self._is_async: return await self._exists_async(key) return await async_(self._exists_sync)(key)
def _expires_in_sync(self, key: str) -> Optional[int]: db_key, db_namespace = self._make_key(key) now = datetime.datetime.now(datetime.timezone.utc) with self._get_sync_session() as session: stmt = select(self._model.expires_at).where( self._model.key == db_key, self._model.namespace == db_namespace, self._model.expires_at.is_not(None), # Explicitly check for non-null expiry self._model.expires_at > now, ) expires_at = session.execute(stmt).scalar_one_or_none() if expires_at: return int((expires_at - now).total_seconds()) return None async def _expires_in_async(self, key: str) -> Optional[int]: db_key, db_namespace = self._make_key(key) now = datetime.datetime.now(datetime.timezone.utc) async with self._get_async_session() as session: stmt = select(self._model.expires_at).where( self._model.key == db_key, self._model.namespace == db_namespace, self._model.expires_at.is_not(None), # Explicitly check for non-null expiry self._model.expires_at > now, ) result = await session.execute(stmt) expires_at = result.scalar_one_or_none() if expires_at: return int((expires_at - now).total_seconds()) return None
[docs] async def expires_in(self, key: str) -> Optional[int]: """Get expiration time. Handles both sync and async backends. Args: key: The key to get the expiration time for. Returns: The expiration time in seconds, or None if the key does not exist or has no expiration time. """ if self._is_async: return await self._expires_in_async(key) return await async_(self._expires_in_sync)(key)
[docs] def with_namespace(self, namespace: str) -> "SQLAlchemyStore[SQLAlchemyConfigT]": """Return a new :class:`SQLAlchemyStore` with a nested virtual key namespace.""" new_namespace = f"{self.namespace}_{namespace}" if self.namespace else namespace # We pass the original config, model type to the new instance return type(self)( config=self._config, # Pass the original config model=self._model, namespace=new_namespace, )
async def __aexit__( self, exc_type: object, exc_val: object, exc_tb: object, ) -> None: # Session lifecycle is managed by the context managers (_get_sync_session/_get_async_session) # or externally via the provided config's lifespan manager. No explicit cleanup needed here. pass