# Do not edit this file directly. It has been autogenerated from
# advanced_alchemy/repository/memory/_async.py
import datetime
import random
import re
import string
from collections import abc
from collections.abc import Iterable
from typing import Any, List, Optional, Union, cast, overload
from unittest.mock import create_autospec
from sqlalchemy import (
ColumnElement,
Dialect,
Select,
StatementLambdaElement,
Update,
)
from sqlalchemy.orm import InstrumentedAttribute, Session, class_mapper
from sqlalchemy.orm.scoping import scoped_session
from sqlalchemy.orm.strategy_options import _AbstractLoad # pyright: ignore[reportPrivateUsage]
from sqlalchemy.sql.dml import ReturningUpdate
from sqlalchemy.sql.selectable import ForUpdateParameter
from typing_extensions import Self
from advanced_alchemy.exceptions import ErrorMessages, IntegrityError, NotFoundError, RepositoryError
from advanced_alchemy.filters import (
BeforeAfter,
CollectionFilter,
LimitOffset,
NotInCollectionFilter,
NotInSearchFilter,
OnBeforeAfter,
OrderBy,
SearchFilter,
StatementFilter,
)
from advanced_alchemy.repository._sync import SQLAlchemySyncRepositoryProtocol, SQLAlchemySyncSlugRepositoryProtocol
from advanced_alchemy.repository._util import (
DEFAULT_ERROR_MESSAGE_TEMPLATES,
LoadSpec,
compare_values,
extract_pk_value_from_instance,
is_composite_pk,
normalize_pk_to_tuple,
pk_values_present,
)
from advanced_alchemy.repository.memory.base import (
AnyObject,
InMemoryStore,
SQLAlchemyInMemoryStore,
SQLAlchemyMultiStore,
)
from advanced_alchemy.repository.typing import MISSING, ModelT, OrderingPair, PrimaryKeyType
from advanced_alchemy.utils.dataclass import Empty, EmptyType
from advanced_alchemy.utils.deprecation import warn_deprecation
from advanced_alchemy.utils.text import slugify
[docs]
class SQLAlchemySyncMockRepository(SQLAlchemySyncRepositoryProtocol[ModelT]):
"""In memory repository."""
__database__: SQLAlchemyMultiStore[ModelT] = SQLAlchemyMultiStore(SQLAlchemyInMemoryStore)
__database_registry__: dict[type[Self], SQLAlchemyMultiStore[ModelT]] = {}
loader_options: Optional[LoadSpec] = None
"""Default loader options for the repository."""
execution_options: Optional[dict[str, Any]] = None
"""Default execution options for the repository."""
model_type: type[ModelT]
id_attribute: Any = "id"
match_fields: Optional[Union[List[str], str]] = None
uniquify: bool = False
_exclude_kwargs: set[str] = {
"statement",
"session",
"auto_expunge",
"auto_refresh",
"auto_commit",
"attribute_names",
"with_for_update",
"count_with_window_function",
"loader_options",
"execution_options",
"order_by",
"load",
"error_messages",
"wrap_exceptions",
"uniquify",
"bind_group",
}
[docs]
def __init__(
self,
*,
statement: Union[Select[tuple[ModelT]], StatementLambdaElement, None] = None,
session: Union[Session, scoped_session[Session]],
auto_expunge: bool = False,
auto_refresh: bool = True,
auto_commit: bool = False,
order_by: Optional[Union[List[OrderingPair], OrderingPair]] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
wrap_exceptions: bool = True,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
**kwargs: Any,
) -> None:
self.session = session
self.statement = create_autospec("Select[Tuple[ModelT]]", instance=True)
self.auto_expunge = auto_expunge
self.auto_refresh = auto_refresh
self.auto_commit = auto_commit
self.error_messages = self._get_error_messages(
error_messages=error_messages, default_messages=self.error_messages
)
self.wrap_exceptions = wrap_exceptions
self.order_by = order_by
self._dialect: Dialect = create_autospec(Dialect, instance=True)
self._dialect.name = "mock"
self.__filtered_store__: InMemoryStore[ModelT] = self.__database__.store_type()
self._default_options: Any = []
self._default_execution_options: Any = {}
self._loader_options: Any = []
self._loader_options_have_wildcards = False
self.uniquify = bool(uniquify)
def __init_subclass__(cls) -> None:
cls.__database_registry__[cls] = cls.__database__ # type: ignore[index]
@property
def _pk_columns(self) -> tuple[Any, ...]:
"""Get primary key columns from the model mapper.
Returns:
Tuple of Column objects representing the primary key.
"""
mapper = class_mapper(self.model_type)
return tuple(mapper.primary_key)
@property
def pk_attr_names(self) -> tuple[str, ...]:
"""Get primary key attribute names from the model mapper.
Uses mapper.get_property_by_column() to get ORM attribute names,
which may differ from column names when using Column("sql_name").
Returns:
Tuple of ORM attribute names for primary key columns.
"""
mapper = class_mapper(self.model_type)
return tuple(mapper.get_property_by_column(col).key for col in self._pk_columns)
@property
def has_composite_pk(self) -> bool:
"""Check if model has a composite (multi-column) primary key.
Returns:
True if the model has 2 or more primary key columns, False otherwise.
"""
return is_composite_pk(self._pk_columns)
[docs]
def get_primary_key_value(self, instance: ModelT) -> PrimaryKeyType:
"""Extract the primary key value(s) from a model instance.
Args:
instance: Model instance to extract primary key from.
Returns:
- For single PK: scalar value
- For composite PK: tuple of values in column order
"""
return extract_pk_value_from_instance(instance, self.pk_attr_names)
[docs]
def has_primary_key_values(self, instance: ModelT) -> bool:
"""Check if all primary key values are set on an instance.
Args:
instance: Model instance to check.
Returns:
True if all PK values are non-None, False otherwise.
"""
return pk_values_present(instance, self.pk_attr_names)
def _normalize_pk_to_tuple(self, pk_value: PrimaryKeyType) -> tuple[Any, ...]:
"""Normalize a primary key value to a tuple for consistent storage key generation.
Args:
pk_value: Primary key value (scalar, tuple, or dict).
Returns:
Tuple representation of the primary key.
"""
return normalize_pk_to_tuple(pk_value, self.pk_attr_names, self.model_type.__name__)
def _get_store_key(self, pk_value: PrimaryKeyType) -> str:
"""Generate a store key from a primary key value.
Args:
pk_value: Primary key value (scalar, tuple, or dict).
Returns:
String key for the in-memory store.
"""
pk_tuple = self._normalize_pk_to_tuple(pk_value)
if len(pk_tuple) > 1:
return str(pk_tuple)
return str(pk_tuple[0]) if pk_tuple else ""
def _get_store_key_from_instance(self, instance: ModelT) -> str:
"""Generate a store key from a model instance.
Args:
instance: Model instance to generate key from.
Returns:
String key for the in-memory store.
"""
pk_value = self.get_primary_key_value(instance)
return str(pk_value)
@staticmethod
def _get_error_messages(
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
default_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
) -> Optional[ErrorMessages]:
if error_messages == Empty:
error_messages = None
if default_messages == Empty:
default_messages = None
messages = cast("ErrorMessages", dict(DEFAULT_ERROR_MESSAGE_TEMPLATES))
if default_messages:
messages.update(cast("ErrorMessages", default_messages)) # type: ignore[unused-ignore,redundant-cast]
if error_messages:
messages.update(cast("ErrorMessages", error_messages)) # type: ignore[unused-ignore,redundant-cast]
return messages
@classmethod
def __database_add__(cls, identity: Any, data: ModelT) -> ModelT:
return cast("ModelT", cls.__database__.add(identity, data)) # type: ignore[redundant-cast]
@classmethod
def __database_clear__(cls) -> None:
for database in cls.__database_registry__.values(): # pyright: ignore[reportGeneralTypeIssues,reportUnknownMemberType]
database.remove_all()
@overload
def __collection__(self) -> InMemoryStore[ModelT]: ...
@overload
def __collection__(self, identity: type[AnyObject]) -> InMemoryStore[AnyObject]: ...
def __collection__(
self,
identity: Optional[type[AnyObject]] = None,
) -> Union[InMemoryStore[AnyObject], InMemoryStore[ModelT]]:
if identity:
return self.__database__.store(identity)
return self.__filtered_store__ or self.__database__.store(self.model_type)
[docs]
@staticmethod
def check_not_found(item_or_none: Union[ModelT, None]) -> ModelT:
if item_or_none is None:
msg = "No item found when one was expected"
raise NotFoundError(msg)
return item_or_none
[docs]
@classmethod
def get_id_attribute_value(
cls,
item: Union[ModelT, type[ModelT]],
id_attribute: Union[str, InstrumentedAttribute[Any], None] = None,
) -> Any:
"""Get value of attribute named as :attr:`id_attribute` on ``item``.
Args:
item: Anything that should have an attribute named as :attr:`id_attribute` value.
id_attribute: Allows customization of the unique identifier to use for model fetching.
Defaults to `None`, but can reference any surrogate or candidate key for the table.
Returns:
The value of attribute on ``item`` named as :attr:`id_attribute`.
"""
if isinstance(id_attribute, InstrumentedAttribute):
id_attribute = id_attribute.key
return getattr(item, id_attribute if id_attribute is not None else cls.id_attribute)
[docs]
@classmethod
def set_id_attribute_value(
cls,
item_id: Any,
item: ModelT,
id_attribute: Union[str, InstrumentedAttribute[Any], None] = None,
) -> ModelT:
"""Return the ``item`` after the ID is set to the appropriate attribute.
Args:
item_id: Value of ID to be set on instance
item: Anything that should have an attribute named as :attr:`id_attribute` value.
id_attribute: Allows customization of the unique identifier to use for model fetching.
Defaults to `None`, but can reference any surrogate or candidate key for the table.
Returns:
Item with ``item_id`` set to :attr:`id_attribute`
"""
if isinstance(id_attribute, InstrumentedAttribute):
id_attribute = id_attribute.key
setattr(item, id_attribute if id_attribute is not None else cls.id_attribute, item_id)
return item
def _exclude_unused_kwargs(self, kwargs: dict[str, Any]) -> dict[str, Any]:
return {key: value for key, value in kwargs.items() if key not in self._exclude_kwargs}
@staticmethod
def _apply_limit_offset_pagination(result: List[ModelT], limit: int, offset: int) -> List[ModelT]:
return result[offset:limit]
def _extract_field_name(self, field: "Union[str, ColumnElement[Any], InstrumentedAttribute[Any]]") -> str:
"""Extract string field name from various input types.
Args:
field: Field name, column element, or instrumented attribute
Returns:
str: String field name for use with getattr()
Raises:
RepositoryError: If a ColumnElement (func expression) is used with mock repository
"""
if isinstance(field, str):
return field
if isinstance(field, InstrumentedAttribute):
return field.key
msg = f"{type(field)} columns are not supported in mock repositories (in-memory filtering)"
raise RepositoryError(msg)
def _filter_in_collection(
self,
result: List[ModelT],
field_name: "Union[str, ColumnElement[Any], InstrumentedAttribute[Any]]",
values: abc.Collection[Any],
) -> List[ModelT]:
field_str = self._extract_field_name(field_name)
return [item for item in result if getattr(item, field_str) in values]
def _filter_not_in_collection(
self,
result: List[ModelT],
field_name: "Union[str, ColumnElement[Any], InstrumentedAttribute[Any]]",
values: abc.Collection[Any],
) -> List[ModelT]:
if not values:
return result
field_str = self._extract_field_name(field_name)
return [item for item in result if getattr(item, field_str) not in values]
def _filter_on_datetime_field(
self,
result: List[ModelT],
field_name: "Union[str, ColumnElement[Any], InstrumentedAttribute[Any]]",
before: Optional[datetime.datetime] = None,
after: Optional[datetime.datetime] = None,
on_or_before: Optional[datetime.datetime] = None,
on_or_after: Optional[datetime.datetime] = None,
) -> List[ModelT]:
field_str = self._extract_field_name(field_name)
result_: List[ModelT] = []
for item in result:
attr: datetime.datetime = getattr(item, field_str)
if before is not None and attr < before:
result_.append(item)
if after is not None and attr > after:
result_.append(item)
if on_or_before is not None and attr <= on_or_before:
result_.append(item)
if on_or_after is not None and attr >= on_or_after:
result_.append(item)
return result_
@staticmethod
def _filter_by_like(
result: List[ModelT],
field_name: Union[str, set[str]],
value: str,
ignore_case: bool,
) -> List[ModelT]:
pattern = re.compile(rf".*{value}.*", re.IGNORECASE) if ignore_case else re.compile(rf".*{value}.*")
fields = {field_name} if isinstance(field_name, str) else field_name
items: List[ModelT] = []
for field in fields:
items.extend(
[
item
for item in result
if isinstance(getattr(item, field), str) and pattern.match(getattr(item, field))
],
)
return list(set(items))
@staticmethod
def _filter_by_not_like(
result: List[ModelT],
field_name: Union[str, set[str]],
value: str,
ignore_case: bool,
) -> List[ModelT]:
pattern = re.compile(rf".*{value}.*", re.IGNORECASE) if ignore_case else re.compile(rf".*{value}.*")
fields = {field_name} if isinstance(field_name, str) else field_name
items: List[ModelT] = []
for field in fields:
items.extend(
[
item
for item in result
if isinstance(getattr(item, field), str) and pattern.match(getattr(item, field))
],
)
return list(set(result).difference(set(items)))
def _filter_result_by_kwargs(
self,
result: Iterable[ModelT],
/,
kwargs: Union[dict[Any, Any], Iterable[tuple[Any, Any]]],
) -> List[ModelT]:
kwargs_: dict[Any, Any] = kwargs if isinstance(kwargs, dict) else dict(*kwargs) # pyright: ignore
kwargs_ = self._exclude_unused_kwargs(kwargs_) # pyright: ignore
try:
return [item for item in result if all(getattr(item, field) == value for field, value in kwargs_.items())] # pyright: ignore
except AttributeError as error:
raise RepositoryError from error
def _order_by(
self,
result: List[ModelT],
field_name: "Union[str, ColumnElement[Any], InstrumentedAttribute[Any]]",
sort_desc: bool = False,
) -> List[ModelT]:
return sorted(result, key=lambda item: getattr(item, self._extract_field_name(field_name)), reverse=sort_desc)
def _apply_filters(
self,
result: List[ModelT],
*filters: Union[StatementFilter, ColumnElement[bool]],
apply_pagination: bool = True,
) -> List[ModelT]:
for filter_ in filters:
if isinstance(filter_, LimitOffset):
if apply_pagination:
result = self._apply_limit_offset_pagination(result, filter_.limit, filter_.offset)
elif isinstance(filter_, BeforeAfter):
result = self._filter_on_datetime_field(
result,
field_name=filter_.field_name,
before=filter_.before,
after=filter_.after,
)
elif isinstance(filter_, OnBeforeAfter):
result = self._filter_on_datetime_field(
result,
field_name=filter_.field_name,
on_or_before=filter_.on_or_before,
on_or_after=filter_.on_or_after,
)
elif isinstance(filter_, NotInCollectionFilter):
if filter_.values is not None: # pyright: ignore
result = self._filter_not_in_collection(result, filter_.field_name, filter_.values) # pyright: ignore
elif isinstance(filter_, CollectionFilter):
if filter_.values is not None: # pyright: ignore
result = self._filter_in_collection(result, filter_.field_name, filter_.values) # pyright: ignore
elif isinstance(filter_, OrderBy):
result = self._order_by(
result,
filter_.field_name,
sort_desc=filter_.sort_order == "desc",
)
elif isinstance(filter_, NotInSearchFilter):
result = self._filter_by_not_like(
result,
filter_.field_name,
value=filter_.value,
ignore_case=bool(filter_.ignore_case),
)
elif isinstance(filter_, SearchFilter):
result = self._filter_by_like(
result,
filter_.field_name,
value=filter_.value,
ignore_case=bool(filter_.ignore_case),
)
elif not isinstance(filter_, ColumnElement):
msg = f"Unexpected filter: {filter_}"
raise RepositoryError(msg)
return result
def _get_match_fields(
self,
match_fields: Union[List[str], str, None],
id_attribute: Optional[str] = None,
) -> Optional[List[str]]:
id_attribute = id_attribute or self.id_attribute
match_fields = match_fields or self.match_fields
if isinstance(match_fields, str):
match_fields = [match_fields]
return match_fields
def _get_many_and_count_basic(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
**kwargs: Any,
) -> tuple[List[ModelT], int]:
result = self.get_many(*filters, **kwargs)
return result, len(result)
def _get_many_and_count_window(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
**kwargs: Any,
) -> tuple[List[ModelT], int]:
return self._get_many_and_count_basic(*filters, **kwargs)
def _find_or_raise_not_found(self, id_: PrimaryKeyType) -> ModelT:
"""Find an item by primary key or raise NotFoundError.
Args:
id_: Primary key value (scalar, tuple, or dict).
Returns:
The found model instance.
Raises:
NotFoundError: If no instance found with the given primary key.
"""
store_key = self._get_store_key(id_)
return self.check_not_found(self.__collection__().get_or_none(store_key))
@staticmethod
def _find_one_or_raise_error(result: List[ModelT]) -> ModelT:
if not result:
msg = "No item found when one was expected"
raise IntegrityError(msg)
if len(result) > 1:
msg = "Multiple objects when one was expected"
raise IntegrityError(msg)
return result[0] # pyright: ignore
def _get_update_many_statement(
self,
model_type: type[ModelT],
supports_returning: bool,
loader_options: Optional[List[_AbstractLoad]],
execution_options: Optional[dict[str, Any]],
) -> Union[Update, ReturningUpdate[tuple[ModelT]]]:
return self.statement # type: ignore[no-any-return] # pyright: ignore[reportReturnType]
[docs]
@classmethod
def check_health(cls, session: Union[Session, scoped_session[Session]]) -> bool:
return True
[docs]
def get(
self,
item_id: PrimaryKeyType,
*,
auto_expunge: Optional[bool] = None,
statement: Union[Select[tuple[ModelT]], StatementLambdaElement, None] = None,
id_attribute: Union[str, InstrumentedAttribute[Any], None] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
with_for_update: ForUpdateParameter = None,
use_cache: bool = True,
bind_group: Optional[str] = None,
) -> ModelT:
return self._find_or_raise_not_found(item_id)
[docs]
def get_one(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
auto_expunge: Optional[bool] = None,
statement: Union[Select[tuple[ModelT]], StatementLambdaElement, None] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
bind_group: Optional[str] = None,
**kwargs: Any,
) -> ModelT:
return self.check_not_found(self.get_one_or_none(**kwargs))
[docs]
def get_one_or_none(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
auto_expunge: Optional[bool] = None,
statement: Union[Select[tuple[ModelT]], StatementLambdaElement, None] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
bind_group: Optional[str] = None,
**kwargs: Any,
) -> Union[ModelT, None]:
result = self._filter_result_by_kwargs(self.__collection__().get_all(), kwargs)
if len(result) > 1:
msg = "Multiple objects when one was expected"
raise IntegrityError(msg)
return result[0] if result else None
[docs]
def get_or_upsert(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
match_fields: Union[List[str], str, None] = None,
upsert: bool = True,
attribute_names: Optional[Iterable[str]] = None,
with_for_update: ForUpdateParameter = None,
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
auto_refresh: Optional[bool] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
bind_group: Optional[str] = None,
**kwargs: Any,
) -> tuple[ModelT, bool]:
kwargs_ = self._exclude_unused_kwargs(kwargs)
if match_fields := self._get_match_fields(match_fields=match_fields):
match_filter = {
# sourcery skip: remove-none-from-default-get
field_name: kwargs_.get(field_name, None)
for field_name in match_fields
if kwargs_.get(field_name, None) is not None
}
else:
match_filter = kwargs_
existing = self.get_one_or_none(**match_filter)
if not existing:
return (self.add(self.model_type(**kwargs_)), True)
if upsert:
for field_name, new_field_value in kwargs_.items():
field = getattr(existing, field_name, MISSING)
if field is not MISSING and not compare_values(field, new_field_value): # pragma: no cover
setattr(existing, field_name, new_field_value)
existing = self.update(existing)
return existing, False
[docs]
def get_and_update(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
match_fields: Union[List[str], str, None] = None,
attribute_names: Optional[Iterable[str]] = None,
with_for_update: ForUpdateParameter = None,
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
auto_refresh: Optional[bool] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
bind_group: Optional[str] = None,
**kwargs: Any,
) -> tuple[ModelT, bool]:
kwargs_ = self._exclude_unused_kwargs(kwargs)
if match_fields := self._get_match_fields(match_fields=match_fields):
match_filter = {
# sourcery skip: remove-none-from-default-get
field_name: kwargs_.get(field_name, None)
for field_name in match_fields
if kwargs_.get(field_name, None) is not None
}
else:
match_filter = kwargs_
existing = self.get_one(**match_filter)
updated = False
for field_name, new_field_value in kwargs_.items():
field = getattr(existing, field_name, MISSING)
if field is not MISSING and not compare_values(field, new_field_value): # pragma: no cover
updated = True
setattr(existing, field_name, new_field_value)
existing = self.update(existing)
return existing, updated
[docs]
def exists(
self,
*filters: "Union[StatementFilter, ColumnElement[bool]]",
uniquify: Optional[bool] = None,
bind_group: Optional[str] = None,
**kwargs: Any,
) -> bool:
existing = self.count(*filters, **kwargs)
return existing > 0
[docs]
def count(
self,
*filters: "Union[StatementFilter, ColumnElement[bool]]",
uniquify: Optional[bool] = None,
bind_group: Optional[str] = None,
**kwargs: Any,
) -> int:
result = self._apply_filters(self.__collection__().get_all(), *filters)
return len(self._filter_result_by_kwargs(result, kwargs))
[docs]
def add(
self,
data: ModelT,
*,
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
auto_refresh: Optional[bool] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
bind_group: Optional[str] = None,
) -> ModelT:
try:
self.__database__.add(self.model_type, data)
except KeyError as exc:
msg = "Item already exist in collection"
raise IntegrityError(msg) from exc
return data
[docs]
def add_many(
self,
data: List[ModelT],
*,
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
bind_group: Optional[str] = None,
) -> List[ModelT]:
for obj in data:
self.add(obj) # pyright: ignore[reportCallIssue]
return data
[docs]
def update(
self,
data: ModelT,
*,
attribute_names: Optional[Iterable[str]] = None,
with_for_update: ForUpdateParameter = None,
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
auto_refresh: Optional[bool] = None,
id_attribute: Optional[Union[str, InstrumentedAttribute[Any]]] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
bind_group: Optional[str] = None,
) -> ModelT:
pk_value = self.get_primary_key_value(data)
self._find_or_raise_not_found(pk_value)
return self.__collection__().update(data)
[docs]
def update_many(
self,
data: List[ModelT],
*,
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
bind_group: Optional[str] = None,
) -> List[ModelT]:
return [self.__collection__().update(obj) for obj in data if obj in self.__collection__()]
[docs]
def delete(
self,
item_id: PrimaryKeyType,
*,
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
id_attribute: Optional[Union[str, InstrumentedAttribute[Any]]] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
bind_group: Optional[str] = None,
) -> ModelT:
store_key = self._get_store_key(item_id)
try:
return self._find_or_raise_not_found(item_id)
finally:
self.__collection__().remove(store_key)
[docs]
def delete_many(
self,
item_ids: List[PrimaryKeyType],
*,
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
id_attribute: Optional[Union[str, InstrumentedAttribute[Any]]] = None,
chunk_size: Optional[int] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
bind_group: Optional[str] = None,
) -> List[ModelT]:
deleted: List[ModelT] = []
for id_ in item_ids:
store_key = self._get_store_key(id_)
if obj := self.__collection__().get_or_none(store_key):
deleted.append(obj)
self.__collection__().remove(store_key)
return deleted
[docs]
def delete_where(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
auto_commit: Optional[bool] = None,
auto_expunge: Optional[bool] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
sanity_check: bool = True,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
bind_group: Optional[str] = None,
**kwargs: Any,
) -> List[ModelT]:
result = self.__collection__().get_all()
result = self._apply_filters(result, *filters)
models = self._filter_result_by_kwargs(result, kwargs)
item_ids: list[PrimaryKeyType] = [self.get_primary_key_value(model) for model in models]
return self.delete_many(item_ids=item_ids)
[docs]
def upsert(
self,
data: ModelT,
*,
attribute_names: Optional[Iterable[str]] = None,
with_for_update: ForUpdateParameter = None,
auto_expunge: Optional[bool] = None,
auto_commit: Optional[bool] = None,
auto_refresh: Optional[bool] = None,
match_fields: Optional[Union[List[str], str]] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
bind_group: Optional[str] = None,
) -> ModelT:
# sourcery skip: assign-if-exp, reintroduce-else
if data in self.__collection__():
return self.update(data)
return self.add(data)
[docs]
def upsert_many(
self,
data: List[ModelT],
*,
auto_expunge: Optional[bool] = None,
auto_commit: Optional[bool] = None,
no_merge: bool = False,
match_fields: Optional[Union[List[str], str]] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
bind_group: Optional[str] = None,
) -> List[ModelT]:
return [self.upsert(item) for item in data]
[docs]
def get_many_and_count(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
statement: Union[Select[tuple[ModelT]], StatementLambdaElement, None] = None,
auto_expunge: Optional[bool] = None,
count_with_window_function: Optional[bool] = None,
order_by: Optional[Union[List[OrderingPair], OrderingPair]] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
use_cache: bool = True,
bind_group: Optional[str] = None,
**kwargs: Any,
) -> tuple[List[ModelT], int]:
return self._get_many_and_count_basic(*filters, **kwargs)
[docs]
def list_and_count(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
statement: Union[Select[tuple[ModelT]], StatementLambdaElement, None] = None,
auto_expunge: Optional[bool] = None,
count_with_window_function: Optional[bool] = None,
order_by: Optional[Union[List[OrderingPair], OrderingPair]] = None,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
use_cache: bool = True,
bind_group: Optional[str] = None,
**kwargs: Any,
) -> tuple[List[ModelT], int]:
"""List of records and total count returned by query.
.. deprecated:: 1.10.0
Use :meth:`get_many_and_count` instead.
"""
warn_deprecation(
version="1.10.0",
deprecated_name="list_and_count",
kind="method",
removal_in="2.0.0",
alternative="get_many_and_count",
)
return self.get_many_and_count(*filters, **kwargs)
[docs]
def get_many(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
uniquify: Optional[bool] = None,
use_cache: bool = True,
bind_group: Optional[str] = None,
**kwargs: Any,
) -> List[ModelT]:
result = self.__collection__().get_all()
result = self._apply_filters(result, *filters)
return self._filter_result_by_kwargs(result, kwargs)
[docs]
def list(
self,
*filters: Union[StatementFilter, ColumnElement[bool]],
uniquify: Optional[bool] = None,
use_cache: bool = True,
bind_group: Optional[str] = None,
**kwargs: Any,
) -> List[ModelT]:
"""List of records returned by query.
.. deprecated:: 1.10.0
Use :meth:`get_many` instead.
"""
warn_deprecation(
version="1.10.0",
deprecated_name="list",
kind="method",
removal_in="2.0.0",
alternative="get_many",
)
return self.get_many(*filters, **kwargs)
[docs]
class SQLAlchemySyncMockSlugRepository(
SQLAlchemySyncMockRepository[ModelT],
SQLAlchemySyncSlugRepositoryProtocol[ModelT],
):
[docs]
def get_by_slug(
self,
slug: str,
error_messages: Optional[Union[ErrorMessages, EmptyType]] = Empty,
load: Optional[LoadSpec] = None,
execution_options: Optional[dict[str, Any]] = None,
uniquify: Optional[bool] = None,
bind_group: Optional[str] = None,
**kwargs: Any,
) -> Union[ModelT, None]:
return self.get_one_or_none(slug=slug)
[docs]
def get_available_slug(
self,
value_to_slugify: str,
**kwargs: Any,
) -> str:
"""Get a unique slug for the supplied value.
If the value is found to exist, a random 4 digit character is appended to the end.
Override this method to change the default behavior
Args:
value_to_slugify (str): A string that should be converted to a unique slug.
**kwargs: stuff
Returns:
str: a unique slug for the supplied value. This is safe for URLs and other unique identifiers.
"""
slug = slugify(value_to_slugify)
if self._is_slug_unique(slug):
return slug
random_string = "".join(random.choices(string.ascii_lowercase + string.digits, k=4)) # noqa: S311
return f"{slug}-{random_string}"
def _is_slug_unique(
self,
slug: str,
**kwargs: Any,
) -> bool:
return self.exists(slug=slug) is False