Source code for advanced_alchemy.extensions.litestar.plugins.serialization

import contextlib
from typing import TYPE_CHECKING, Any

from litestar.plugins import InitPluginProtocol, SerializationPlugin
from litestar.typing import FieldDefinition
from sqlalchemy.orm import DeclarativeBase

from advanced_alchemy.extensions.litestar.dto import SQLAlchemyDTO
from advanced_alchemy.extensions.litestar.plugins import _slots_base
from advanced_alchemy.utils.serialization import DEFAULT_TYPE_ENCODERS

if TYPE_CHECKING:
    from collections.abc import Callable

    from litestar.config.app import AppConfig


def _get_aa_type_encoders() -> "dict[type, Callable[[Any], Any]]":
    """Return Advanced Alchemy's built-in Litestar type encoders.

    These cover database-specific types (asyncpg's ``pgproto.UUID``,
    ``uuid_utils.UUID``) that need explicit serialization to JSON-friendly
    forms.  They are merged into ``AppConfig.type_encoders`` with lower
    precedence than user-supplied encoders.
    """
    encoders: dict[type, Callable[[Any], Any]] = {**DEFAULT_TYPE_ENCODERS}
    encoders.update(_get_aa_litestar_type_encoders())
    return encoders


def _get_aa_litestar_type_encoders() -> "dict[type, Callable[[Any], Any]]":
    """Return Litestar-only compatibility encoders for database UUID types."""
    encoders: dict[type, Callable[[Any], Any]] = {}

    with contextlib.suppress(ImportError):
        from asyncpg.pgproto import pgproto  # pyright: ignore[reportMissingImports]

        encoders[pgproto.UUID] = str

    with contextlib.suppress(ImportError):
        import uuid_utils  # pyright: ignore[reportMissingImports]

        encoders[uuid_utils.UUID] = str  # pyright: ignore[reportUnknownMemberType]

    return encoders


def _is_uuid_utils_uuid_type(value: Any) -> bool:
    with contextlib.suppress(ImportError):
        import uuid_utils  # pyright: ignore[reportMissingImports]

        return value is uuid_utils.UUID  # pyright: ignore[reportUnknownMemberType]
    return False


def _decode_uuid_utils_uuid(target_type: type, value: Any) -> Any:
    return target_type(str(value))


def _get_aa_type_decoders() -> "list[tuple[Callable[[Any], bool], Callable[[type, Any], Any]]]":
    """Return Advanced Alchemy's built-in Litestar type decoders.

    Currently covers ``uuid_utils.UUID`` for request-side parsing.  Decoders
    are merged into ``AppConfig.type_decoders`` with lower precedence than
    user-supplied decoders.
    """
    decoders: list[tuple[Callable[[Any], bool], Callable[[type, Any], Any]]] = []

    with contextlib.suppress(ImportError):
        import uuid_utils  # pyright: ignore[reportMissingImports]

        if uuid_utils is not None:
            decoders.append((_is_uuid_utils_uuid_type, _decode_uuid_utils_uuid))

    return decoders


def merge_aa_litestar_type_encoders(app_config: "AppConfig", *, include_default_encoders: bool) -> None:
    """Merge AA's Litestar encoders/decoders into app config.

    User-supplied encoders and decoders keep precedence.  ``include_default_encoders``
    is enabled by ``SQLAlchemySerializationPlugin``; ``SQLAlchemyInitPlugin`` uses
    the UUID-only path to preserve its released direct-registration behavior.
    """
    aa_encoders = _get_aa_type_encoders() if include_default_encoders else _get_aa_litestar_type_encoders()
    aa_decoders = _get_aa_type_decoders()
    app_config.type_encoders = {**aa_encoders, **(app_config.type_encoders or {})}

    type_decoders = list(app_config.type_decoders or [])
    for decoder in aa_decoders:
        if decoder not in type_decoders:
            type_decoders.append(decoder)
    app_config.type_decoders = type_decoders


[docs] class SQLAlchemySerializationPlugin(SerializationPlugin, InitPluginProtocol, _slots_base.SlotsBase):
[docs] def __init__(self) -> None: self._type_dto_map: dict[type[DeclarativeBase], type[SQLAlchemyDTO[Any]]] = {}
[docs] def on_app_init(self, app_config: "AppConfig") -> "AppConfig": """Register Advanced Alchemy's built-in type encoders and decoders. AA encoders/decoders are added with lower precedence so user-supplied ``type_encoders`` / ``type_decoders`` on the application config win. """ merge_aa_litestar_type_encoders(app_config, include_default_encoders=True) return app_config
[docs] def supports_type(self, field_definition: FieldDefinition) -> bool: return ( field_definition.is_collection and field_definition.has_inner_subclass_of(DeclarativeBase) ) or field_definition.is_subclass_of(DeclarativeBase)
[docs] def create_dto_for_type(self, field_definition: FieldDefinition) -> type[SQLAlchemyDTO[Any]]: # assumes that the type is a container of SQLAlchemy models or a single SQLAlchemy model annotation = next( ( inner_type.annotation for inner_type in field_definition.inner_types if inner_type.is_subclass_of(DeclarativeBase) ), field_definition.annotation, ) if annotation in self._type_dto_map: return self._type_dto_map[annotation] self._type_dto_map[annotation] = dto_type = SQLAlchemyDTO[annotation] # type:ignore[valid-type] return dto_type