Source code for advanced_alchemy.config.asyncio

from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Callable, Optional, Union, cast

from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine, AsyncSession, async_sessionmaker, create_async_engine

from advanced_alchemy._listeners import set_async_context
from advanced_alchemy.config.common import (
    GenericAlembicConfig,
    GenericSessionConfig,
    GenericSQLAlchemyConfig,
)
from advanced_alchemy.exceptions import ImproperConfigurationError
from advanced_alchemy.utils.dataclass import Empty

if TYPE_CHECKING:
    from typing import Callable

    from sqlalchemy.orm import Session

    from advanced_alchemy.config.routing import RoutingConfig
    from advanced_alchemy.utils.dataclass import EmptyType

__all__ = (
    "AlembicAsyncConfig",
    "AsyncSessionConfig",
    "SQLAlchemyAsyncConfig",
)


[docs] @dataclass class AsyncSessionConfig(GenericSessionConfig[AsyncConnection, AsyncEngine, AsyncSession]): """SQLAlchemy async session config.""" sync_session_class: "Optional[Union[type[Session], EmptyType]]" = Empty """A :class:`Session <sqlalchemy.orm.Session>` subclass or other callable which will be used to construct the :class:`Session <sqlalchemy.orm.Session>` which will be proxied. This parameter may be used to provide custom :class:`Session <sqlalchemy.orm.Session>` subclasses. Defaults to the :attr:`AsyncSession.sync_session_class <sqlalchemy.ext.asyncio.AsyncSession.sync_session_class>` class-level attribute."""
[docs] @dataclass class AlembicAsyncConfig(GenericAlembicConfig): """Configuration for an Async Alembic's Config class. .. seealso:: https://alembic.sqlalchemy.org/en/latest/api/config.html """
[docs] @dataclass class SQLAlchemyAsyncConfig(GenericSQLAlchemyConfig[AsyncEngine, AsyncSession, async_sessionmaker[AsyncSession]]): """Async SQLAlchemy Configuration. Note: The alembic configuration options are documented in the Alembic documentation. Example: Basic async configuration:: config = SQLAlchemyAsyncConfig( connection_string="postgresql+asyncpg://user:pass@localhost/db", ) Configuration with read/write routing:: from advanced_alchemy.config.routing import RoutingConfig config = SQLAlchemyAsyncConfig( routing_config=RoutingConfig( primary_connection_string="postgresql+asyncpg://user:pass@primary/db", read_replicas=[ "postgresql+asyncpg://user:pass@replica/db" ], ), ) """ create_engine_callable: "Callable[[str], AsyncEngine]" = create_async_engine """Callable that creates an :class:`AsyncEngine <sqlalchemy.ext.asyncio.AsyncEngine>` instance or instance of its subclass. """ session_config: AsyncSessionConfig = field(default_factory=AsyncSessionConfig) # pyright: ignore[reportIncompatibleVariableOverride] """Configuration options for the :class:`async_sessionmaker<sqlalchemy.ext.asyncio.async_sessionmaker>`.""" session_maker_class: "type[async_sessionmaker[AsyncSession]]" = async_sessionmaker # pyright: ignore[reportIncompatibleVariableOverride] """Sessionmaker class to use.""" alembic_config: "AlembicAsyncConfig" = field(default_factory=AlembicAsyncConfig) """Configuration for the SQLAlchemy Alembic migrations. The configuration options are documented in the Alembic documentation. """ routing_config: "Optional[RoutingConfig]" = None """Optional read/write routing configuration. When provided, enables automatic routing of read operations to replicas and write operations to the primary database. .. note:: When using ``routing_config``, do not set ``connection_string``. The primary connection is specified in the routing config. """ def __post_init__(self) -> None: # Validate routing config vs connection_string if self.routing_config is not None and self.connection_string is not None: msg = "Provide either 'connection_string' or 'routing_config', not both" raise ImproperConfigurationError(msg) # If routing_config is set, use its primary as the connection_string for compatibility if self.routing_config is not None: self.connection_string = self.routing_config.primary_connection_string if self.connection_string is None: # Try to get from default group engines configs = self.routing_config.get_engine_configs(self.routing_config.default_group) if configs: self.connection_string = configs[0].connection_string super().__post_init__() def __hash__(self) -> int: return super().__hash__() def __eq__(self, other: object) -> bool: return super().__eq__(other)
[docs] def create_session_maker(self) -> "Callable[[], AsyncSession]": """Get a session maker. If routing is configured, returns a routing-aware session maker. Otherwise, returns a standard session maker. Returns: A callable that creates session instances. """ if self.session_maker: return self.session_maker from sqlalchemy import event from advanced_alchemy._listeners import ( AsyncCacheListener, AsyncFileObjectListener, touch_updated_timestamp, ) # Use routing session maker if routing is configured if self.routing_config is not None: from advanced_alchemy.routing import RoutingAsyncSessionMaker routing_maker: Callable[[], AsyncSession] = RoutingAsyncSessionMaker( routing_config=self.routing_config, engine_config=self.engine_config_dict, session_config=self.session_config_dict, ) self.session_maker = routing_maker else: self.session_maker = cast("Callable[[], AsyncSession]", super().create_session_maker()) # type: ignore[redundant-cast] if isinstance(self.session_maker, async_sessionmaker): session_maker = cast( "async_sessionmaker[AsyncSession]", self.session_maker, # pyright: ignore[reportUnknownMemberType] ) if self.enable_file_object_listener: event.listen(session_maker, "before_flush", AsyncFileObjectListener.before_flush) event.listen(session_maker, "after_commit", AsyncFileObjectListener.after_commit) event.listen(session_maker, "after_rollback", AsyncFileObjectListener.after_rollback) if self.enable_touch_updated_timestamp_listener: event.listen(session_maker, "before_flush", touch_updated_timestamp) event.listen(session_maker, "after_commit", AsyncCacheListener.after_commit) event.listen(session_maker, "after_rollback", AsyncCacheListener.after_rollback) if self.session_maker is None: # pyright: ignore msg = "Session maker was not initialized." # type: ignore[unreachable] raise ImproperConfigurationError(msg) return cast("async_sessionmaker[AsyncSession]", self.session_maker) # pyright: ignore[reportUnknownMemberType]
[docs] @asynccontextmanager async def get_session( self, ) -> AsyncGenerator[AsyncSession, None]: """Get a session from the session maker. Yields: AsyncGenerator[AsyncSession, None]: An async context manager that yields an AsyncSession. """ session_maker = self.create_session_maker() set_async_context(True) async with session_maker() as session: yield session