import contextlib
import datetime
import random
import string
from collections.abc import Iterable, Sequence
from functools import partial
from typing import (
TYPE_CHECKING,
Any,
Final,
Literal,
Optional,
Protocol,
Union,
cast,
runtime_checkable,
)
from sqlalchemy import (
Delete,
Result,
Row,
Select,
TextClause,
Update,
any_,
delete,
inspect,
over,
select,
text,
update,
)
from sqlalchemy import func as sql_func
from sqlalchemy.exc import MissingGreenlet, NoInspectionAvailable
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.ext.asyncio.scoping import async_scoped_session
from sqlalchemy.orm import InstrumentedAttribute
from sqlalchemy.orm.strategy_options import _AbstractLoad # pyright: ignore[reportPrivateUsage]
from sqlalchemy.sql import ColumnElement
from sqlalchemy.sql.dml import ReturningDelete, ReturningUpdate
from sqlalchemy.sql.selectable import ForUpdateArg, ForUpdateParameter
from advanced_alchemy.exceptions import ErrorMessages, NotFoundError, RepositoryError, wrap_sqlalchemy_exception
from advanced_alchemy.filters import StatementFilter, StatementTypeT
from advanced_alchemy.repository._util import (
DEFAULT_ERROR_MESSAGE_TEMPLATES,
DEFAULT_SAFE_TYPES,
FilterableRepository,
FilterableRepositoryProtocol,
LoadSpec,
_build_list_cache_key, # pyright: ignore
column_has_defaults,
compare_values,
get_abstract_loader_options,
get_instrumented_attr,
was_attribute_set,
)
from advanced_alchemy.repository.typing import MISSING, ModelT, OrderingPair, T
from advanced_alchemy.service.typing import schema_dump
from advanced_alchemy.utils.dataclass import Empty, EmptyType
from advanced_alchemy.utils.text import slugify
if TYPE_CHECKING:
from sqlalchemy.engine.interfaces import _CoreSingleExecuteParams # pyright: ignore[reportPrivateUsage]
from advanced_alchemy.cache.manager import CacheManager
DEFAULT_INSERTMANYVALUES_MAX_PARAMETERS: Final = 950
POSTGRES_VERSION_SUPPORTING_MERGE: Final = 15
[docs]
@runtime_checkable
class SQLAlchemyAsyncRepositoryProtocol(FilterableRepositoryProtocol[ModelT], Protocol[ModelT]):
"""Base Protocol"""
id_attribute: str
match_fields: Optional[Union[list[str], str]] = None
statement: Select[tuple[ModelT]]
session: Union[AsyncSession, async_scoped_session[AsyncSession]]
auto_expunge: bool
auto_refresh: bool
auto_commit: bool
order_by: Optional[Union[list[OrderingPair], OrderingPair]] = None
error_messages: Optional[ErrorMessages] = None
wrap_exceptions: bool = True
self,
*,
statement: Optional[Select[tuple[ModelT]]] = None,
session: Union[AsyncSession, async_scoped_session[AsyncSession]],
auto_expunge: bool = False,
auto_refresh: bool = True,
auto_commit: bool = False,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
order_by: Optional[Union[list[OrderingPair], OrderingPair]] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
wrap_exceptions: bool = True,
**kwargs: Any,
) -> None: ...
[docs]
@classmethod
def get_id_attribute_value(
cls,
item: Union[ModelT, type[ModelT]],
id_attribute: Optional[Union[str, InstrumentedAttribute[Any]]] = None,
) -> Any: ...
[docs]
@classmethod
def set_id_attribute_value(
cls,
item_id: Any,
item: ModelT,
id_attribute: Optional[Union[str, InstrumentedAttribute[Any]]] = None,
) -> ModelT: ...
[docs]
@staticmethod
def check_not_found(item_or_none: Optional[ModelT]) -> ModelT: ...
self,
data: ModelT,
*,
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
auto_refresh: Optional[bool] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
bind_group: Optional[str] = None,
) -> ModelT: ...
self,
data: list[ModelT],
*,
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
bind_group: Optional[str] = None,
) -> Sequence[ModelT]: ...
self,
item_id: Any,
*,
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
id_attribute: Optional[Union[str, InstrumentedAttribute[Any]]] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
bind_group: Optional[str] = None,
) -> ModelT: ...
self,
item_ids: list[Any],
*,
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
id_attribute: Optional[Union[str, InstrumentedAttribute[Any]]] = None,
chunk_size: Optional[int] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
bind_group: Optional[str] = None,
) -> Sequence[ModelT]: ...
[docs]
async def delete_where(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
load: Optional[LoadSpec] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
execution_options: Optional[dict[str, Any]] = None,
sanity_check: bool = True,
bind_group: Optional[str] = None,
**kwargs: Any,
) -> Sequence[ModelT]: ...
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
load: Optional[LoadSpec] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
execution_options: Optional[dict[str, Any]] = None,
bind_group: Optional[str] = None,
**kwargs: Any,
) -> bool: ...
self,
item_id: Any,
*,
auto_expunge: Optional[bool] = None,
statement: Optional[Select[tuple[ModelT]]] = None,
id_attribute: Optional[Union[str, InstrumentedAttribute[Any]]] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
with_for_update: ForUpdateParameter = None,
bind_group: Optional[str] = None,
) -> ModelT: ...
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
auto_expunge: Optional[bool] = None,
statement: Optional[Select[tuple[ModelT]]] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
with_for_update: ForUpdateParameter = None,
bind_group: Optional[str] = None,
**kwargs: Any,
) -> ModelT: ...
[docs]
async def get_one_or_none(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
auto_expunge: Optional[bool] = None,
statement: Optional[Select[tuple[ModelT]]] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
with_for_update: ForUpdateParameter = None,
bind_group: Optional[str] = None,
**kwargs: Any,
) -> Optional[ModelT]: ...
[docs]
async def get_or_upsert(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
match_fields: Optional[Union[list[str], str]] = None,
upsert: bool = True,
attribute_names: Optional[Iterable[str]] = None,
with_for_update: ForUpdateParameter = None,
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
auto_refresh: Optional[bool] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
bind_group: Optional[str] = None,
**kwargs: Any,
) -> tuple[ModelT, bool]: ...
[docs]
async def get_and_update(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
match_fields: Optional[Union[list[str], str]] = None,
attribute_names: Optional[Iterable[str]] = None,
with_for_update: ForUpdateParameter = None,
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
auto_refresh: Optional[bool] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
bind_group: Optional[str] = None,
**kwargs: Any,
) -> tuple[ModelT, bool]: ...
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
statement: Optional[Select[tuple[ModelT]]] = None,
load: Optional[LoadSpec] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
execution_options: Optional[dict[str, Any]] = None,
bind_group: Optional[str] = None,
**kwargs: Any,
) -> int: ...
self,
data: ModelT,
*,
attribute_names: Optional[Iterable[str]] = None,
with_for_update: ForUpdateParameter = None,
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
auto_refresh: Optional[bool] = None,
id_attribute: Optional[Union[str, InstrumentedAttribute[Any]]] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
bind_group: Optional[str] = None,
) -> ModelT: ...
self,
data: list[ModelT],
*,
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
bind_group: Optional[str] = None,
) -> list[ModelT]: ...
def _get_update_many_statement(
self,
model_type: type[ModelT],
supports_returning: bool,
loader_options: Optional[list[_AbstractLoad]],
execution_options: Optional[dict[str, Any]],
) -> Union[Update, ReturningUpdate[tuple[ModelT]]]: ...
self,
data: ModelT,
*,
attribute_names: Optional[Iterable[str]] = None,
with_for_update: ForUpdateParameter = None,
auto_expunge: Optional[bool] = None,
auto_commit: Optional[bool] = None,
auto_refresh: Optional[bool] = None,
match_fields: Optional[Union[list[str], str]] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
bind_group: Optional[str] = None,
) -> ModelT: ...
self,
data: list[ModelT],
*,
auto_expunge: Optional[bool] = None,
auto_commit: Optional[bool] = None,
no_merge: bool = False,
match_fields: Optional[Union[list[str], str]] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
bind_group: Optional[str] = None,
) -> list[ModelT]: ...
[docs]
async def list_and_count(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
auto_expunge: Optional[bool] = None,
statement: Optional[Select[tuple[ModelT]]] = None,
count_with_window_function: Optional[bool] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
order_by: Optional[Union[list[OrderingPair], OrderingPair]] = None,
use_cache: bool = True,
bind_group: Optional[str] = None,
**kwargs: Any,
) -> tuple[list[ModelT], int]: ...
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
auto_expunge: Optional[bool] = None,
statement: Optional[Select[tuple[ModelT]]] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
order_by: Optional[Union[list[OrderingPair], OrderingPair]] = None,
use_cache: bool = True,
bind_group: Optional[str] = None,
**kwargs: Any,
) -> list[ModelT]: ...
[docs]
@classmethod
async def check_health(cls, session: Union[AsyncSession, async_scoped_session[AsyncSession]]) -> bool: ...
[docs]
@runtime_checkable
class SQLAlchemyAsyncSlugRepositoryProtocol(SQLAlchemyAsyncRepositoryProtocol[ModelT], Protocol[ModelT]):
"""Protocol for SQLAlchemy repositories that support slug-based operations.
Extends the base repository protocol to add slug-related functionality.
Type Parameters:
ModelT: The SQLAlchemy model type this repository handles.
"""
[docs]
async def get_by_slug(
self,
slug: str,
*,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
bind_group: Optional[str] = None,
**kwargs: Any,
) -> Optional[ModelT]:
"""Get a model instance by its slug.
Args:
slug: The slug value to search for.
error_messages: Optional custom error message templates.
load: Specification for eager loading of relationships.
execution_options: Options for statement execution.
bind_group: Optional routing group to use for the operation.
**kwargs: Additional filtering criteria.
Returns:
ModelT | None: The found model instance or None if not found.
"""
...
[docs]
async def get_available_slug(
self,
value_to_slugify: str,
**kwargs: Any,
) -> str:
"""Generate a unique slug for a given value.
Args:
value_to_slugify: The string to convert to a slug.
**kwargs: Additional parameters for slug generation.
Returns:
str: A unique slug derived from the input value.
"""
...
[docs]
class SQLAlchemyAsyncRepository(SQLAlchemyAsyncRepositoryProtocol[ModelT], FilterableRepository[ModelT]):
"""Async SQLAlchemy repository implementation.
Provides a complete implementation of async database operations using SQLAlchemy,
including CRUD operations, filtering, and relationship loading.
Type Parameters:
ModelT: The SQLAlchemy model type this repository handles.
.. seealso::
:class:`~advanced_alchemy.repository._util.FilterableRepository`
"""
id_attribute: str = "id"
"""Name of the unique identifier for the model."""
loader_options: Optional[LoadSpec] = None
"""Default loader options for the repository."""
error_messages: Optional[ErrorMessages] = None
"""Default error messages for the repository."""
wrap_exceptions: bool = True
"""Wrap SQLAlchemy exceptions in a ``RepositoryError``. When set to ``False``, the original exception will be raised."""
inherit_lazy_relationships: bool = True
"""Optionally ignore the default ``lazy`` configuration for model relationships. This is useful for when you want to
replace instead of merge the model's loaded relationships with the ones specified in the ``load`` or ``default_loader_options`` configuration."""
merge_loader_options: bool = True
"""Merges the default loader options with the loader options specified in the ``load`` argument. This is useful for when you want to totally
replace instead of merge the model's loaded relationships with the ones specified in the ``load`` or ``default_loader_options`` configuration."""
execution_options: Optional[dict[str, Any]] = None
"""Default execution options for the repository."""
match_fields: Optional[Union[list[str], str]] = None
"""List of dialects that prefer to use ``field.id = ANY(:1)`` instead of ``field.id IN (...)``."""
uniquify: bool = False
"""Optionally apply the ``unique()`` method to results before returning.
This is useful for certain SQLAlchemy uses cases such as applying ``contains_eager`` to a query containing a one-to-many relationship
"""
count_with_window_function: bool = True
"""Use an analytical window function to count results. This allows the count to be performed in a single query.
"""
_cache_manager: Optional["CacheManager"] = None
"""Cache manager instance for repository-level caching. Set via ``cache_manager`` kwarg or retrieved from ``session.info``."""
_bind_group: Optional[str] = None
"""Default bind group for routing operations (e.g., to read replicas). Can be overridden per-method."""
[docs]
def __init__(
self,
*,
statement: Optional[Select[tuple[ModelT]]] = None,
session: Union[AsyncSession, async_scoped_session[AsyncSession]],
auto_expunge: bool = False,
auto_refresh: bool = True,
auto_commit: bool = False,
order_by: Optional[Union[list[OrderingPair], OrderingPair]] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
wrap_exceptions: bool = True,
uniquify: Optional[bool] = None,
count_with_window_function: Optional[bool] = None,
cache_manager: Optional["CacheManager"] = None,
bind_group: Optional[str] = None,
**kwargs: Any,
) -> None:
"""Repository for SQLAlchemy models.
Args:
statement: To facilitate customization of the underlying select query.
session: Session managing the unit-of-work for the operation.
auto_expunge: Remove object from session before returning.
auto_refresh: Refresh object from session before returning.
auto_commit: Commit objects before returning.
order_by: Set default order options for queries.
load: Set default relationships to be loaded
execution_options: Set default execution options
error_messages: A set of custom error messages to use for operations
wrap_exceptions: Wrap SQLAlchemy exceptions in a ``RepositoryError``. When set to ``False``, the original exception will be raised.
uniquify: Optionally apply the ``unique()`` method to results before returning.
count_with_window_function: When false, list and count will use two queries instead of an analytical window function.
cache_manager: Optional cache manager for repository-level caching. If not provided, retrieved from ``session.info``.
bind_group: Optional default routing group to use for all operations. Can be overridden per-method.
**kwargs: Additional arguments.
"""
self.auto_expunge = auto_expunge
self.auto_refresh = auto_refresh
self.auto_commit = auto_commit
self.order_by = order_by
self.session = session
self.error_messages = self._get_error_messages(
error_messages=error_messages, default_messages=self.error_messages
)
self.wrap_exceptions = wrap_exceptions
self._uniquify = uniquify if uniquify is not None else self.uniquify
self.count_with_window_function = (
count_with_window_function if count_with_window_function is not None else self.count_with_window_function
)
self._default_loader_options, self._loader_options_have_wildcards = get_abstract_loader_options(
loader_options=load if load is not None else self.loader_options,
inherit_lazy_relationships=self.inherit_lazy_relationships,
merge_with_default=self.merge_loader_options,
)
execution_options = execution_options if execution_options is not None else self.execution_options
self._default_execution_options = execution_options or {}
self.statement = select(self.model_type) if statement is None else statement
self._dialect = self.session.bind.dialect if self.session.bind is not None else self.session.get_bind().dialect
self._prefer_any = any(self._dialect.name == engine_type for engine_type in self.prefer_any_dialects or ())
# Cache manager: from explicit param or session.info (set by SQLAlchemyAsyncConfig)
self._cache_manager = cache_manager if cache_manager is not None else session.info.get("cache_manager")
# Default bind group for all operations (can be overridden per-method)
self._bind_group = bind_group
def _get_uniquify(self, uniquify: Optional[bool] = None) -> bool:
"""Get the uniquify value, preferring the method parameter over instance setting.
Args:
uniquify: Optional override for the uniquify setting.
Returns:
bool: The uniquify value to use.
"""
return bool(uniquify) if uniquify is not None else self._uniquify
def _resolve_bind_group(self, bind_group: Optional[str] = None) -> Optional[str]:
"""Resolve the bind_group to use, preferring method parameter over instance default.
Args:
bind_group: Optional override for the bind_group setting.
Returns:
The bind_group to use, or None if not set.
"""
return bind_group if bind_group is not None else self._bind_group
def _queue_cache_invalidation(self, entity_id: Any, bind_group: Optional[str] = None) -> None:
"""Queue a cache invalidation for an entity.
The invalidation will be processed after the transaction commits.
If the transaction rolls back, the pending invalidation is discarded.
This uses cache listeners which must be set up via setup_cache_listeners()
during application initialization, or via scoped listeners in SQLAlchemyConfig.
Args:
entity_id: The primary key value of the entity to invalidate.
bind_group: Optional routing group for multi-master configurations.
When provided, only the cache entry for that bind_group is
invalidated.
"""
if self._cache_manager is not None:
from advanced_alchemy._listeners import get_cache_tracker
# Check if model_type has __tablename__ (may not exist in mock scenarios)
model_name = getattr(self.model_type, "__tablename__", None)
if model_name is None:
return
tracker = get_cache_tracker(self.session, self._cache_manager)
if tracker is not None:
tracker.add_invalidation(cast("str", model_name), entity_id, bind_group)
def _type_must_use_in_instead_of_any(self, matched_values: "list[Any]", field_type: "Any" = None) -> bool:
"""Determine if field.in_() should be used instead of any_() for compatibility.
Uses SQLAlchemy's type introspection to detect types that may have DBAPI
serialization issues with the ANY() operator. Checks if actual values match
the column's expected python_type - mismatches indicate complex types that
need the safer IN() operator. Falls back to Python type checking when
SQLAlchemy type information is unavailable.
Args:
matched_values: Values to be used in the filter
field_type: Optional SQLAlchemy TypeEngine from the column
Returns:
bool: True if field.in_() should be used instead of any_()
"""
if not matched_values:
return False
if field_type is not None:
try:
expected_python_type = getattr(field_type, "python_type", None)
if expected_python_type is not None:
for value in matched_values:
if value is not None and not isinstance(value, expected_python_type):
return True
except (AttributeError, NotImplementedError):
return True
return any(value is not None and type(value) not in DEFAULT_SAFE_TYPES for value in matched_values)
def _get_unique_values(self, values: "list[Any]") -> "list[Any]":
"""Get unique values from a list, handling unhashable types safely.
Args:
values: List of values to deduplicate
Returns:
list[Any]: List of unique values preserving order
"""
if not values:
return []
try:
# Fast path for hashable types
seen: set[Any] = set()
unique_values: list[Any] = []
for value in values:
if value not in seen:
unique_values.append(value)
seen.add(value)
except TypeError:
# Fallback for unhashable types (e.g., dicts from JSONB)
unique_values = []
for value in values:
if value not in unique_values:
unique_values.append(value)
return unique_values
@staticmethod
def _get_error_messages(
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
default_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
) -> Optional[ErrorMessages]:
if error_messages == Empty:
error_messages = None
if default_messages == Empty:
default_messages = None
messages = cast("ErrorMessages", dict(DEFAULT_ERROR_MESSAGE_TEMPLATES))
if default_messages and isinstance(default_messages, dict):
messages.update(default_messages)
if error_messages:
messages.update(cast("ErrorMessages", error_messages))
return messages
[docs]
@classmethod
def get_id_attribute_value(
cls,
item: Union[ModelT, type[ModelT]],
id_attribute: Optional[Union[str, InstrumentedAttribute[Any]]] = None,
) -> Any:
"""Get value of attribute named as :attr:`id_attribute` on ``item``.
Args:
item: Anything that should have an attribute named as :attr:`id_attribute` value.
id_attribute: Allows customization of the unique identifier to use for model fetching.
Defaults to `None`, but can reference any surrogate or candidate key for the table.
Returns:
The value of attribute on ``item`` named as :attr:`id_attribute`.
"""
if isinstance(id_attribute, InstrumentedAttribute):
id_attribute = id_attribute.key
return getattr(item, id_attribute if id_attribute is not None else cls.id_attribute)
[docs]
@classmethod
def set_id_attribute_value(
cls,
item_id: Any,
item: ModelT,
id_attribute: Optional[Union[str, InstrumentedAttribute[Any]]] = None,
) -> ModelT:
"""Return the ``item`` after the ID is set to the appropriate attribute.
Args:
item_id: Value of ID to be set on instance
item: Anything that should have an attribute named as :attr:`id_attribute` value.
id_attribute: Allows customization of the unique identifier to use for model fetching.
Defaults to `None`, but can reference any surrogate or candidate key for the table.
Returns:
Item with ``item_id`` set to :attr:`id_attribute`
"""
if isinstance(id_attribute, InstrumentedAttribute):
id_attribute = id_attribute.key
setattr(item, id_attribute if id_attribute is not None else cls.id_attribute, item_id)
return item
[docs]
@staticmethod
def check_not_found(item_or_none: Optional[ModelT]) -> ModelT:
"""Raise :exc:`advanced_alchemy.exceptions.NotFoundError` if ``item_or_none`` is ``None``.
Args:
item_or_none: Item (:class:`T <T>`) to be tested for existence.
Raises:
NotFoundError: If ``item_or_none`` is ``None``
Returns:
The item, if it exists.
"""
if item_or_none is None:
msg = "No item found when one was expected"
raise NotFoundError(msg)
return item_or_none
def _get_execution_options(
self,
execution_options: Optional[dict[str, Any]] = None,
) -> dict[str, Any]:
if execution_options is None:
return self._default_execution_options
return execution_options
def _get_loader_options(
self,
loader_options: Optional[LoadSpec],
) -> Union[tuple[list[_AbstractLoad], bool], tuple[None, bool]]:
if loader_options is None:
# use the defaults set at initialization
return self._default_loader_options, self._loader_options_have_wildcards or self._uniquify
return get_abstract_loader_options(
loader_options=loader_options,
default_loader_options=self._default_loader_options,
default_options_have_wildcards=self._loader_options_have_wildcards or self._uniquify,
inherit_lazy_relationships=self.inherit_lazy_relationships,
merge_with_default=self.merge_loader_options,
)
[docs]
async def add(
self,
data: ModelT,
*,
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
auto_refresh: Optional[bool] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
bind_group: Optional[str] = None,
) -> ModelT:
"""Add ``data`` to the collection.
Args:
data: Instance to be added to the collection.
auto_expunge: Remove object from session before returning.
auto_refresh: Refresh object from session before returning.
auto_commit: Commit objects before returning.
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
bind_group: Optional routing group for multi-master configurations.
Returns:
The added instance.
"""
_ = bind_group # Reserved for future multi-master routing
error_messages = self._get_error_messages(
error_messages=error_messages,
default_messages=self.error_messages,
)
with wrap_sqlalchemy_exception(
error_messages=error_messages, dialect_name=self._dialect.name, wrap_exceptions=self.wrap_exceptions
):
instance = await self._attach_to_session(data)
await self._flush_or_commit(auto_commit=auto_commit)
await self._refresh(instance, auto_refresh=auto_refresh)
self._expunge(instance, auto_expunge=auto_expunge)
return instance
[docs]
async def add_many(
self,
data: list[ModelT],
*,
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
bind_group: Optional[str] = None,
) -> Sequence[ModelT]:
"""Add many `data` to the collection.
Args:
data: list of Instances to be added to the collection.
auto_expunge: Remove object from session before returning.
auto_commit: Commit objects before returning.
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
bind_group: Optional routing group for multi-master configurations.
Returns:
The added instances.
"""
_ = bind_group # Reserved for future multi-master routing
error_messages = self._get_error_messages(
error_messages=error_messages,
default_messages=self.error_messages,
)
with wrap_sqlalchemy_exception(
error_messages=error_messages, dialect_name=self._dialect.name, wrap_exceptions=self.wrap_exceptions
):
self.session.add_all(data)
await self._flush_or_commit(auto_commit=auto_commit)
for datum in data:
self._expunge(datum, auto_expunge=auto_expunge)
return data
[docs]
async def delete(
self,
item_id: Any,
*,
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
id_attribute: Optional[Union[str, InstrumentedAttribute[Any]]] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
bind_group: Optional[str] = None,
) -> ModelT:
"""Delete instance identified by ``item_id``.
Args:
item_id: Identifier of instance to be deleted.
auto_expunge: Remove object from session before returning.
auto_commit: Commit objects before returning.
id_attribute: Allows customization of the unique identifier to use for model fetching.
Defaults to `id`, but can reference any surrogate or candidate key for the table.
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
load: Set default relationships to be loaded
execution_options: Set default execution options
uniquify: Optionally apply the ``unique()`` method to results before returning.
bind_group: Optional routing group for multi-master configurations.
Returns:
The deleted instance.
"""
self._uniquify = self._get_uniquify(uniquify)
error_messages = self._get_error_messages(
error_messages=error_messages,
default_messages=self.error_messages,
)
with wrap_sqlalchemy_exception(
error_messages=error_messages, dialect_name=self._dialect.name, wrap_exceptions=self.wrap_exceptions
):
resolved_bind_group = self._resolve_bind_group(bind_group)
if resolved_bind_group:
execution_options = dict(execution_options) if execution_options else {}
execution_options["bind_group"] = resolved_bind_group
execution_options = self._get_execution_options(execution_options)
instance = await self.get(
item_id,
id_attribute=id_attribute,
load=load,
execution_options=execution_options,
bind_group=bind_group,
)
await self.session.delete(instance)
await self._flush_or_commit(auto_commit=auto_commit)
self._expunge(instance, auto_expunge=auto_expunge)
# Queue cache invalidation (processed on commit)
self._queue_cache_invalidation(item_id, bind_group)
return instance
[docs]
async def delete_many(
self,
item_ids: list[Any],
*,
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
id_attribute: Optional[Union[str, InstrumentedAttribute[Any]]] = None,
chunk_size: Optional[int] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
bind_group: Optional[str] = None,
) -> Sequence[ModelT]:
"""Delete instance identified by `item_id`.
Args:
item_ids: Identifier of instance to be deleted.
auto_expunge: Remove object from session before returning.
auto_commit: Commit objects before returning.
id_attribute: Allows customization of the unique identifier to use for model fetching.
Defaults to `id`, but can reference any surrogate or candidate key for the table.
chunk_size: Allows customization of the ``insertmanyvalues_max_parameters`` setting for the driver.
Defaults to `950` if left unset.
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
load: Set default relationships to be loaded
execution_options: Set default execution options
uniquify: Optionally apply the ``unique()`` method to results before returning.
bind_group: Optional routing group for multi-master configurations.
Returns:
The deleted instances.
"""
self._uniquify = self._get_uniquify(uniquify)
error_messages = self._get_error_messages(
error_messages=error_messages,
default_messages=self.error_messages,
)
with wrap_sqlalchemy_exception(
error_messages=error_messages, dialect_name=self._dialect.name, wrap_exceptions=self.wrap_exceptions
):
resolved_bind_group = self._resolve_bind_group(bind_group)
if resolved_bind_group:
execution_options = dict(execution_options) if execution_options else {}
execution_options["bind_group"] = resolved_bind_group
execution_options = self._get_execution_options(execution_options)
loader_options, _loader_options_have_wildcard = self._get_loader_options(load)
id_attribute = get_instrumented_attr(
self.model_type,
id_attribute if id_attribute is not None else self.id_attribute,
)
instances: list[ModelT] = []
if self._prefer_any:
chunk_size = len(item_ids) + 1
chunk_size = self._get_insertmanyvalues_max_parameters(chunk_size)
for idx in range(0, len(item_ids), chunk_size):
chunk = item_ids[idx : min(idx + chunk_size, len(item_ids))]
if self._dialect.delete_executemany_returning:
instances.extend(
await self.session.scalars(
self._get_delete_many_statement(
statement_type="delete",
model_type=self.model_type,
id_attribute=id_attribute,
id_chunk=chunk,
supports_returning=self._dialect.delete_executemany_returning,
loader_options=loader_options,
execution_options=execution_options,
),
),
)
else:
instances.extend(
await self.session.scalars(
self._get_delete_many_statement(
statement_type="select",
model_type=self.model_type,
id_attribute=id_attribute,
id_chunk=chunk,
supports_returning=self._dialect.delete_executemany_returning,
loader_options=loader_options,
execution_options=execution_options,
),
),
)
await self.session.execute(
self._get_delete_many_statement(
statement_type="delete",
model_type=self.model_type,
id_attribute=id_attribute,
id_chunk=chunk,
supports_returning=self._dialect.delete_executemany_returning,
loader_options=loader_options,
execution_options=execution_options,
),
)
await self._flush_or_commit(auto_commit=auto_commit)
for instance in instances:
self._expunge(instance, auto_expunge=auto_expunge)
# Queue cache invalidation (processed on commit)
self._queue_cache_invalidation(self.get_id_attribute_value(instance), bind_group)
return instances
@staticmethod
def _get_insertmanyvalues_max_parameters(chunk_size: Optional[int] = None) -> int:
return chunk_size if chunk_size is not None else DEFAULT_INSERTMANYVALUES_MAX_PARAMETERS
[docs]
async def delete_where(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
sanity_check: bool = True,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
bind_group: Optional[str] = None,
**kwargs: Any,
) -> Sequence[ModelT]:
"""Delete instances specified by referenced kwargs and filters.
Args:
*filters: Types for specific filtering operations.
auto_expunge: Remove object from session before returning.
auto_commit: Commit objects before returning.
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
sanity_check: When true, the length of selected instances is compared to the deleted row count
load: Set default relationships to be loaded
execution_options: Set default execution options
uniquify: Optionally apply the ``unique()`` method to results before returning.
bind_group: Optional routing group for multi-master configurations.
**kwargs: Arguments to apply to a delete
Raises:
RepositoryError: If the number of deleted rows does not match the number of selected instances
Returns:
The deleted instances.
"""
self._uniquify = self._get_uniquify(uniquify)
error_messages = self._get_error_messages(
error_messages=error_messages,
default_messages=self.error_messages,
)
with wrap_sqlalchemy_exception(
error_messages=error_messages, dialect_name=self._dialect.name, wrap_exceptions=self.wrap_exceptions
):
resolved_bind_group = self._resolve_bind_group(bind_group)
if resolved_bind_group:
execution_options = dict(execution_options) if execution_options else {}
execution_options["bind_group"] = resolved_bind_group
execution_options = self._get_execution_options(execution_options)
loader_options, _loader_options_have_wildcard = self._get_loader_options(load)
model_type = self.model_type
statement = self._get_base_stmt(
statement=delete(model_type),
loader_options=loader_options,
execution_options=execution_options,
)
statement = self._filter_select_by_kwargs(statement=statement, kwargs=kwargs)
statement = self._apply_filters(*filters, statement=statement, apply_pagination=False)
instances: list[ModelT] = []
if self._dialect.delete_executemany_returning:
instances.extend(await self.session.scalars(statement.returning(model_type)))
else:
instances.extend(
await self.list(
*filters,
load=load,
execution_options=execution_options,
auto_expunge=auto_expunge,
use_cache=False, # Always fetch from DB for delete_where
bind_group=bind_group,
**kwargs,
),
)
result = await self.session.execute(statement)
row_count = getattr(result, "rowcount", -2)
if sanity_check and row_count >= 0 and len(instances) != row_count: # pyright: ignore
# backends will return a -1 if they can't determine impacted rowcount
# only compare length of selected instances to results if it's >= 0
await self.session.rollback()
raise RepositoryError(detail="Deleted count does not match fetched count. Rollback issued.")
await self._flush_or_commit(auto_commit=auto_commit)
for instance in instances:
self._expunge(instance, auto_expunge=auto_expunge)
# Queue cache invalidation (processed on commit)
self._queue_cache_invalidation(self.get_id_attribute_value(instance), resolved_bind_group)
return instances
[docs]
async def exists(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
bind_group: Optional[str] = None,
**kwargs: Any,
) -> bool:
"""Return true if the object specified by ``kwargs`` exists.
Args:
*filters: Types for specific filtering operations.
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
load: Set default relationships to be loaded
execution_options: Set default execution options
uniquify: Optionally apply the ``unique()`` method to results before returning.
bind_group: Optional routing group to use for the operation.
**kwargs: Identifier of the instance to be retrieved.
Returns:
True if the instance was found. False if not found..
"""
error_messages = self._get_error_messages(
error_messages=error_messages,
default_messages=self.error_messages,
)
existing = await self.count(
*filters,
load=load,
execution_options=execution_options,
error_messages=error_messages,
bind_group=bind_group,
**kwargs,
)
return existing > 0
@staticmethod
def _get_base_stmt(
*,
statement: StatementTypeT,
loader_options: Optional[list[_AbstractLoad]],
execution_options: Optional[dict[str, Any]],
) -> StatementTypeT:
"""Get base statement with options applied.
Args:
statement: The select statement to modify
loader_options: Options for loading relationships
execution_options: Options for statement execution
Returns:
Modified select statement
"""
if loader_options:
statement = cast("StatementTypeT", statement.options(*loader_options))
if execution_options:
statement = cast("StatementTypeT", statement.execution_options(**execution_options))
return statement
def _apply_for_update_options(
self,
statement: Select[tuple[ModelT]],
with_for_update: ForUpdateParameter,
) -> Select[tuple[ModelT]]:
"""Apply FOR UPDATE options to a SELECT statement when requested."""
if with_for_update in (None, False):
return statement
if with_for_update is True:
return statement.with_for_update()
if isinstance(with_for_update, ForUpdateArg):
with_for_update_kwargs: dict[str, Any] = {
"nowait": with_for_update.nowait,
"read": with_for_update.read,
"skip_locked": with_for_update.skip_locked,
"key_share": with_for_update.key_share,
}
if getattr(with_for_update, "of", None):
with_for_update_kwargs["of"] = with_for_update.of
return statement.with_for_update(**with_for_update_kwargs)
if isinstance(with_for_update, dict): # pyright: ignore
return statement.with_for_update(**with_for_update)
return statement
def _get_delete_many_statement(
self,
*,
model_type: type[ModelT],
id_attribute: InstrumentedAttribute[Any],
id_chunk: list[Any],
supports_returning: bool,
statement_type: Literal["delete", "select"] = "delete",
loader_options: Optional[list[_AbstractLoad]],
execution_options: Optional[dict[str, Any]],
) -> Union[Select[tuple[ModelT]], Delete, ReturningDelete[tuple[ModelT]]]:
# Base statement is static
statement = self._get_base_stmt(
statement=delete(model_type) if statement_type == "delete" else select(model_type),
loader_options=loader_options,
execution_options=execution_options,
)
if execution_options:
statement = statement.execution_options(**execution_options)
if supports_returning and statement_type != "select":
statement = cast("ReturningDelete[tuple[ModelT]]", statement.returning(model_type)) # type: ignore[union-attr,assignment] # pyright: ignore[reportUnknownLambdaType,reportUnknownMemberType,reportAttributeAccessIssue,reportUnknownVariableType]
# Use field.in_() if types are incompatible with ANY() or if dialect doesn't prefer ANY()
use_in = not self._prefer_any or self._type_must_use_in_instead_of_any(id_chunk, id_attribute.type)
if use_in:
return statement.where(id_attribute.in_(id_chunk)) # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType]
return statement.where(any_(id_chunk) == id_attribute) # type: ignore[arg-type]
async def _get_from_db(
self,
item_id: Any,
*,
auto_expunge: Optional[bool],
statement: Optional[Select[tuple[ModelT]]],
id_attribute: Optional[Union[str, InstrumentedAttribute[Any]]],
error_messages: Optional[ErrorMessages],
load: Optional[LoadSpec],
execution_options: Optional[dict[str, Any]],
with_for_update: ForUpdateParameter,
bind_group: Optional[str] = None,
) -> ModelT:
"""Fetch an entity from the database without using cache."""
with wrap_sqlalchemy_exception(
error_messages=error_messages, dialect_name=self._dialect.name, wrap_exceptions=self.wrap_exceptions
):
resolved_bind_group = self._resolve_bind_group(bind_group)
if resolved_bind_group:
execution_options = dict(execution_options) if execution_options else {}
execution_options["bind_group"] = resolved_bind_group
resolved_execution_options = self._get_execution_options(execution_options)
resolved_statement = self.statement if statement is None else statement
loader_options, loader_options_have_wildcard = self._get_loader_options(load)
resolved_id_attribute = id_attribute if id_attribute is not None else self.id_attribute
resolved_statement = self._get_base_stmt(
statement=resolved_statement,
loader_options=loader_options,
execution_options=resolved_execution_options,
)
resolved_statement = self._filter_select_by_kwargs(resolved_statement, [(resolved_id_attribute, item_id)])
resolved_statement = self._apply_for_update_options(resolved_statement, with_for_update)
instance = (
await self._execute(resolved_statement, uniquify=loader_options_have_wildcard)
).scalar_one_or_none()
instance = self.check_not_found(instance)
self._expunge(instance, auto_expunge=auto_expunge)
return instance
async def _get_cached_creator(
self,
model_name: str,
item_id: Any,
*,
auto_expunge: Optional[bool],
statement: Optional[Select[tuple[ModelT]]],
id_attribute: Optional[Union[str, InstrumentedAttribute[Any]]],
error_messages: Optional[ErrorMessages],
load: Optional[LoadSpec],
execution_options: Optional[dict[str, Any]],
with_for_update: ForUpdateParameter,
bind_group: Optional[str] = None,
) -> ModelT:
"""Singleflight creator for get(id) caching (async)."""
if self._cache_manager is None:
return await self._get_from_db(
item_id,
auto_expunge=auto_expunge,
statement=statement,
id_attribute=id_attribute,
error_messages=error_messages,
load=load,
execution_options=execution_options,
with_for_update=with_for_update,
bind_group=bind_group,
)
existing = await self._cache_manager.get_entity_async(
model_name, item_id, self.model_type, bind_group=bind_group
)
if existing is not None:
return existing
instance = await self._get_from_db(
item_id,
auto_expunge=auto_expunge,
statement=statement,
id_attribute=id_attribute,
error_messages=error_messages,
load=load,
execution_options=execution_options,
with_for_update=with_for_update,
bind_group=bind_group,
)
await self._cache_manager.set_entity_async(model_name, item_id, instance, bind_group=bind_group)
return instance
async def _list_from_db(
self,
*,
filters: Sequence[Union[StatementFilter, ColumnElement[bool]]],
auto_expunge: Optional[bool],
statement: Optional[Select[tuple[ModelT]]],
order_by: Optional[Union[list[OrderingPair], OrderingPair]],
error_messages: Optional[ErrorMessages],
load: Optional[LoadSpec],
execution_options: Optional[dict[str, Any]],
kwargs: dict[str, Any],
uniquify: Optional[bool],
bind_group: Optional[str] = None,
) -> list[ModelT]:
"""Fetch a list of entities from the database without using cache."""
self._uniquify = self._get_uniquify(uniquify)
with wrap_sqlalchemy_exception(
error_messages=error_messages, dialect_name=self._dialect.name, wrap_exceptions=self.wrap_exceptions
):
resolved_bind_group = self._resolve_bind_group(bind_group)
if resolved_bind_group:
execution_options = dict(execution_options) if execution_options else {}
execution_options["bind_group"] = resolved_bind_group
resolved_execution_options = self._get_execution_options(execution_options)
resolved_statement = self.statement if statement is None else statement
loader_options, loader_options_have_wildcard = self._get_loader_options(load)
resolved_statement = self._get_base_stmt(
statement=resolved_statement,
loader_options=loader_options,
execution_options=resolved_execution_options,
)
if order_by is None:
order_by = self.order_by if self.order_by is not None else []
resolved_statement = self._apply_order_by(statement=resolved_statement, order_by=order_by)
resolved_statement = self._apply_filters(*filters, statement=resolved_statement)
resolved_statement = self._filter_select_by_kwargs(resolved_statement, kwargs)
result = await self._execute(resolved_statement, uniquify=loader_options_have_wildcard)
instances = list(result.scalars())
for instance in instances:
self._expunge(instance, auto_expunge=auto_expunge)
return cast("list[ModelT]", instances)
async def _list_cached_creator(
self,
cache_key: str,
*,
filters: Sequence[Union[StatementFilter, ColumnElement[bool]]],
auto_expunge: Optional[bool],
statement: Optional[Select[tuple[ModelT]]],
order_by: Optional[Union[list[OrderingPair], OrderingPair]],
error_messages: Optional[ErrorMessages],
load: Optional[LoadSpec],
execution_options: Optional[dict[str, Any]],
kwargs: dict[str, Any],
uniquify: Optional[bool],
bind_group: Optional[str] = None,
) -> list[ModelT]:
"""Singleflight creator for list caching (async)."""
if self._cache_manager is None:
return await self._list_from_db(
filters=filters,
auto_expunge=auto_expunge,
statement=statement,
order_by=order_by,
error_messages=error_messages,
load=load,
execution_options=execution_options,
kwargs=kwargs,
uniquify=uniquify,
bind_group=bind_group,
)
existing = await self._cache_manager.get_list_async(cache_key, self.model_type)
if existing is not None:
return existing
instances = await self._list_from_db(
filters=filters,
auto_expunge=auto_expunge,
statement=statement,
order_by=order_by,
error_messages=error_messages,
load=load,
execution_options=execution_options,
kwargs=kwargs,
uniquify=uniquify,
bind_group=bind_group,
)
await self._cache_manager.set_list_async(cache_key, list(instances))
return list(instances)
async def _list_and_count_from_db(
self,
*,
filters: Sequence[Union[StatementFilter, ColumnElement[bool]]],
auto_expunge: Optional[bool],
statement: Optional[Select[tuple[ModelT]]],
count_with_window_function: bool,
order_by: Optional[Union[list[OrderingPair], OrderingPair]],
error_messages: Optional[ErrorMessages],
load: Optional[LoadSpec],
execution_options: Optional[dict[str, Any]],
kwargs: dict[str, Any],
uniquify: Optional[bool],
bind_group: Optional[str] = None,
) -> tuple[list[ModelT], int]:
"""Fetch a list+count payload from the database without using cache."""
self._uniquify = self._get_uniquify(uniquify)
resolved_bind_group = self._resolve_bind_group(bind_group)
if resolved_bind_group:
execution_options = dict(execution_options) if execution_options else {}
execution_options["bind_group"] = resolved_bind_group
if self._dialect.name in {"spanner", "spanner+spanner"} or not count_with_window_function:
return await self._list_and_count_basic(
*filters,
auto_expunge=auto_expunge,
statement=statement,
load=load,
execution_options=execution_options,
order_by=order_by,
error_messages=error_messages,
**kwargs,
)
return await self._list_and_count_window(
*filters,
auto_expunge=auto_expunge,
statement=statement,
load=load,
execution_options=execution_options,
error_messages=error_messages,
order_by=order_by,
**kwargs,
)
async def _list_and_count_cached_creator(
self,
cache_key: str,
*,
filters: Sequence[Union[StatementFilter, ColumnElement[bool]]],
auto_expunge: Optional[bool],
statement: Optional[Select[tuple[ModelT]]],
count_with_window_function: bool,
order_by: Optional[Union[list[OrderingPair], OrderingPair]],
error_messages: Optional[ErrorMessages],
load: Optional[LoadSpec],
execution_options: Optional[dict[str, Any]],
kwargs: dict[str, Any],
uniquify: Optional[bool],
bind_group: Optional[str] = None,
) -> tuple[list[ModelT], int]:
"""Singleflight creator for list_and_count caching (async)."""
if self._cache_manager is None:
return await self._list_and_count_from_db(
filters=filters,
auto_expunge=auto_expunge,
statement=statement,
count_with_window_function=count_with_window_function,
order_by=order_by,
error_messages=error_messages,
load=load,
execution_options=execution_options,
kwargs=kwargs,
uniquify=uniquify,
bind_group=bind_group,
)
existing = await self._cache_manager.get_list_and_count_async(cache_key, self.model_type)
if existing is not None:
return existing
instances, count = await self._list_and_count_from_db(
filters=filters,
auto_expunge=auto_expunge,
statement=statement,
count_with_window_function=count_with_window_function,
order_by=order_by,
error_messages=error_messages,
load=load,
execution_options=execution_options,
kwargs=kwargs,
uniquify=uniquify,
bind_group=bind_group,
)
await self._cache_manager.set_list_and_count_async(cache_key, list(instances), count)
return list(instances), count
[docs]
async def get(
self,
item_id: Any,
*,
auto_expunge: Optional[bool] = None,
statement: Optional[Select[tuple[ModelT]]] = None,
id_attribute: Optional[Union[str, InstrumentedAttribute[Any]]] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
with_for_update: ForUpdateParameter = None,
use_cache: bool = True,
bind_group: Optional[str] = None,
) -> ModelT:
"""Get instance identified by `item_id`.
Args:
item_id: Identifier of the instance to be retrieved.
auto_expunge: Remove object from session before returning.
statement: To facilitate customization of the underlying select query.
id_attribute: Allows customization of the unique identifier to use for model fetching.
Defaults to `id`, but can reference any surrogate or candidate key for the table.
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
load: Set relationships to be loaded
execution_options: Set default execution options
uniquify: Optionally apply the ``unique()`` method to results before returning.
with_for_update: Optional FOR UPDATE clause / parameters to apply to the SELECT statement.
use_cache: Whether to use caching for this query (default True).
bind_group: Optional routing group to use for the operation.
Returns:
The retrieved instance.
"""
self._uniquify = self._get_uniquify(uniquify)
resolved_error_messages = self._get_error_messages(
error_messages=error_messages,
default_messages=self.error_messages,
)
resolved_auto_expunge = self.auto_expunge if auto_expunge is None else auto_expunge
resolved_id_attribute: Optional[Union[str, InstrumentedAttribute[Any]]] = id_attribute
if isinstance(resolved_id_attribute, InstrumentedAttribute):
resolved_id_attribute = resolved_id_attribute.key
cache_manager = self._cache_manager
# Resolve bind_group for cache key namespacing
resolved_bind_group = self._resolve_bind_group(bind_group)
if (
use_cache
and cache_manager is not None
and bool(resolved_auto_expunge)
and statement is None
and load is None
and with_for_update is None
and (resolved_id_attribute is None or resolved_id_attribute == self.id_attribute)
and not self._default_loader_options
and not self._default_execution_options
and execution_options is None
):
model_name = cast("str", self.model_type.__tablename__) # type: ignore[attr-defined]
cached = await cache_manager.get_entity_async(
model_name, item_id, self.model_type, bind_group=resolved_bind_group
)
if cached is not None:
return cached
# Include bind_group in singleflight key to prevent cross-shard cache pollution
singleflight_key = (
f"{model_name}:{resolved_bind_group}:get:{item_id}"
if resolved_bind_group
else f"{model_name}:get:{item_id}"
)
return await cache_manager.singleflight_async(
singleflight_key,
partial(
self._get_cached_creator,
model_name,
item_id,
auto_expunge=auto_expunge,
statement=statement,
id_attribute=resolved_id_attribute,
error_messages=resolved_error_messages,
load=load,
execution_options=execution_options,
with_for_update=with_for_update,
bind_group=resolved_bind_group,
),
)
return await self._get_from_db(
item_id,
auto_expunge=auto_expunge,
statement=statement,
id_attribute=id_attribute,
error_messages=resolved_error_messages,
load=load,
execution_options=execution_options,
with_for_update=with_for_update,
bind_group=bind_group,
)
[docs]
async def get_one(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
auto_expunge: Optional[bool] = None,
statement: Optional[Select[tuple[ModelT]]] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
with_for_update: ForUpdateParameter = None,
bind_group: Optional[str] = None,
**kwargs: Any,
) -> ModelT:
"""Get instance identified by ``kwargs``.
Args:
*filters: Types for specific filtering operations.
auto_expunge: Remove object from session before returning.
statement: To facilitate customization of the underlying select query.
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
load: Set relationships to be loaded
execution_options: Set default execution options
uniquify: Optionally apply the ``unique()`` method to results before returning.
with_for_update: Optional FOR UPDATE clause / parameters to apply to the SELECT statement.
bind_group: Optional routing group to use for the operation.
**kwargs: Identifier of the instance to be retrieved.
Returns:
The retrieved instance.
"""
self._uniquify = self._get_uniquify(uniquify)
error_messages = self._get_error_messages(
error_messages=error_messages,
default_messages=self.error_messages,
)
with wrap_sqlalchemy_exception(
error_messages=error_messages, dialect_name=self._dialect.name, wrap_exceptions=self.wrap_exceptions
):
if bind_group:
execution_options = dict(execution_options) if execution_options else {}
execution_options["bind_group"] = bind_group
execution_options = self._get_execution_options(execution_options)
statement = self.statement if statement is None else statement
loader_options, loader_options_have_wildcard = self._get_loader_options(load)
statement = self._get_base_stmt(
statement=statement,
loader_options=loader_options,
execution_options=execution_options,
)
statement = self._apply_filters(*filters, apply_pagination=False, statement=statement)
statement = self._filter_select_by_kwargs(statement, kwargs)
statement = self._apply_for_update_options(statement, with_for_update)
instance = (await self._execute(statement, uniquify=loader_options_have_wildcard)).scalar_one_or_none()
instance = self.check_not_found(instance)
self._expunge(instance, auto_expunge=auto_expunge)
return instance
[docs]
async def get_one_or_none(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
auto_expunge: Optional[bool] = None,
statement: Optional[Select[tuple[ModelT]]] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
with_for_update: ForUpdateParameter = None,
bind_group: Optional[str] = None,
**kwargs: Any,
) -> Union[ModelT, None]:
"""Get instance identified by ``kwargs`` or None if not found.
Args:
*filters: Types for specific filtering operations.
auto_expunge: Remove object from session before returning.
statement: To facilitate customization of the underlying select query.
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
load: Set relationships to be loaded
execution_options: Set default execution options
uniquify: Optionally apply the ``unique()`` method to results before returning.
with_for_update: Optional FOR UPDATE clause / parameters to apply to the SELECT statement.
bind_group: Optional routing group to use for the operation.
**kwargs: Identifier of the instance to be retrieved.
Returns:
The retrieved instance or None
"""
self._uniquify = self._get_uniquify(uniquify)
error_messages = self._get_error_messages(
error_messages=error_messages,
default_messages=self.error_messages,
)
with wrap_sqlalchemy_exception(
error_messages=error_messages, dialect_name=self._dialect.name, wrap_exceptions=self.wrap_exceptions
):
if bind_group:
execution_options = dict(execution_options) if execution_options else {}
execution_options["bind_group"] = bind_group
execution_options = self._get_execution_options(execution_options)
statement = self.statement if statement is None else statement
loader_options, loader_options_have_wildcard = self._get_loader_options(load)
statement = self._get_base_stmt(
statement=statement,
loader_options=loader_options,
execution_options=execution_options,
)
statement = self._apply_filters(*filters, apply_pagination=False, statement=statement)
statement = self._filter_select_by_kwargs(statement, kwargs)
statement = self._apply_for_update_options(statement, with_for_update)
instance = cast(
"Result[tuple[ModelT]]",
(await self._execute(statement, uniquify=loader_options_have_wildcard)),
).scalar_one_or_none()
if instance:
self._expunge(instance, auto_expunge=auto_expunge)
return instance
[docs]
async def get_or_upsert(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
match_fields: Optional[Union[list[str], str]] = None,
upsert: bool = True,
attribute_names: Optional[Iterable[str]] = None,
with_for_update: ForUpdateParameter = None,
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
auto_refresh: Union[bool, None] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
bind_group: Optional[str] = None,
**kwargs: Any,
) -> tuple[ModelT, bool]:
"""Get instance identified by ``kwargs`` or create if it doesn't exist.
Args:
*filters: Types for specific filtering operations.
match_fields: a list of keys to use to match the existing model. When
empty, all fields are matched.
upsert: When using match_fields and actual model values differ from
`kwargs`, automatically perform an update operation on the model.
attribute_names: an iterable of attribute names to pass into the ``update``
method.
with_for_update: indicating FOR UPDATE should be used, or may be a
dictionary containing flags to indicate a more specific set of
FOR UPDATE flags for the SELECT
auto_expunge: Remove object from session before returning.
auto_refresh: Refresh object from session before returning.
auto_commit: Commit objects before returning.
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
load: Set relationships to be loaded
execution_options: Set default execution options
uniquify: Optionally apply the ``unique()`` method to results before returning.
bind_group: Optional routing group for multi-master configurations.
**kwargs: Identifier of the instance to be retrieved.
Returns:
a tuple that includes the instance and whether it needed to be created.
When using match_fields and actual model values differ from ``kwargs``, the
model value will be updated.
"""
self._uniquify = self._get_uniquify(uniquify)
error_messages = self._get_error_messages(
error_messages=error_messages,
default_messages=self.error_messages,
)
with wrap_sqlalchemy_exception(
error_messages=error_messages, dialect_name=self._dialect.name, wrap_exceptions=self.wrap_exceptions
):
if match_fields := self._get_match_fields(match_fields=match_fields):
match_filter = {
field_name: kwargs.get(field_name)
for field_name in match_fields
if kwargs.get(field_name) is not None
}
else:
match_filter = kwargs
existing = await self.get_one_or_none(
*filters,
**match_filter,
load=load,
execution_options=execution_options,
bind_group=bind_group,
)
if not existing:
return (
await self.add(
self.model_type(**kwargs),
auto_commit=auto_commit,
auto_expunge=auto_expunge,
auto_refresh=auto_refresh,
bind_group=bind_group,
),
True,
)
if upsert:
for field_name, new_field_value in kwargs.items():
field = getattr(existing, field_name, MISSING)
if field is not MISSING and not compare_values(field, new_field_value): # pragma: no cover
setattr(existing, field_name, new_field_value)
existing = await self._attach_to_session(existing, strategy="merge")
await self._flush_or_commit(auto_commit=auto_commit)
await self._refresh(
existing,
attribute_names=attribute_names,
with_for_update=with_for_update,
auto_refresh=auto_refresh,
)
self._expunge(existing, auto_expunge=auto_expunge)
return existing, False
[docs]
async def get_and_update(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
match_fields: Optional[Union[list[str], str]] = None,
attribute_names: Optional[Iterable[str]] = None,
with_for_update: ForUpdateParameter = None,
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
auto_refresh: Optional[bool] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
bind_group: Optional[str] = None,
**kwargs: Any,
) -> tuple[ModelT, bool]:
"""Get instance identified by ``kwargs`` and update the model if the arguments are different.
Args:
*filters: Types for specific filtering operations.
match_fields: a list of keys to use to match the existing model. When
empty, all fields are matched.
attribute_names: an iterable of attribute names to pass into the ``update``
method.
with_for_update: indicating FOR UPDATE should be used, or may be a
dictionary containing flags to indicate a more specific set of
FOR UPDATE flags for the SELECT
auto_expunge: Remove object from session before returning.
auto_refresh: Refresh object from session before returning.
auto_commit: Commit objects before returning.
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
load: Set relationships to be loaded
execution_options: Set default execution options
uniquify: Optionally apply the ``unique()`` method to results before returning.
bind_group: Optional routing group for multi-master configurations.
**kwargs: Identifier of the instance to be retrieved.
Returns:
a tuple that includes the instance and whether it needed to be updated.
When using match_fields and actual model values differ from ``kwargs``, the
model value will be updated.
"""
self._uniquify = self._get_uniquify(uniquify)
error_messages = self._get_error_messages(
error_messages=error_messages,
default_messages=self.error_messages,
)
with wrap_sqlalchemy_exception(
error_messages=error_messages, dialect_name=self._dialect.name, wrap_exceptions=self.wrap_exceptions
):
if match_fields := self._get_match_fields(match_fields=match_fields):
match_filter = {
field_name: kwargs.get(field_name)
for field_name in match_fields
if kwargs.get(field_name) is not None
}
else:
match_filter = kwargs
existing = await self.get_one(
*filters, **match_filter, load=load, execution_options=execution_options, bind_group=bind_group
)
updated = False
for field_name, new_field_value in kwargs.items():
field = getattr(existing, field_name, MISSING)
if field is not MISSING and not compare_values(field, new_field_value): # pragma: no cover
updated = True
setattr(existing, field_name, new_field_value)
existing = await self._attach_to_session(existing, strategy="merge")
await self._flush_or_commit(auto_commit=auto_commit)
await self._refresh(
existing,
attribute_names=attribute_names,
with_for_update=with_for_update,
auto_refresh=auto_refresh,
)
self._expunge(existing, auto_expunge=auto_expunge)
return existing, updated
[docs]
async def count(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
statement: Optional[Select[tuple[ModelT]]] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
bind_group: Optional[str] = None,
**kwargs: Any,
) -> int:
"""Get the count of records returned by a query.
Args:
*filters: Types for specific filtering operations.
statement: To facilitate customization of the underlying select query.
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
load: Set relationships to be loaded
execution_options: Set default execution options
uniquify: Optionally apply the ``unique()`` method to results before returning.
bind_group: Optional routing group to use for the operation.
**kwargs: Instance attribute value filters.
Returns:
Count of records returned by query, ignoring pagination.
"""
self._uniquify = self._get_uniquify(uniquify)
error_messages = self._get_error_messages(
error_messages=error_messages,
default_messages=self.error_messages,
)
with wrap_sqlalchemy_exception(
error_messages=error_messages, dialect_name=self._dialect.name, wrap_exceptions=self.wrap_exceptions
):
if bind_group:
execution_options = dict(execution_options) if execution_options else {}
execution_options["bind_group"] = bind_group
execution_options = self._get_execution_options(execution_options)
statement = self.statement if statement is None else statement
loader_options, loader_options_have_wildcard = self._get_loader_options(load)
statement = self._get_base_stmt(
statement=statement,
loader_options=loader_options,
execution_options=execution_options,
)
statement = self._apply_filters(*filters, apply_pagination=False, statement=statement)
statement = self._filter_select_by_kwargs(statement, kwargs)
results = await self._execute(
statement=self._get_count_stmt(
statement=statement, loader_options=loader_options, execution_options=execution_options
),
uniquify=loader_options_have_wildcard,
)
return cast("int", results.scalar_one())
[docs]
async def update(
self,
data: ModelT,
*,
attribute_names: Optional[Iterable[str]] = None,
with_for_update: ForUpdateParameter = None,
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
auto_refresh: Optional[bool] = None,
id_attribute: Optional[Union[str, InstrumentedAttribute[Any]]] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
bind_group: Optional[str] = None,
) -> ModelT:
"""Update instance with the attribute values present on `data`.
Args:
data: An instance that should have a value for `self.id_attribute` that
exists in the collection.
attribute_names: an iterable of attribute names to pass into the ``update``
method.
with_for_update: indicating FOR UPDATE should be used, or may be a
dictionary containing flags to indicate a more specific set of
FOR UPDATE flags for the SELECT
auto_expunge: Remove object from session before returning.
auto_refresh: Refresh object from session before returning.
auto_commit: Commit objects before returning.
id_attribute: Allows customization of the unique identifier to use for model fetching.
Defaults to `id`, but can reference any surrogate or candidate key for the table.
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
load: Set relationships to be loaded
execution_options: Set default execution options
uniquify: Optionally apply the ``unique()`` method to results before returning.
bind_group: Optional routing group for multi-master configurations.
Returns:
The updated instance.
"""
self._uniquify = self._get_uniquify(uniquify)
error_messages = self._get_error_messages(
error_messages=error_messages,
default_messages=self.error_messages,
)
with wrap_sqlalchemy_exception(
error_messages=error_messages, dialect_name=self._dialect.name, wrap_exceptions=self.wrap_exceptions
):
item_id = self.get_id_attribute_value(
data,
id_attribute=id_attribute,
)
existing_instance = await self.get(
item_id,
id_attribute=id_attribute,
load=load,
execution_options=execution_options,
with_for_update=with_for_update,
bind_group=bind_group,
)
mapper = None
with (
self.session.no_autoflush,
contextlib.suppress(MissingGreenlet, NoInspectionAvailable),
):
mapper = inspect(data)
if mapper is not None:
for column in mapper.mapper.columns:
field_name = column.key
new_field_value = getattr(data, field_name, MISSING)
if new_field_value is not MISSING:
# Skip setting columns with defaults/onupdate to None during updates
# This prevents overwriting columns that should use their defaults
if new_field_value is None and column_has_defaults(column):
continue
# Only copy attributes that were explicitly set on the input instance
# This prevents overwriting existing values with uninitialized None values
if not was_attribute_set(data, mapper, field_name):
continue
existing_field_value = getattr(existing_instance, field_name, MISSING)
if existing_field_value is not MISSING and not compare_values(
existing_field_value, new_field_value
):
setattr(existing_instance, field_name, new_field_value)
# Handle relationships by merging objects into session first
for relationship in mapper.mapper.relationships:
if relationship.viewonly or relationship.lazy in { # pragma: no cover
"write_only",
"dynamic",
"raise",
"raise_on_sql",
}:
# Skip relationships with incompatible lazy loading strategies
continue
if (new_value := getattr(data, relationship.key, MISSING)) is not MISSING:
# Skip relationships that cannot be handled by generic merge operations
if isinstance(new_value, list):
merged_values = [ # pyright: ignore
await self.session.merge(item, load=False) # pyright: ignore
for item in new_value # pyright: ignore
]
setattr(existing_instance, relationship.key, merged_values)
elif new_value is not None:
merged_value = await self.session.merge(new_value, load=False)
setattr(existing_instance, relationship.key, merged_value)
else:
setattr(existing_instance, relationship.key, new_value)
instance = await self._attach_to_session(existing_instance, strategy="merge")
await self._flush_or_commit(auto_commit=auto_commit)
await self._refresh(
instance,
attribute_names=attribute_names,
with_for_update=with_for_update,
auto_refresh=auto_refresh,
)
self._expunge(instance, auto_expunge=auto_expunge)
# Queue cache invalidation (processed on commit)
self._queue_cache_invalidation(self.get_id_attribute_value(instance), bind_group)
return instance
[docs]
async def update_many(
self,
data: list[ModelT],
*,
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
bind_group: Optional[str] = None,
) -> list[ModelT]:
"""Update one or more instances with the attribute values present on `data`.
This function has an optimized bulk update based on the configured SQL dialect:
- For backends supporting `RETURNING` with `executemany`, a single bulk update with returning clause is executed.
- For other backends, it does a bulk update and then returns the updated data after a refresh.
Args:
data: A list of instances to update. Each should have a value for `self.id_attribute` that exists in the
collection.
auto_expunge: Remove object from session before returning.
auto_commit: Commit objects before returning.
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
load: Set default relationships to be loaded
execution_options: Set default execution options
uniquify: Optionally apply the ``unique()`` method to results before returning.
bind_group: Optional routing group for multi-master configurations.
Returns:
The updated instances.
"""
self._uniquify = self._get_uniquify(uniquify)
error_messages = self._get_error_messages(
error_messages=error_messages,
default_messages=self.error_messages,
)
supports_updated_at = hasattr(self.model_type, "updated_at")
data_to_update: list[dict[str, Any]] = []
for v in data:
if isinstance(v, self.model_type) or (hasattr(v, "to_dict") and callable(v.to_dict)):
update_payload = v.to_dict()
else:
update_payload = cast("dict[str, Any]", schema_dump(v))
if supports_updated_at and (update_payload.get("updated_at") is None):
update_payload["updated_at"] = datetime.datetime.now(datetime.timezone.utc)
data_to_update.append(update_payload)
with wrap_sqlalchemy_exception(
error_messages=error_messages, dialect_name=self._dialect.name, wrap_exceptions=self.wrap_exceptions
):
resolved_bind_group = self._resolve_bind_group(bind_group)
if resolved_bind_group:
execution_options = dict(execution_options) if execution_options else {}
execution_options["bind_group"] = resolved_bind_group
execution_options = self._get_execution_options(execution_options)
loader_options = self._get_loader_options(load)[0]
supports_returning = self._dialect.update_executemany_returning and self._dialect.name != "oracle"
statement = self._get_update_many_statement(
self.model_type,
supports_returning,
loader_options=loader_options,
execution_options=execution_options,
)
if supports_returning:
instances = list(
await self.session.scalars(
statement,
cast("_CoreSingleExecuteParams", data_to_update), # this is not correct but the only way
# currently to deal with an SQLAlchemy typing issue. See
# https://github.com/sqlalchemy/sqlalchemy/discussions/9925
execution_options=execution_options,
),
)
await self._flush_or_commit(auto_commit=auto_commit)
for instance in instances:
self._expunge(instance, auto_expunge=auto_expunge)
return instances
await self.session.execute(statement, data_to_update, execution_options=execution_options)
await self._flush_or_commit(auto_commit=auto_commit)
# For non-RETURNING backends, fetch updated instances from database
updated_ids: list[Any] = [item[self.id_attribute] for item in data_to_update]
updated_instances = await self.list(
getattr(self.model_type, self.id_attribute).in_(updated_ids),
load=loader_options,
execution_options=execution_options,
bind_group=bind_group,
)
for instance in updated_instances:
self._expunge(instance, auto_expunge=auto_expunge)
# Queue cache invalidation (processed on commit)
self._queue_cache_invalidation(self.get_id_attribute_value(instance), bind_group)
return updated_instances
def _get_update_many_statement(
self,
model_type: type[ModelT],
supports_returning: bool,
loader_options: Union[list[_AbstractLoad], None],
execution_options: Union[dict[str, Any], None],
) -> Union[Update, ReturningUpdate[tuple[ModelT]]]:
# Base update statement is static
statement = self._get_base_stmt(
statement=update(table=model_type), loader_options=loader_options, execution_options=execution_options
)
if supports_returning:
return statement.returning(model_type)
return statement
[docs]
async def list_and_count(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
statement: Optional[Select[tuple[ModelT]]] = None,
auto_expunge: Optional[bool] = None,
count_with_window_function: Optional[bool] = None,
order_by: Optional[Union[list[OrderingPair], OrderingPair]] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
use_cache: bool = True,
bind_group: Optional[str] = None,
**kwargs: Any,
) -> tuple[list[ModelT], int]:
"""List records with total count.
Args:
*filters: Types for specific filtering operations.
statement: To facilitate customization of the underlying select query.
auto_expunge: Remove object from session before returning.
count_with_window_function: When false, list and count will use two queries instead of an analytical window function.
order_by: Set default order options for queries.
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
load: Set relationships to be loaded
execution_options: Set default execution options
uniquify: Optionally apply the ``unique()`` method to results before returning.
use_cache: Whether to use the cache for this query. Defaults to ``True``.
bind_group: Optional routing group to use for the operation.
**kwargs: Instance attribute value filters.
Returns:
Count of records returned by query, ignoring pagination.
"""
count_with_window_function = (
count_with_window_function if count_with_window_function is not None else self.count_with_window_function
)
self._uniquify = self._get_uniquify(uniquify)
resolved_error_messages = self._get_error_messages(
error_messages=error_messages,
default_messages=self.error_messages,
)
resolved_auto_expunge = self.auto_expunge if auto_expunge is None else auto_expunge
resolved_execution_options = self._get_execution_options(execution_options)
resolved_order_by = order_by if order_by is not None else (self.order_by if self.order_by is not None else [])
cache_manager = self._cache_manager
if not (
use_cache
and bool(resolved_auto_expunge)
and cache_manager is not None
and statement is None
and load is None
and not self._default_loader_options
):
return await self._list_and_count_from_db(
filters=filters,
auto_expunge=auto_expunge,
statement=statement,
count_with_window_function=count_with_window_function,
order_by=order_by,
error_messages=resolved_error_messages,
load=load,
execution_options=execution_options,
kwargs=kwargs,
uniquify=uniquify,
bind_group=bind_group,
)
model_name = cast("str", self.model_type.__tablename__) # type: ignore[attr-defined]
version_token = await cache_manager.get_model_version_async(model_name)
cache_key = _build_list_cache_key(
model_name=model_name,
version_token=version_token,
method="list_and_count",
filters=filters,
kwargs=kwargs,
order_by=resolved_order_by,
execution_options=resolved_execution_options,
uniquify=self._uniquify,
count_with_window_function=count_with_window_function,
)
if cache_key is None:
return await self._list_and_count_from_db(
filters=filters,
auto_expunge=auto_expunge,
statement=statement,
count_with_window_function=count_with_window_function,
order_by=order_by,
error_messages=resolved_error_messages,
load=load,
execution_options=execution_options,
kwargs=kwargs,
uniquify=uniquify,
bind_group=bind_group,
)
cached = await cache_manager.get_list_and_count_async(cache_key, self.model_type)
if cached is not None:
return cached
return await cache_manager.singleflight_async(
cache_key,
partial(
self._list_and_count_cached_creator,
cache_key,
filters=filters,
auto_expunge=auto_expunge,
statement=statement,
count_with_window_function=count_with_window_function,
order_by=order_by,
error_messages=resolved_error_messages,
load=load,
execution_options=execution_options,
kwargs=kwargs,
uniquify=uniquify,
bind_group=bind_group,
),
)
def _expunge(self, instance: "ModelT", auto_expunge: "Optional[bool]") -> None:
"""Remove instance from session if auto_expunge is enabled.
Args:
instance: The model instance to expunge
auto_expunge: Whether to expunge the instance. If None, uses self.auto_expunge
Note:
Deleted objects that have been committed are automatically moved to the
detached state by SQLAlchemy. Objects returned from DELETE...RETURNING
statements are initially persistent but become detached after commit.
We skip expunge for objects that are already detached or marked for deletion
to avoid InvalidRequestError.
"""
if auto_expunge is None:
auto_expunge = self.auto_expunge
if not auto_expunge:
return
# Check object state before expunging
state = inspect(instance)
if state is not None and (state.deleted or state.detached):
# Skip expunge for objects that are deleted or already detached
# - state.deleted: Object marked for deletion, will be detached on commit
# - state.detached: Object already removed from session (e.g., from DELETE...RETURNING)
return
self.session.expunge(instance)
return
async def _flush_or_commit(self, auto_commit: Optional[bool]) -> None:
if auto_commit is None:
auto_commit = self.auto_commit
return await self.session.commit() if auto_commit else await self.session.flush()
async def _refresh(
self,
instance: ModelT,
auto_refresh: Optional[bool],
attribute_names: Optional[Iterable[str]] = None,
with_for_update: ForUpdateParameter = None,
) -> None:
if auto_refresh is None:
auto_refresh = self.auto_refresh
return (
await self.session.refresh(
instance=instance,
attribute_names=attribute_names,
with_for_update=with_for_update,
)
if auto_refresh
else None
)
async def _list_and_count_window(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
auto_expunge: Optional[bool] = None,
statement: Optional[Select[tuple[ModelT]]] = None,
order_by: Optional[Union[list[OrderingPair], OrderingPair]] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
bind_group: Optional[str] = None,
**kwargs: Any,
) -> tuple[list[ModelT], int]:
"""List records with total count.
Args:
*filters: Types for specific filtering operations.
auto_expunge: Remove object from session before returning.
statement: To facilitate customization of the underlying select query.
order_by: List[OrderingPair] | OrderingPair | None = None,
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
load: Set relationships to be loaded
execution_options: Set default execution options
bind_group: Optional routing group to use for the operation.
**kwargs: Instance attribute value filters.
Returns:
Count of records returned by query using an analytical window function, ignoring pagination.
"""
error_messages = self._get_error_messages(
error_messages=error_messages,
default_messages=self.error_messages,
)
with wrap_sqlalchemy_exception(
error_messages=error_messages, dialect_name=self._dialect.name, wrap_exceptions=self.wrap_exceptions
):
if bind_group:
execution_options = dict(execution_options) if execution_options else {}
execution_options["bind_group"] = bind_group
execution_options = self._get_execution_options(execution_options)
statement = self.statement if statement is None else statement
loader_options, loader_options_have_wildcard = self._get_loader_options(load)
statement = self._get_base_stmt(
statement=statement,
loader_options=loader_options,
execution_options=execution_options,
)
if order_by is None:
order_by = self.order_by if self.order_by is not None else []
statement = self._apply_order_by(statement=statement, order_by=order_by)
statement = self._apply_filters(*filters, statement=statement)
statement = self._filter_select_by_kwargs(statement, kwargs)
result = await self._execute(
statement.add_columns(over(sql_func.count())), uniquify=loader_options_have_wildcard
)
count: int = 0
instances: list[ModelT] = []
for i, (instance, count_value) in enumerate(result):
self._expunge(instance, auto_expunge=auto_expunge)
instances.append(instance)
if i == 0:
count = count_value
return instances, count
async def _list_and_count_basic(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
auto_expunge: Optional[bool] = None,
statement: Optional[Select[tuple[ModelT]]] = None,
order_by: Optional[Union[list[OrderingPair], OrderingPair]] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
bind_group: Optional[str] = None,
**kwargs: Any,
) -> tuple[list[ModelT], int]:
"""List records with total count.
Args:
*filters: Types for specific filtering operations.
auto_expunge: Remove object from session before returning.
statement: To facilitate customization of the underlying select query.
order_by: Set default order options for queries.
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
load: Set relationships to be loaded
execution_options: Set default execution options
bind_group: Optional routing group to use for the operation.
**kwargs: Instance attribute value filters.
Returns:
Count of records returned by query using 2 queries, ignoring pagination.
"""
error_messages = self._get_error_messages(
error_messages=error_messages,
default_messages=self.error_messages,
)
with wrap_sqlalchemy_exception(
error_messages=error_messages, dialect_name=self._dialect.name, wrap_exceptions=self.wrap_exceptions
):
if bind_group:
execution_options = dict(execution_options) if execution_options else {}
execution_options["bind_group"] = bind_group
execution_options = self._get_execution_options(execution_options)
statement = self.statement if statement is None else statement
loader_options, loader_options_have_wildcard = self._get_loader_options(load)
statement = self._get_base_stmt(
statement=statement,
loader_options=loader_options,
execution_options=execution_options,
)
if order_by is None:
order_by = self.order_by if self.order_by is not None else []
statement = self._apply_order_by(statement=statement, order_by=order_by)
statement = self._apply_filters(*filters, statement=statement)
statement = self._filter_select_by_kwargs(statement, kwargs)
count_result = await self.session.execute(
self._get_count_stmt(
statement,
loader_options=loader_options,
execution_options=execution_options,
),
)
count = count_result.scalar_one()
if count == 0:
return [], 0
result = await self._execute(statement, uniquify=loader_options_have_wildcard)
instances: list[ModelT] = []
for (instance,) in result:
self._expunge(instance, auto_expunge=auto_expunge)
instances.append(instance)
return instances, count
@staticmethod
def _get_count_stmt(
statement: Select[tuple[ModelT]],
loader_options: Optional[list[_AbstractLoad]], # noqa: ARG004
execution_options: Optional[dict[str, Any]], # noqa: ARG004
) -> Select[tuple[int]]:
# Count statement transformations are static
return (
statement.with_only_columns(sql_func.count(text("1")), maintain_column_froms=True)
.limit(None)
.offset(None)
.order_by(None)
)
[docs]
async def upsert(
self,
data: ModelT,
*,
attribute_names: Optional[Iterable[str]] = None,
with_for_update: ForUpdateParameter = None,
auto_expunge: Optional[bool] = None,
auto_commit: Optional[bool] = None,
auto_refresh: Optional[bool] = None,
match_fields: Optional[Union[list[str], str]] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
bind_group: Optional[str] = None,
) -> ModelT:
"""Modify or create instance.
Updates instance with the attribute values present on `data`, or creates a new instance if
one doesn't exist.
Args:
data: Instance to update existing, or be created. Identifier used to determine if an
existing instance exists is the value of an attribute on `data` named as value of
`self.id_attribute`.
attribute_names: an iterable of attribute names to pass into the ``update`` method.
with_for_update: indicating FOR UPDATE should be used, or may be a
dictionary containing flags to indicate a more specific set of
FOR UPDATE flags for the SELECT
auto_expunge: Remove object from session before returning.
auto_refresh: Refresh object from session before returning.
auto_commit: Commit objects before returning.
match_fields: a list of keys to use to match the existing model. When
empty, all fields are matched.
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
load: Set relationships to be loaded
execution_options: Set default execution options
uniquify: Optionally apply the ``unique()`` method to results before returning.
bind_group: Optional routing group for multi-master configurations.
Returns:
The updated or created instance.
"""
self._uniquify = self._get_uniquify(uniquify)
error_messages = self._get_error_messages(
error_messages=error_messages,
default_messages=self.error_messages,
)
if match_fields := self._get_match_fields(match_fields=match_fields):
match_filter = {
field_name: getattr(data, field_name, None)
for field_name in match_fields
if getattr(data, field_name, None) is not None
}
elif getattr(data, self.id_attribute, None) is not None:
match_filter = {self.id_attribute: getattr(data, self.id_attribute, None)}
else:
match_filter = data.to_dict(exclude={self.id_attribute})
existing = await self.get_one_or_none(
load=load, execution_options=execution_options, bind_group=bind_group, **match_filter
)
if not existing:
return await self.add(
data,
auto_commit=auto_commit,
auto_expunge=auto_expunge,
auto_refresh=auto_refresh,
bind_group=bind_group,
)
with wrap_sqlalchemy_exception(
error_messages=error_messages, dialect_name=self._dialect.name, wrap_exceptions=self.wrap_exceptions
):
for field_name, new_field_value in data.to_dict(exclude={self.id_attribute}).items():
field = getattr(existing, field_name, MISSING)
if field is not MISSING and not compare_values(field, new_field_value): # pragma: no cover
setattr(existing, field_name, new_field_value)
instance = await self._attach_to_session(existing, strategy="merge")
await self._flush_or_commit(auto_commit=auto_commit)
await self._refresh(
instance,
attribute_names=attribute_names,
with_for_update=with_for_update,
auto_refresh=auto_refresh,
)
self._expunge(instance, auto_expunge=auto_expunge)
# Queue cache invalidation (processed on commit)
self._queue_cache_invalidation(self.get_id_attribute_value(instance), bind_group)
return instance
[docs]
async def upsert_many(
self,
data: list[ModelT],
*,
auto_expunge: Optional[bool] = None,
auto_commit: Optional[bool] = None,
no_merge: bool = False,
match_fields: Optional[Union[list[str], str]] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
bind_group: Optional[str] = None,
) -> list[ModelT]:
"""Modify or create multiple instances.
Update instances with the attribute values present on `data`, or create a new instance if
one doesn't exist.
!!! tip
In most cases, you will want to set `match_fields` to the combination of attributes, excluded the primary key, that define uniqueness for a row.
Args:
data: Instance to update existing, or be created. Identifier used to determine if an
existing instance exists is the value of an attribute on ``data`` named as value of
:attr:`id_attribute`.
auto_expunge: Remove object from session before returning.
auto_commit: Commit objects before returning.
no_merge: Skip the usage of optimized Merge statements
match_fields: a list of keys to use to match the existing model. When
empty, automatically uses ``self.id_attribute`` (`id` by default) to match .
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
load: Set default relationships to be loaded
execution_options: Set default execution options
uniquify: Optionally apply the ``unique()`` method to results before returning.
bind_group: Optional routing group to use for the operation.
Returns:
The updated or created instance.
"""
self._uniquify = self._get_uniquify(uniquify)
error_messages = self._get_error_messages(
error_messages=error_messages,
default_messages=self.error_messages,
)
instances: list[ModelT] = []
data_to_update: list[ModelT] = []
data_to_insert: list[ModelT] = []
match_fields = self._get_match_fields(match_fields=match_fields)
if match_fields is None:
match_fields = [self.id_attribute]
match_filter: list[Union[StatementFilter, ColumnElement[bool]]] = []
if match_fields:
for field_name in match_fields:
field = get_instrumented_attr(self.model_type, field_name)
matched_values = [
field_data for datum in data if (field_data := getattr(datum, field_name)) is not None
]
# Use field.in_() if types are incompatible with ANY() or if dialect doesn't prefer ANY()
use_in = not self._prefer_any or self._type_must_use_in_instead_of_any(matched_values, field.type)
match_filter.append(field.in_(matched_values) if use_in else any_(matched_values) == field) # type: ignore[arg-type]
with wrap_sqlalchemy_exception(
error_messages=error_messages, dialect_name=self._dialect.name, wrap_exceptions=self.wrap_exceptions
):
existing_objs = await self.list(
*match_filter,
load=load,
execution_options=execution_options,
auto_expunge=False,
bind_group=bind_group,
)
for field_name in match_fields:
field = get_instrumented_attr(self.model_type, field_name)
# Safe deduplication that handles unhashable types (e.g., JSONB dicts)
all_values = [getattr(datum, field_name) for datum in existing_objs if datum]
matched_values = self._get_unique_values(all_values)
# Use field.in_() if types are incompatible with ANY() or if dialect doesn't prefer ANY()
use_in = not self._prefer_any or self._type_must_use_in_instead_of_any(matched_values, field.type)
match_filter.append(field.in_(matched_values) if use_in else any_(matched_values) == field) # type: ignore[arg-type]
existing_ids = self._get_object_ids(existing_objs=existing_objs)
data = self._merge_on_match_fields(data, existing_objs, match_fields)
for datum in data:
if getattr(datum, self.id_attribute, None) in existing_ids:
data_to_update.append(datum)
else:
data_to_insert.append(datum)
if data_to_insert:
instances.extend(
await self.add_many(data_to_insert, auto_commit=False, auto_expunge=False, bind_group=bind_group),
)
if data_to_update:
instances.extend(
await self.update_many(
data_to_update,
auto_commit=False,
auto_expunge=False,
load=load,
execution_options=execution_options,
bind_group=bind_group,
),
)
await self._flush_or_commit(auto_commit=auto_commit)
for instance in instances:
self._expunge(instance, auto_expunge=auto_expunge)
return instances
def _get_object_ids(self, existing_objs: list[ModelT]) -> list[Any]:
return [obj_id for datum in existing_objs if (obj_id := getattr(datum, self.id_attribute)) is not None]
def _get_match_fields(
self,
match_fields: Optional[Union[list[str], str]] = None,
id_attribute: Optional[str] = None,
) -> Optional[list[str]]:
id_attribute = id_attribute or self.id_attribute
match_fields = match_fields or self.match_fields
if isinstance(match_fields, str):
match_fields = [match_fields]
return match_fields
def _merge_on_match_fields(
self,
data: list[ModelT],
existing_data: list[ModelT],
match_fields: Optional[Union[list[str], str]] = None,
) -> list[ModelT]:
match_fields = self._get_match_fields(match_fields=match_fields)
if match_fields is None:
match_fields = [self.id_attribute]
for existing_datum in existing_data:
for datum in data:
match = all(
getattr(datum, field_name) == getattr(existing_datum, field_name) for field_name in match_fields
)
if match and getattr(existing_datum, self.id_attribute) is not None:
setattr(datum, self.id_attribute, getattr(existing_datum, self.id_attribute))
return data
[docs]
async def list(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
auto_expunge: Optional[bool] = None,
statement: Optional[Select[tuple[ModelT]]] = None,
order_by: Optional[Union[list[OrderingPair], OrderingPair]] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
use_cache: bool = True,
bind_group: Optional[str] = None,
**kwargs: Any,
) -> list[ModelT]:
"""Get a list of instances, optionally filtered.
Args:
*filters: Types for specific filtering operations.
auto_expunge: Remove object from session before returning.
statement: To facilitate customization of the underlying select query.
order_by: Set default order options for queries.
error_messages: An optional dictionary of templates to use
for friendlier error messages to clients
load: Set relationships to be loaded
execution_options: Set default execution options
uniquify: Optionally apply the ``unique()`` method to results before returning.
use_cache: Whether to use the cache for this query. Defaults to ``True``.
bind_group: Optional routing group to use for the operation.
**kwargs: Instance attribute value filters.
Returns:
The list of instances, after filtering applied.
"""
self._uniquify = self._get_uniquify(uniquify)
resolved_error_messages = self._get_error_messages(
error_messages=error_messages,
default_messages=self.error_messages,
)
resolved_auto_expunge = self.auto_expunge if auto_expunge is None else auto_expunge
resolved_execution_options = self._get_execution_options(execution_options)
resolved_order_by = order_by if order_by is not None else (self.order_by if self.order_by is not None else [])
cache_manager = self._cache_manager
if not (
use_cache
and bool(resolved_auto_expunge)
and cache_manager is not None
and statement is None
and load is None
and not self._default_loader_options
):
return await self._list_from_db(
filters=filters,
auto_expunge=auto_expunge,
statement=statement,
order_by=order_by,
error_messages=resolved_error_messages,
load=load,
execution_options=execution_options,
kwargs=kwargs,
uniquify=uniquify,
bind_group=bind_group,
)
model_name = cast("str", self.model_type.__tablename__) # type: ignore[attr-defined]
version_token = await cache_manager.get_model_version_async(model_name)
cache_key = _build_list_cache_key(
model_name=model_name,
version_token=version_token,
method="list",
filters=filters,
kwargs=kwargs,
order_by=resolved_order_by,
execution_options=resolved_execution_options,
uniquify=self._uniquify,
)
if cache_key is None:
return await self._list_from_db(
filters=filters,
auto_expunge=auto_expunge,
statement=statement,
order_by=order_by,
error_messages=resolved_error_messages,
load=load,
execution_options=execution_options,
kwargs=kwargs,
uniquify=uniquify,
bind_group=bind_group,
)
cached = await cache_manager.get_list_async(cache_key, self.model_type)
if cached is not None:
return cached
return await cache_manager.singleflight_async(
cache_key,
partial(
self._list_cached_creator,
cache_key,
filters=filters,
auto_expunge=auto_expunge,
statement=statement,
order_by=order_by,
error_messages=resolved_error_messages,
load=load,
execution_options=execution_options,
kwargs=kwargs,
uniquify=uniquify,
bind_group=bind_group,
),
)
[docs]
@classmethod
async def check_health(cls, session: Union[AsyncSession, async_scoped_session[AsyncSession]]) -> bool:
"""Perform a health check on the database.
Args:
session: through which we run a check statement
Returns:
``True`` if healthy.
"""
with wrap_sqlalchemy_exception():
return ( # type: ignore[no-any-return]
await session.execute(cls._get_health_check_statement(session))
).scalar_one() == 1
@staticmethod
def _get_health_check_statement(session: Union[AsyncSession, async_scoped_session[AsyncSession]]) -> TextClause:
if session.bind and session.bind.dialect.name == "oracle":
return text("SELECT 1 FROM DUAL")
return text("SELECT 1")
async def _attach_to_session(
self, model: ModelT, strategy: Literal["add", "merge"] = "add", load: bool = True
) -> ModelT:
"""Attach detached instance to the session.
Args:
model: The instance to be attached to the session.
strategy: How the instance should be attached.
- "add": New instance added to session
- "merge": Instance merged with existing, or new one added.
load: Boolean, when False, merge switches into
a "high performance" mode which causes it to forego emitting history
events as well as all database access. This flag is used for
cases such as transferring graphs of objects into a session
from a second level cache, or to transfer just-loaded objects
into the session owned by a worker thread or process
without re-querying the database.
Raises:
ValueError: If `strategy` is not one of the expected values.
Returns:
Instance attached to the session - if `"merge"` strategy, may not be same instance
that was provided.
"""
if strategy == "add":
self.session.add(model)
return model
if strategy == "merge":
return await self.session.merge(model, load=load)
msg = "Unexpected value for `strategy`, must be `'add'` or `'merge'`" # type: ignore[unreachable]
raise ValueError(msg)
async def _execute(
self,
statement: Select[Any],
uniquify: bool = False,
) -> Result[Any]:
result = await self.session.execute(statement)
if uniquify or self._uniquify:
result = result.unique()
return result
[docs]
class SQLAlchemyAsyncSlugRepository(
SQLAlchemyAsyncRepository[ModelT],
SQLAlchemyAsyncSlugRepositoryProtocol[ModelT],
):
"""Extends the repository to include slug model features.."""
[docs]
async def get_by_slug(
self,
slug: str,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
bind_group: Optional[str] = None,
**kwargs: Any,
) -> Optional[ModelT]:
"""Select record by slug value.
Returns:
The model instance or None if not found.
"""
return await self.get_one_or_none(
slug=slug,
load=load,
execution_options=execution_options,
error_messages=error_messages,
uniquify=uniquify,
bind_group=bind_group,
)
[docs]
async def get_available_slug(
self,
value_to_slugify: str,
**kwargs: Any,
) -> str:
"""Get a unique slug for the supplied value.
If the value is found to exist, a random 4 digit character is appended to the end.
Override this method to change the default behavior
Args:
value_to_slugify (str): A string that should be converted to a unique slug.
**kwargs: stuff
Returns:
str: a unique slug for the supplied value. This is safe for URLs and other unique identifiers.
"""
slug = slugify(value_to_slugify)
if await self._is_slug_unique(slug):
return slug
random_string = "".join(random.choices(string.ascii_lowercase + string.digits, k=4)) # noqa: S311
return f"{slug}-{random_string}"
async def _is_slug_unique(
self,
slug: str,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
**kwargs: Any,
) -> bool:
return await self.exists(slug=slug, load=load, execution_options=execution_options, **kwargs) is False
[docs]
class SQLAlchemyAsyncQueryRepository:
"""SQLAlchemy Query Repository.
This is a loosely typed helper to query for when you need to select data in ways that don't align to the normal repository pattern.
"""
error_messages: Optional[ErrorMessages] = None
wrap_exceptions: bool = True
[docs]
def __init__(
self,
*,
session: Union[AsyncSession, async_scoped_session[AsyncSession]],
error_messages: Optional[ErrorMessages] = None,
wrap_exceptions: bool = True,
**kwargs: Any,
) -> None:
"""Repository pattern for SQLAlchemy models.
Args:
session: Session managing the unit-of-work for the operation.
error_messages: A set of error messages to use for operations.
wrap_exceptions: Whether to wrap exceptions in a SQLAlchemy exception.
**kwargs: Additional arguments (ignored).
"""
self.session = session
self.error_messages = error_messages
self.wrap_exceptions = wrap_exceptions
self._dialect = self.session.bind.dialect if self.session.bind is not None else self.session.get_bind().dialect
[docs]
async def get_one(
self,
statement: Select[tuple[Any]],
bind_group: Optional[str] = None,
**kwargs: Any,
) -> Row[Any]:
"""Get instance identified by ``kwargs``.
Args:
statement: To facilitate customization of the underlying select query.
bind_group: The bind group to use for the operation.
**kwargs: Instance attribute value filters.
Returns:
The retrieved instance.
"""
with wrap_sqlalchemy_exception(error_messages=self.error_messages, wrap_exceptions=self.wrap_exceptions):
statement = self._filter_statement_by_kwargs(statement, **kwargs)
execution_options = {"bind_group": bind_group} if bind_group else None
instance = (await self.execute(statement, execution_options=execution_options)).scalar_one_or_none()
return self.check_not_found(instance)
[docs]
async def get_one_or_none(
self,
statement: Select[Any],
bind_group: Optional[str] = None,
**kwargs: Any,
) -> Optional[Row[Any]]:
"""Get instance identified by ``kwargs`` or None if not found.
Args:
statement: To facilitate customization of the underlying select query.
bind_group: The bind group to use for the operation.
**kwargs: Instance attribute value filters.
Returns:
The retrieved instance or None
"""
with wrap_sqlalchemy_exception(error_messages=self.error_messages, wrap_exceptions=self.wrap_exceptions):
statement = self._filter_statement_by_kwargs(statement, **kwargs)
execution_options = {"bind_group": bind_group} if bind_group else None
instance = (await self.execute(statement, execution_options=execution_options)).scalar_one_or_none()
return instance or None
[docs]
async def count(self, statement: Select[Any], bind_group: Optional[str] = None, **kwargs: Any) -> int:
"""Get the count of records returned by a query.
Args:
statement: To facilitate customization of the underlying select query.
bind_group: The bind group to use for the operation.
**kwargs: Instance attribute value filters.
Returns:
Count of records returned by query, ignoring pagination.
"""
with wrap_sqlalchemy_exception(error_messages=self.error_messages, wrap_exceptions=self.wrap_exceptions):
statement = statement.with_only_columns(sql_func.count(text("1")), maintain_column_froms=True).order_by(
None,
)
statement = self._filter_statement_by_kwargs(statement, **kwargs)
execution_options = {"bind_group": bind_group} if bind_group else None
results = await self.execute(statement, execution_options=execution_options)
return results.scalar_one() # type: ignore
[docs]
async def list_and_count(
self,
statement: Select[Any],
count_with_window_function: Optional[bool] = None,
bind_group: Optional[str] = None,
**kwargs: Any,
) -> tuple[list[Row[Any]], int]:
"""List records with total count.
Args:
statement: To facilitate customization of the underlying select query.
count_with_window_function: Force list and count to use two queries instead of an analytical window function.
bind_group: The bind group to use for the operation.
**kwargs: Instance attribute value filters.
Returns:
Count of records returned by query, ignoring pagination.
"""
if self._dialect.name in {"spanner", "spanner+spanner"} or count_with_window_function:
return await self._list_and_count_basic(statement=statement, bind_group=bind_group, **kwargs)
return await self._list_and_count_window(statement=statement, bind_group=bind_group, **kwargs)
async def _list_and_count_window(
self,
statement: Select[Any],
bind_group: Optional[str] = None,
**kwargs: Any,
) -> tuple[list[Row[Any]], int]:
"""List records with total count.
Args:
*filters: Types for specific filtering operations.
statement: To facilitate customization of the underlying select query.
bind_group: The bind group to use for the operation.
**kwargs: Instance attribute value filters.
Returns:
Count of records returned by query using an analytical window function, ignoring pagination.
"""
with wrap_sqlalchemy_exception(error_messages=self.error_messages, wrap_exceptions=self.wrap_exceptions):
statement = statement.add_columns(over(sql_func.count(text("1"))))
statement = self._filter_statement_by_kwargs(statement, **kwargs)
execution_options = {"bind_group": bind_group} if bind_group else None
result = await self.execute(statement, execution_options=execution_options)
count: int = 0
instances: list[Row[Any]] = []
for i, (instance, count_value) in enumerate(result):
instances.append(instance)
if i == 0:
count = count_value
return instances, count
@staticmethod
def _get_count_stmt(statement: Select[Any]) -> Select[Any]:
return statement.with_only_columns(sql_func.count(text("1")), maintain_column_froms=True).order_by(None) # pyright: ignore[reportUnknownVariable]
async def _list_and_count_basic(
self,
statement: Select[Any],
bind_group: Optional[str] = None,
**kwargs: Any,
) -> tuple[list[Row[Any]], int]:
"""List records with total count.
Args:
statement: To facilitate customization of the underlying select query. .
bind_group: The bind group to use for the operation.
**kwargs: Instance attribute value filters.
Returns:
Count of records returned by query using 2 queries, ignoring pagination.
"""
with wrap_sqlalchemy_exception(error_messages=self.error_messages, wrap_exceptions=self.wrap_exceptions):
statement = self._filter_statement_by_kwargs(statement, **kwargs)
execution_options = {"bind_group": bind_group} if bind_group else None
count_result = await self.session.execute(
self._get_count_stmt(statement), execution_options=execution_options or {}
)
count = count_result.scalar_one()
result = await self.execute(statement, execution_options=execution_options)
instances: list[Row[Any]] = []
for (instance,) in result:
instances.append(instance)
return instances, count
[docs]
async def list(self, statement: Select[Any], bind_group: Optional[str] = None, **kwargs: Any) -> list[Row[Any]]:
"""Get a list of instances, optionally filtered.
Args:
statement: To facilitate customization of the underlying select query.
bind_group: The bind group to use for the operation.
**kwargs: Instance attribute value filters.
Returns:
The list of instances, after filtering applied.
"""
with wrap_sqlalchemy_exception(error_messages=self.error_messages, wrap_exceptions=self.wrap_exceptions):
statement = self._filter_statement_by_kwargs(statement, **kwargs)
execution_options = {"bind_group": bind_group} if bind_group else None
result = await self.execute(statement, execution_options=execution_options)
return list(result.all())
def _filter_statement_by_kwargs(
self,
statement: Select[Any],
/,
**kwargs: Any,
) -> Select[Any]:
"""Filter the collection by kwargs.
Args:
statement: statement to filter
**kwargs: key/value pairs such that objects remaining in the statement after filtering
have the property that their attribute named `key` has value equal to `value`.
Returns:
The filtered statement.
"""
with wrap_sqlalchemy_exception(error_messages=self.error_messages):
return statement.filter_by(**kwargs)
# the following is all sqlalchemy implementation detail, and shouldn't be directly accessed
[docs]
@staticmethod
def check_not_found(item_or_none: Optional[T]) -> T:
"""Raise :class:`NotFoundError` if ``item_or_none`` is ``None``.
Args:
item_or_none: Item to be tested for existence.
Raises:
NotFoundError: If ``item_or_none`` is ``None``
Returns:
The item, if it exists.
"""
if item_or_none is None:
msg = "No item found when one was expected"
raise NotFoundError(msg)
return item_or_none
[docs]
async def execute(
self,
statement: Union[
ReturningDelete[tuple[Any]], ReturningUpdate[tuple[Any]], Select[tuple[Any]], Update, Delete, Select[Any]
],
execution_options: Optional[dict[str, Any]] = None,
) -> Result[Any]:
return await self.session.execute(statement, execution_options=execution_options or {})