"""Dialect-aware ``Vector`` column type.
Provides a single user-facing :class:`Vector` class that resolves to the most
appropriate backend representation per dialect:
- Oracle 23ai → ``sqlalchemy.dialects.oracle.VECTOR(dim, storage_format)``
- PostgreSQL / CockroachDB → ``pgvector.sqlalchemy.Vector(dim)`` when
``pgvector`` is importable, otherwise the cross-dialect JSON fallback.
- All other dialects → ``sqlalchemy.types.JSON`` round-trip as a JSON array.
The pattern mirrors :class:`advanced_alchemy.types.guid.GUID`: one
``TypeDecorator`` with a single constructor, and ``load_dialect_impl`` selects
the backend impl at DDL/compile time. Backend libraries are imported lazily
inside ``load_dialect_impl`` so ``from advanced_alchemy.types import Vector``
never requires ``pgvector`` or ``oracledb`` to be installed.
"""
from typing import TYPE_CHECKING, Any, Optional
from sqlalchemy import Float, literal
from sqlalchemy.engine import Dialect
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.sql.expression import ColumnElement
from sqlalchemy.sql.visitors import InternalTraversal
from sqlalchemy.types import JSON, TypeDecorator, TypeEngine
if TYPE_CHECKING:
from collections.abc import Sequence
from sqlalchemy.sql.compiler import SQLCompiler
__all__ = ("Vector",)
_PGVECTOR_OPERATORS = {"cosine": "<=>", "l2": "<->", "l1": "<+>", "inner_product": "<#>"}
_ORACLE_METRICS = {"cosine": "COSINE", "l2": "EUCLIDEAN", "l1": "MANHATTAN", "inner_product": "DOT"}
[docs]
class Vector(TypeDecorator[list[float]]):
"""Dialect-aware fixed-dimension vector column.
Args:
dim: The vector dimension. Required by every supported backend.
storage_format: Oracle 23ai storage format name (matched against
:class:`sqlalchemy.dialects.oracle.VectorStorageFormat`).
Defaults to ``"FLOAT32"``. Ignored on non-Oracle backends.
"""
impl = JSON
cache_ok = True
@property
def python_type(self) -> type[list[float]]:
return list
[docs]
def __init__(self, dim: int, *, storage_format: str = "FLOAT32") -> None:
super().__init__()
self.dim = dim
self.storage_format = storage_format
[docs]
def load_dialect_impl(self, dialect: Dialect) -> TypeEngine[Any]:
if dialect.name == "oracle":
from sqlalchemy.dialects.oracle import VECTOR, VectorStorageFormat
fmt = VectorStorageFormat[self.storage_format]
return dialect.type_descriptor(VECTOR(dim=self.dim, storage_format=fmt)) # type: ignore[no-untyped-call]
if dialect.name in {"postgresql", "cockroachdb"}:
try:
from pgvector.sqlalchemy import Vector as PgVector # pyright: ignore[reportMissingTypeStubs]
except ImportError:
return dialect.type_descriptor(JSON())
return dialect.type_descriptor(PgVector(self.dim))
return dialect.type_descriptor(JSON())
[docs]
def process_result_value(self, value: Any, dialect: Dialect) -> Optional[list[float]]:
if value is None:
return None
if hasattr(value, "tolist"):
return list(value.tolist())
return list(value)
[docs]
class Comparator(TypeDecorator.Comparator[list[float]]):
"""Dialect-aware vector distance operators for similarity search.
Method names and semantics mirror ``pgvector``'s SQLAlchemy comparator so
existing pgvector-based queries can swap to :class:`Vector` unchanged.
"""
def _distance(self, other: "Sequence[float]", metric: str) -> "_VectorDistance":
operand = literal(list(other), type_=self.type)
return _VectorDistance(self.expr, operand, metric)
[docs]
def cosine_distance(self, other: "Sequence[float]") -> "_VectorDistance":
return self._distance(other, "cosine")
[docs]
def l2_distance(self, other: "Sequence[float]") -> "_VectorDistance":
return self._distance(other, "l2")
[docs]
def l1_distance(self, other: "Sequence[float]") -> "_VectorDistance":
return self._distance(other, "l1")
[docs]
def max_inner_product(self, other: "Sequence[float]") -> "_VectorDistance":
return self._distance(other, "inner_product")
comparator_factory = Comparator # pyright: ignore[reportIncompatibleMethodOverride,reportAssignmentType]
class _VectorDistance(ColumnElement[float]):
"""Dialect-deferred vector distance expression.
The SQL is chosen at compile time so the same expression compiles to native
operators on each backend:
- PostgreSQL / CockroachDB → ``<=>`` / ``<->`` / ``<+>`` / ``<#>`` (pgvector)
- Oracle 23ai → ``VECTOR_DISTANCE(col, :vec, COSINE | EUCLIDEAN | MANHATTAN | DOT)``
- other dialects → :exc:`NotImplementedError` (no native vector backend)
``max_inner_product`` (pgvector ``<#>``, Oracle ``DOT``) returns the *negated*
inner product on both backends, so ascending order yields the nearest match.
``_traverse_internals`` makes the expression cache-safe: statements containing
a vector distance participate in SQLAlchemy's compiled-statement cache keyed on
the operands and metric.
"""
_traverse_internals = [
("left", InternalTraversal.dp_clauseelement),
("right", InternalTraversal.dp_clauseelement),
("metric", InternalTraversal.dp_string),
]
def __init__(self, left: "ColumnElement[Any]", right: "ColumnElement[Any]", metric: str) -> None:
self.left = left
self.right = right
self.metric = metric
self.type = Float()
@compiles(_VectorDistance)
def compile_vector_distance(element: _VectorDistance, compiler: "SQLCompiler", **kw: Any) -> str:
msg = (
"Vector distance operations require a native vector backend "
"(PostgreSQL with pgvector, or Oracle 23ai). The current dialect stores "
"vectors as JSON and has no distance operator."
)
raise NotImplementedError(msg)
@compiles(_VectorDistance, "postgresql")
@compiles(_VectorDistance, "cockroachdb")
def compile_vector_distance_pgvector(element: _VectorDistance, compiler: "SQLCompiler", **kw: Any) -> str:
operator = _PGVECTOR_OPERATORS[element.metric]
return f"({compiler.process(element.left, **kw)} {operator} {compiler.process(element.right, **kw)})"
@compiles(_VectorDistance, "oracle")
def compile_vector_distance_oracle(element: _VectorDistance, compiler: "SQLCompiler", **kw: Any) -> str:
metric = _ORACLE_METRICS[element.metric]
return f"VECTOR_DISTANCE({compiler.process(element.left, **kw)}, {compiler.process(element.right, **kw)}, {metric})"