Source code for advanced_alchemy.routing.maker

"""Session maker factories for read/write routing.

This module provides session maker classes that create routing-aware sessions
with properly configured primary and replica engines.
"""

from typing import Any, Callable, Optional

from sqlalchemy import Engine, create_engine
from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine

from advanced_alchemy.config.routing import RoutingConfig, RoutingStrategy
from advanced_alchemy.exceptions import ImproperConfigurationError
from advanced_alchemy.routing.selectors import EngineSelector, RandomSelector, RoundRobinSelector
from advanced_alchemy.routing.session import RoutingAsyncSession, RoutingSyncSession

__all__ = (
    "RoutingAsyncSessionMaker",
    "RoutingSyncSessionMaker",
)


[docs] class RoutingSyncSessionMaker: """Factory for creating sync routing sessions. This class creates :class:`RoutingSyncSession` instances with properly configured engines and routing selectors. Example: Creating a routing session maker:: maker = RoutingSyncSessionMaker( routing_config=RoutingConfig( engines={ "writer": ["postgresql://primary"], "reader": ["postgresql://replica1"], } ), engine_config={"pool_size": 10}, ) session = maker() """ __slots__ = ( "_default_engine", "_engine_config", "_engines", "_routing_config", "_selectors", "_session_config", )
[docs] def __init__( self, routing_config: RoutingConfig, engine_config: Optional[dict[str, Any]] = None, session_config: Optional[dict[str, Any]] = None, create_engine_callable: Callable[[str], Engine] = create_engine, ) -> None: """Initialize the session maker. Args: routing_config: Configuration for read/write routing. engine_config: Configuration options for engine creation. session_config: Configuration options for session creation. create_engine_callable: Callable to create engines (for testing). """ self._routing_config = routing_config self._engine_config = engine_config or {} self._session_config = session_config or {} self._engines: dict[str, list[Engine]] = {} self._selectors: dict[str, EngineSelector[Engine]] = {} # Initialize engines and selectors for all groups for group in routing_config.engines: engines_for_group: list[Engine] = [] for config in routing_config.get_engine_configs(group): engine = self._create_engine(config.connection_string, create_engine_callable) engines_for_group.append(engine) if engines_for_group: self._engines[group] = engines_for_group self._selectors[group] = self._create_selector( engines_for_group, routing_config.routing_strategy, ) # Set default engine (required) default_group = routing_config.default_group if ( default_group not in self._engines or not self._engines[default_group] ) and not routing_config.primary_connection_string: # Only raise if strict legacy check fails too? # Actually, post_init maps primary_connection_string to engines, so we just check engines. msg = ( f"Default group '{default_group}' has no engines configured. " "Ensure 'engines' contains this group or 'primary_connection_string' is set." ) raise ImproperConfigurationError(msg) self._default_engine = self._engines[default_group][0]
def _create_engine( self, connection_string: str, create_engine_callable: Callable[[str], Engine], ) -> Engine: """Create an engine with the configured options. Args: connection_string: Database connection string. create_engine_callable: Callable to create the engine. Returns: The created engine. """ try: return create_engine_callable(connection_string, **self._engine_config) except TypeError: config = self._engine_config.copy() config.pop("json_deserializer", None) config.pop("json_serializer", None) return create_engine_callable(connection_string, **config) def _create_selector( self, engines: list[Engine], strategy: RoutingStrategy, ) -> EngineSelector[Engine]: """Create an engine selector for the given strategy. Args: engines: List of engines. strategy: The routing strategy to use. Returns: The appropriate selector instance. """ if strategy == RoutingStrategy.RANDOM: return RandomSelector(engines) return RoundRobinSelector(engines)
[docs] def __call__(self) -> RoutingSyncSession: """Create a new routing session. Any ``bind`` passed in the session config is ignored because routing controls bind selection. Returns: A new :class:`RoutingSyncSession` instance. """ session_config = self._session_config.copy() session_config.pop("bind", None) return RoutingSyncSession( routing_config=self._routing_config, selectors=self._selectors, default_engine=self._default_engine, **session_config, )
@property def primary_engine(self) -> Engine: """Get the primary (default) engine. Returns: The primary database engine. """ return self._default_engine @property def replica_engines(self) -> list[Engine]: """Get the replica engines (from read_group). Returns: List of replica database engines. """ return self._engines.get(self._routing_config.read_group, [])
[docs] def close_all(self) -> None: """Close all engines and release connections. Call this when shutting down to properly release database connections. """ for engine_list in self._engines.values(): for engine in engine_list: engine.dispose()
[docs] class RoutingAsyncSessionMaker: """Factory for creating async routing sessions. This class creates :class:`RoutingAsyncSession` instances with properly configured async engines and routing selectors. Example: Creating an async routing session maker:: maker = RoutingAsyncSessionMaker( routing_config=RoutingConfig( engines={ "writer": ["postgresql+asyncpg://primary"], "reader": ["postgresql+asyncpg://replica1"], } ), engine_config={"pool_size": 10}, ) async with maker() as session: result = await session.execute(select(User)) """ __slots__ = ( "_default_engine", "_engine_config", "_engines", "_routing_config", "_selectors", "_session_config", )
[docs] def __init__( self, routing_config: RoutingConfig, engine_config: Optional[dict[str, Any]] = None, session_config: Optional[dict[str, Any]] = None, create_engine_callable: Callable[[str], AsyncEngine] = create_async_engine, ) -> None: """Initialize the async session maker. Args: routing_config: Configuration for read/write routing. engine_config: Configuration options for engine creation. session_config: Configuration options for session creation. create_engine_callable: Callable to create async engines (for testing). """ self._routing_config = routing_config self._engine_config = engine_config or {} self._session_config = session_config or {} self._engines: dict[str, list[AsyncEngine]] = {} self._selectors: dict[str, EngineSelector[AsyncEngine]] = {} # Initialize engines and selectors for all groups for group in routing_config.engines: engines_for_group: list[AsyncEngine] = [] for config in routing_config.get_engine_configs(group): engine = self._create_engine(config.connection_string, create_engine_callable) engines_for_group.append(engine) if engines_for_group: self._engines[group] = engines_for_group self._selectors[group] = self._create_selector( engines_for_group, routing_config.routing_strategy, ) # Set default engine (required) default_group = routing_config.default_group if default_group not in self._engines or not self._engines[default_group]: msg = ( f"Default group '{default_group}' has no engines configured. " "Ensure 'engines' contains this group or 'primary_connection_string' is set." ) raise ImproperConfigurationError(msg) self._default_engine = self._engines[default_group][0]
def _create_engine( self, connection_string: str, create_engine_callable: Callable[[str], AsyncEngine], ) -> AsyncEngine: """Create an async engine with the configured options. Args: connection_string: Database connection string. create_engine_callable: Callable to create the engine. Returns: The created async engine. """ try: return create_engine_callable(connection_string, **self._engine_config) except TypeError: config = self._engine_config.copy() config.pop("json_deserializer", None) config.pop("json_serializer", None) return create_engine_callable(connection_string, **config) def _create_selector( self, engines: list[AsyncEngine], strategy: RoutingStrategy, ) -> EngineSelector[AsyncEngine]: """Create an engine selector for the given strategy. Args: engines: List of replica async engines. strategy: The routing strategy to use. Returns: The appropriate selector instance. """ if strategy == RoutingStrategy.RANDOM: return RandomSelector(engines) return RoundRobinSelector(engines)
[docs] def __call__(self) -> RoutingAsyncSession: """Create a new async routing session. Any ``bind`` passed in the session config is ignored because routing controls bind selection. Returns: A new :class:`RoutingAsyncSession` instance. """ session_config = self._session_config.copy() session_config.pop("bind", None) return RoutingAsyncSession( routing_config=self._routing_config, selectors=self._selectors, default_engine=self._default_engine, **session_config, )
@property def primary_engine(self) -> AsyncEngine: """Get the primary (default) async engine. Returns: The primary database async engine. """ return self._default_engine @property def replica_engines(self) -> list[AsyncEngine]: """Get the replica async engines (from read_group). Returns: List of replica database async engines. """ return self._engines.get(self._routing_config.read_group, [])
[docs] async def close_all(self) -> None: """Close all engines and release connections. Call this when shutting down to properly release database connections. """ for engine_list in self._engines.values(): for engine in engine_list: await engine.dispose()