Source code for axio.models
"""Transport-agnostic model types: Capability, ModelSpec, ModelRegistry."""
from __future__ import annotations
from collections.abc import ItemsView, Iterable, Iterator, KeysView, MutableMapping, ValuesView
from dataclasses import dataclass
from enum import StrEnum
[docs]
class Capability(StrEnum):
# Input modalities
text = "text"
vision = "vision"
video = "video"
audio = "audio"
# Output modalities
image_generation = "image_generation"
video_generation = "video_generation"
# Processing capabilities
reasoning = "reasoning"
tool_use = "tool_use"
json_mode = "json_mode"
structured_outputs = "structured_outputs"
embedding = "embedding"
[docs]
@dataclass(frozen=True, slots=True)
class ModelSpec:
id: str
capabilities: frozenset[Capability] = frozenset()
max_output_tokens: int = 8192
context_window: int = 128000
input_cost: float = 0.0
output_cost: float = 0.0
[docs]
class ModelRegistry(MutableMapping[str, ModelSpec]):
__slots__ = ("_models",)
def __init__(self, models: Iterable[ModelSpec] | None = None) -> None:
self._models: dict[str, ModelSpec] = {m.id: m for m in (models or [])}
def __setitem__(self, key: str, value: ModelSpec, /) -> None:
if not isinstance(value, ModelSpec):
raise ValueError("ModelRegistry values must be ModelSpec instances")
self._models[key] = value
def __delitem__(self, key: str, /) -> None:
del self._models[key]
def __getitem__(self, key: str, /) -> ModelSpec:
return self._models[key]
def __len__(self) -> int:
return len(self._models)
def __iter__(self) -> Iterator[ModelSpec]: # type: ignore[override]
return iter(self._models.values())
def __eq__(self, other: object) -> bool:
if isinstance(other, ModelRegistry):
return self._models == other._models
if isinstance(other, dict):
return self._models == other
return NotImplemented
def __repr__(self) -> str:
return f"ModelRegistry({self._models!r})"
[docs]
def clear(self) -> None:
self._models.clear()
[docs]
def keys(self) -> KeysView[str]:
return self._models.keys()
[docs]
def values(self) -> ValuesView[ModelSpec]:
return self._models.values()
[docs]
def items(self) -> ItemsView[str, ModelSpec]:
return self._models.items()
[docs]
def by_prefix(self, prefix: str) -> ModelRegistry:
return ModelRegistry(v for k, v in self._models.items() if k.startswith(prefix))
[docs]
def by_capability(self, *caps: Capability) -> ModelRegistry:
required = frozenset(caps)
return ModelRegistry(v for v in self._models.values() if required <= v.capabilities)
[docs]
def search(self, *q: str) -> ModelRegistry:
"""search by parts of id"""
if len(q) == 1 and q[0] in self._models:
return ModelRegistry([self._models[q[0]]])
return ModelRegistry(v for k, v in self._models.items() if all(part in k for part in q))
[docs]
def by_cost(self, *, output: bool = False, desc: bool = False) -> ModelRegistry:
"""Return registry ordered by cost (input by default, output if *output=True*)."""
attr = "output_cost" if output else "input_cost"
items = sorted(self._models.values(), key=lambda v: getattr(v, attr), reverse=desc)
return ModelRegistry(items)
[docs]
def ids(self) -> list[str]:
return list(self._models)
def _get_model_by_index(self, index: int) -> ModelSpec:
vals = list(iter(self._models.values()))
if len(vals) >= 1:
return vals[index]
raise IndexError("ModelRegistry is empty")
[docs]
def first(self) -> ModelSpec:
return self._get_model_by_index(0)
[docs]
def last(self) -> ModelSpec:
return self._get_model_by_index(-1)