Source code for advanced_alchemy.types.totp

"""TOTP shared-secret column type with pyotp-backed helpers."""

from typing import TYPE_CHECKING, Any, Callable, Optional, Union, cast

from advanced_alchemy.exceptions import MissingDependencyError
from advanced_alchemy.types.encrypted_string import EncryptedString, EncryptionBackend, FernetBackend
from advanced_alchemy.typing import PYOTP_INSTALLED

if TYPE_CHECKING:
    from sqlalchemy.engine import Dialect

__all__ = ("TOTPProvider", "TOTPSecret", "generate_totp_secret")


[docs] def generate_totp_secret(length: int = 32) -> str: """Generate a new random base32 TOTP secret. Args: length: Number of base32 characters. Defaults to 32. Raises: MissingDependencyError: If the ``pyotp`` package is not installed. Returns: A base32-encoded secret suitable for an authenticator app. """ if not PYOTP_INSTALLED: raise MissingDependencyError(package="pyotp") import pyotp return pyotp.random_base32(length=length)
[docs] class TOTPProvider: """Read-side wrapper exposing pyotp operations over a decrypted secret. The secret must be a valid base32 string. pyotp decodes it lazily, so a malformed secret raises :exc:`binascii.Error` on the first call that decodes it (``now``/``verify``). """ __slots__ = ("_digest", "_digits", "_interval", "_issuer", "_secret")
[docs] def __init__( self, secret: str, *, digits: int = 6, interval: int = 30, digest: "Optional[Any]" = None, issuer: "Optional[str]" = None, ) -> None: """Initialize the provider. Args: secret: The decrypted base32 secret. digits: Number of digits in generated codes. Defaults to 6. interval: Time step in seconds. Defaults to 30. digest: Hash algorithm passed to pyotp (None selects SHA1). issuer: Default issuer used for provisioning URIs. Raises: MissingDependencyError: If the ``pyotp`` package is not installed. """ if not PYOTP_INSTALLED: raise MissingDependencyError(package="pyotp") self._secret = secret self._digits = digits self._interval = interval self._digest = digest self._issuer = issuer
@property def secret(self) -> str: """The decrypted base32 secret.""" return self._secret def _totp(self) -> Any: import pyotp return pyotp.TOTP( self._secret, digits=self._digits, digest=self._digest, interval=self._interval, issuer=self._issuer, )
[docs] def now(self) -> str: """Return the current TOTP code. Returns: The zero-padded code for the current time step. """ return cast("str", self._totp().now())
[docs] def verify(self, code: str, valid_window: int = 1) -> bool: """Verify a code, tolerating ``valid_window`` ticks of clock drift each way. Args: code: The candidate code. valid_window: Number of time steps of drift accepted on each side. Defaults to 1. Returns: True if the code is valid within the window, False otherwise. """ return cast("bool", self._totp().verify(code, valid_window=valid_window))
[docs] def provisioning_uri(self, name: "Optional[str]" = None, issuer_name: "Optional[str]" = None) -> str: """Return an ``otpauth://`` provisioning URI for QR enrollment. Args: name: The account name (e.g. the user's email). issuer_name: The issuer label; falls back to the type's configured issuer. Returns: The provisioning URI string. """ return cast("str", self._totp().provisioning_uri(name=name, issuer_name=issuer_name or self._issuer))
[docs] class TOTPSecret(EncryptedString): """Stores a TOTP shared secret encrypted at rest, returning a :class:`TOTPProvider` on read. Storage is identical to :class:`EncryptedString`; the read side wraps the decrypted secret in a :class:`TOTPProvider` so callers can generate and verify codes. Requires ``pyotp``. """ cache_ok = True
[docs] def __init__( self, key: "Union[str, bytes, Callable[[], Union[str, bytes]]]", backend: "type[EncryptionBackend]" = FernetBackend, digits: int = 6, interval: int = 30, digest: "Optional[Any]" = None, issuer: "Optional[str]" = None, length: "Optional[int]" = None, **kwargs: Any, ) -> None: """Initialize the TOTPSecret type. Args: key: The encryption key (str/bytes or a callable returning one). Required: unlike :class:`EncryptedString`, this new type has no deprecated random default. backend: The encryption backend class. Defaults to FernetBackend. digits: Number of digits in generated codes. Defaults to 6. interval: Time step in seconds. Defaults to 30. digest: Hash algorithm passed to pyotp (None selects SHA1). issuer: Default issuer used for provisioning URIs. length: The unencrypted string length (documentation only). **kwargs: Additional arguments passed to the underlying String type. Raises: MissingDependencyError: If the ``pyotp`` package is not installed. """ if not PYOTP_INSTALLED: raise MissingDependencyError(package="pyotp") super().__init__(key=key, backend=backend, length=length, **kwargs) self.digits = digits self.interval = interval self.digest = digest self.issuer = issuer
[docs] def __repr__(self) -> str: """Return a reconstructable representation including the TOTP parameters. Overrides :class:`EncryptedString` so Alembic autogenerate reconstructs the type with its configured ``digits``/``interval``/``digest``/``issuer`` rather than the defaults. """ key_repr = self.key.__name__ if callable(self.key) else repr(self.key) return ( f"{type(self).__name__}(key={key_repr}, backend={self.backend.__class__.__name__}, " f"digits={self.digits}, interval={self.interval}, digest={self.digest!r}, " f"issuer={self.issuer!r}, length={self.length})" )
[docs] def process_bind_param(self, value: Any, dialect: "Dialect") -> "Union[str, None]": """Accept a raw secret string or a :class:`TOTPProvider`, then encrypt as usual. Args: value: The secret string or a TOTPProvider. dialect: The SQLAlchemy dialect. Returns: The encrypted secret, or None. """ if isinstance(value, TOTPProvider): value = value.secret return super().process_bind_param(value, dialect)
[docs] def process_result_value(self, value: Any, dialect: "Dialect") -> "Optional[TOTPProvider]": # type: ignore[override] """Decrypt the stored secret and wrap it in a :class:`TOTPProvider`. Args: value: The stored encrypted value. dialect: The SQLAlchemy dialect. Returns: A TOTPProvider over the decrypted secret, or None. """ decrypted = super().process_result_value(value, dialect) if decrypted is None: return None return TOTPProvider( decrypted, digits=self.digits, interval=self.interval, digest=self.digest, issuer=self.issuer, )