Source code for advanced_alchemy.routing.session

"""Routing-aware session classes for read/write routing.

This module provides custom SQLAlchemy session classes that implement
read/write routing via the ``get_bind()`` method.
"""

from typing import TYPE_CHECKING, Any, Optional, Union

from sqlalchemy import Delete, Insert, Update
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session

from advanced_alchemy.routing.context import (
    bind_group_var,
    force_primary_var,
    reset_routing_context,
    set_sticky_primary,
    stick_to_primary_var,
)

if TYPE_CHECKING:
    from sqlalchemy import Engine
    from sqlalchemy.ext.asyncio import AsyncEngine
    from sqlalchemy.orm import Mapper

    from advanced_alchemy.config.routing import RoutingConfig
    from advanced_alchemy.routing.selectors import EngineSelector


__all__ = (
    "RoutingAsyncSession",
    "RoutingSyncSession",
)


[docs] class RoutingSyncSession(Session): """Synchronous session with read/write routing via ``get_bind()``. This session class extends SQLAlchemy's :class:`Session` to provide automatic routing of operations to different engine groups (e.g. writer/reader). The routing decision is made in ``get_bind()`` based on: 1. Execution options (``bind_group``) 2. Context variables (``bind_group``, ``force_primary``) 3. Stickiness state 4. Operation type (Write vs Read) Attributes: _default_engine: The default (write) database engine. _selectors: Map of group names to engine selectors. _routing_config: Configuration for routing behavior. """ _default_engine: "Engine" _selectors: "dict[str, EngineSelector[Engine]]" _routing_config: "RoutingConfig"
[docs] def __init__( self, routing_config: "RoutingConfig", selectors: "dict[str, EngineSelector[Engine]]", default_engine: "Engine", **kwargs: Any, ) -> None: """Initialize the routing session. Args: routing_config: Configuration for routing behavior. selectors: Map of group names to engine selectors. default_engine: The default (fallback/write) engine. **kwargs: Additional arguments passed to the parent Session. """ kwargs.pop("bind", None) kwargs.pop("binds", None) super().__init__(**kwargs) self._default_engine = default_engine self._selectors = selectors self._routing_config = routing_config
[docs] def get_bind( self, mapper: Optional[Union["Mapper[Any]", type[Any]]] = None, clause: Optional[Any] = None, **kwargs: Any, ) -> "Engine": """Route to appropriate engine based on operation and context. Args: mapper: Optional mapper for the operation. clause: The SQL clause being executed. **kwargs: Additional keyword arguments. Returns: The selected engine. """ # 1. Check for explicit bind group in execution options if clause is not None and hasattr(clause, "_execution_options"): bind_group = clause._execution_options.get("bind_group") # noqa: SLF001 if bind_group: return self._get_engine_for_group(bind_group) # 2. Check context variable for bind group bind_group = bind_group_var.get() if bind_group: return self._get_engine_for_group(bind_group) # 3. Check if we should force/stick to default (writer) if self._should_use_default_group(clause): return self._get_engine_for_group(self._routing_config.default_group) # 4. Read operation -> use read group return self._get_engine_for_group(self._routing_config.read_group)
def _get_engine_for_group(self, group: str) -> "Engine": """Get an engine for the specified group. Args: group: Name of the engine group. Returns: An engine from the group, or the default engine if group not found. """ if group in self._selectors: selector = self._selectors[group] if selector.has_engines(): return selector.next() # Fallback to default engine if group has no selector/engines # or if it's the default group and we want to be safe return self._default_engine def _should_use_default_group(self, clause: Optional[Any]) -> bool: """Determine if the operation should use the default (writer) group. Args: clause: The SQL clause being executed. Returns: ``True`` if default group should be used. """ if not self._routing_config.enabled: return True if force_primary_var.get(): return True if stick_to_primary_var.get(): return True if self._flushing: if self._routing_config.sticky_after_write: set_sticky_primary() return True if clause is not None and isinstance(clause, (Insert, Update, Delete)): if self._routing_config.sticky_after_write: set_sticky_primary() return True return self._has_for_update(clause) def _has_for_update(self, clause: Optional[Any]) -> bool: """Check if the clause has FOR UPDATE. Args: clause: The SQL clause to check. Returns: ``True`` if FOR UPDATE is present. """ if clause is None: return False for_update_arg = getattr(clause, "_for_update_arg", None) return for_update_arg is not None
[docs] def commit(self) -> None: """Commit the transaction and reset routing state.""" super().commit() if self._routing_config.reset_stickiness_on_commit: reset_routing_context()
[docs] def rollback(self) -> None: """Rollback the transaction and reset routing state.""" super().rollback() reset_routing_context()
[docs] class RoutingAsyncSession(AsyncSession): """Async session with read/write routing support. Wraps :class:`RoutingSyncSession` to provide async routing capabilities. """ sync_session_class: "type[Session]" = RoutingSyncSession
[docs] def __init__( self, routing_config: "RoutingConfig", selectors: "dict[str, EngineSelector[AsyncEngine]]", default_engine: "AsyncEngine", **kwargs: Any, ) -> None: """Initialize the async routing session. Args: routing_config: Configuration for routing behavior. selectors: Map of group names to async engine selectors. default_engine: The default (fallback/write) async engine. **kwargs: Additional arguments passed to the parent AsyncSession. """ kwargs.pop("bind", None) kwargs.pop("binds", None) # Convert async selectors to sync selectors for the wrapped session sync_selectors = {name: _SyncEngineSelectorWrapper(selector) for name, selector in selectors.items()} super().__init__( sync_session_class=RoutingSyncSession, routing_config=routing_config, selectors=sync_selectors, default_engine=default_engine.sync_engine, **kwargs, ) self._default_engine = default_engine self._selectors = selectors self._routing_config = routing_config
@property def primary_engine(self) -> "AsyncEngine": """Get the primary (default) async engine. Returns: The default database engine. """ return self._default_engine @property def routing_config(self) -> "RoutingConfig": """Get the routing configuration. Returns: The routing configuration. """ return self._routing_config
class _SyncEngineSelectorWrapper: """Wrapper to adapt async engine selector for sync session. This wrapper extracts sync engines from async engines in the selector. """ __slots__ = ("_async_selector",) def __init__(self, async_selector: "EngineSelector[AsyncEngine]") -> None: """Initialize the wrapper. Args: async_selector: The async engine selector to wrap. """ self._async_selector = async_selector def has_engines(self) -> bool: """Check if any engines are configured. Returns: ``True`` if at least one engine is available. """ return self._async_selector.has_engines() def has_replicas(self) -> bool: """Check if any replicas are configured (alias for has_engines). Returns: ``True`` if at least one engine is available. """ return self.has_engines() def next(self) -> "Engine": """Get the next engine's sync engine. Returns: The sync engine for the next selection. """ return self._async_selector.next().sync_engine