Source code for axio.models

"""Transport-agnostic model types: Capability, ModelSpec, ModelRegistry."""

from __future__ import annotations

from collections.abc import ItemsView, Iterable, Iterator, KeysView, MutableMapping, ValuesView
from dataclasses import dataclass
from enum import StrEnum


[docs] class Capability(StrEnum): # Input modalities text = "text" vision = "vision" video = "video" audio = "audio" # Output modalities image_generation = "image_generation" video_generation = "video_generation" # Processing capabilities reasoning = "reasoning" tool_use = "tool_use" json_mode = "json_mode" structured_outputs = "structured_outputs" embedding = "embedding"
[docs] @dataclass(frozen=True, slots=True) class ModelSpec: id: str capabilities: frozenset[Capability] = frozenset() max_output_tokens: int = 8192 context_window: int = 128000 input_cost: float = 0.0 output_cost: float = 0.0
[docs] class ModelRegistry(MutableMapping[str, ModelSpec]): __slots__ = ("_models",) def __init__(self, models: Iterable[ModelSpec] | None = None) -> None: self._models: dict[str, ModelSpec] = {m.id: m for m in (models or [])} def __setitem__(self, key: str, value: ModelSpec, /) -> None: if not isinstance(value, ModelSpec): raise ValueError("ModelRegistry values must be ModelSpec instances") self._models[key] = value def __delitem__(self, key: str, /) -> None: del self._models[key] def __getitem__(self, key: str, /) -> ModelSpec: return self._models[key] def __len__(self) -> int: return len(self._models) def __iter__(self) -> Iterator[ModelSpec]: # type: ignore[override] return iter(self._models.values()) def __eq__(self, other: object) -> bool: if isinstance(other, ModelRegistry): return self._models == other._models if isinstance(other, dict): return self._models == other return NotImplemented def __repr__(self) -> str: return f"ModelRegistry({self._models!r})"
[docs] def clear(self) -> None: self._models.clear()
[docs] def keys(self) -> KeysView[str]: return self._models.keys()
[docs] def values(self) -> ValuesView[ModelSpec]: return self._models.values()
[docs] def items(self) -> ItemsView[str, ModelSpec]: return self._models.items()
[docs] def by_prefix(self, prefix: str) -> ModelRegistry: return ModelRegistry(v for k, v in self._models.items() if k.startswith(prefix))
[docs] def by_capability(self, *caps: Capability) -> ModelRegistry: required = frozenset(caps) return ModelRegistry(v for v in self._models.values() if required <= v.capabilities)
[docs] def search(self, *q: str) -> ModelRegistry: """search by parts of id""" if len(q) == 1 and q[0] in self._models: return ModelRegistry([self._models[q[0]]]) return ModelRegistry(v for k, v in self._models.items() if all(part in k for part in q))
[docs] def by_cost(self, *, output: bool = False, desc: bool = False) -> ModelRegistry: """Return registry ordered by cost (input by default, output if *output=True*).""" attr = "output_cost" if output else "input_cost" items = sorted(self._models.values(), key=lambda v: getattr(v, attr), reverse=desc) return ModelRegistry(items)
[docs] def ids(self) -> list[str]: return list(self._models)
def _get_model_by_index(self, index: int) -> ModelSpec: vals = list(iter(self._models.values())) if len(vals) >= 1: return vals[index] raise IndexError("ModelRegistry is empty")
[docs] def first(self) -> ModelSpec: return self._get_model_by_index(0)
[docs] def last(self) -> ModelSpec: return self._get_model_by_index(-1)