Source code for axio.tool

"""Tool: frozen dataclass binding a handler callable to a name, guards, and concurrency."""

from __future__ import annotations

import asyncio
import copy
import inspect
import logging
from collections.abc import AsyncGenerator, Awaitable, Callable, Mapping
from contextlib import asynccontextmanager
from contextvars import ContextVar
from dataclasses import dataclass, field
from dataclasses import replace as dc_replace
from types import MappingProxyType
from typing import Any, get_type_hints

from .exceptions import GuardError, HandlerError
from .field import MISSING, FieldInfo, get_field_info
from .permission import PermissionGuard
from .schema import build_tool_schema
from .types import ToolName

type JSONSchema = dict[str, Any]

logger = logging.getLogger(__name__)

# Maps JSON Schema primitive type names to Python types used for validation.
SCHEMA_JSON_TYPE_MAP: dict[str, type] = {
    "string": str,
    "integer": int,
    "number": float,
    "boolean": bool,
    "array": list,
    "object": dict,
}


[docs] def hint_from_json_schema(prop_schema: dict[str, Any]) -> Any: """Return the Python type hint for a single JSON Schema property definition.""" t = prop_schema.get("type") if t is not None: return SCHEMA_JSON_TYPE_MAP.get(t, object) any_of = prop_schema.get("anyOf") if any_of is not None: non_null = [s for s in any_of if s.get("type") != "null"] has_null = len(non_null) < len(any_of) if len(non_null) == 1: inner = hint_from_json_schema(non_null[0]) return (inner | None) if has_null else inner return object
# Set to the tool's ``context`` value before each handler invocation. # Handlers that cannot receive context as a parameter retrieve it via ``CONTEXT.get()``. CONTEXT: ContextVar[Any] = ContextVar("CONTEXT") def _default_format_stream_result(chunks: list[tuple[float, str, str]]) -> str: """Default streaming-aggregator: join all text, discard keys/timestamps.""" return "".join(text for _, _, text in chunks)
[docs] @dataclass(frozen=True, slots=True) class Tool[T]: name: ToolName handler: Callable[..., Awaitable[Any]] description: str = "" guards: tuple[PermissionGuard, ...] = () concurrency: int | None = None context: T = field(default=MappingProxyType({}), compare=False) # type: ignore[assignment] schema: MappingProxyType[str, Any] = field(default=MappingProxyType({}), repr=False, compare=False) _semaphore: asyncio.Semaphore | None = field(init=False, default=None, repr=False, compare=False) _fields: Mapping[str, tuple[Any, FieldInfo]] = field( init=False, repr=False, compare=False, default_factory=lambda: MappingProxyType({}) ) _accepts_var_kwargs: bool = field(init=False, default=False, repr=False, compare=False) _schema_explicit: bool = field(init=False, default=False, repr=False, compare=False) def __post_init__(self) -> None: if not inspect.iscoroutinefunction(self.handler): raise TypeError( f"Tool {self.name!r} handler {self.handler!r} must be an async function (coroutinefunction)." ) if not self.description: object.__setattr__(self, "description", self.handler.__doc__ or "") hints = get_type_hints(self.handler, include_extras=True) param_hints = {k: v for k, v in hints.items() if k != "return"} try: sig = inspect.signature(self.handler) except (ValueError, TypeError): sig = None fields: dict[str, tuple[Any, FieldInfo]] = {} for name, hint in param_hints.items(): fi = get_field_info(hint) or FieldInfo() if sig is not None and name in sig.parameters: param = sig.parameters[name] if param.default is not inspect.Parameter.empty and fi.default is MISSING: # Merge sig default into FieldInfo (covers StrictStr and plain defaults). fi = dc_replace(fi, default=param.default) fields[name] = (hint, fi) param_fields = MappingProxyType(fields) accepts_var_kwargs = sig is not None and any( p.kind is inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values() ) schema_explicit = bool(self.schema) if not self.schema: object.__setattr__(self, "schema", MappingProxyType(build_tool_schema(self.handler, hints=param_hints))) # For handlers with an explicit schema: synthesise _fields from schema properties # so type validation, default injection, and kwarg filtering all use the schema. if schema_explicit: schema_props: dict[str, Any] = dict(self.schema).get("properties") or {} schema_fields: dict[str, tuple[Any, FieldInfo]] = {} for prop_name, prop_schema in schema_props.items(): hint = hint_from_json_schema(prop_schema) default = prop_schema.get("default", MISSING) schema_fields[prop_name] = (hint, FieldInfo(default=default)) if schema_fields: param_fields = MappingProxyType(schema_fields) object.__setattr__(self, "_fields", param_fields) object.__setattr__(self, "_accepts_var_kwargs", accepts_var_kwargs) object.__setattr__(self, "_schema_explicit", schema_explicit) if self.concurrency is not None: object.__setattr__(self, "_semaphore", asyncio.Semaphore(self.concurrency)) @asynccontextmanager async def _acquire(self) -> AsyncGenerator[None, None]: if self._semaphore is None: yield return async with self._semaphore: yield @property def input_schema(self) -> JSONSchema: return copy.deepcopy(dict(self.schema)) @property def supports_streaming(self) -> bool: """Handler supports streaming if it exposes a ``.stream`` async-generator attribute.""" return callable(getattr(self.handler, "stream", None))
[docs] def format_stream_result(self, chunks: list[tuple[float, str, str]]) -> str: """Aggregate streamed chunks into the final tool result string. Handlers may attach a ``format_stream_result`` callable for structured output (e.g. shell log records). Defaults to text concatenation. """ fn = getattr(self.handler, "format_stream_result", None) if callable(fn): result = fn(chunks) return result if isinstance(result, str) else str(result) return _default_format_stream_result(chunks)
def _prepare_kwargs(self, kwargs: dict[str, Any]) -> dict[str, Any]: """Inject defaults, validate types, and strip extras per schema.""" required_set: set[str] = set(self.schema.get("required", [])) for name, (hint, fi) in self._fields.items(): if name not in kwargs: if fi.default is not MISSING and name not in required_set: kwargs[name] = fi.default else: fi.validate(kwargs[name], name, hint) missing = [name for name in required_set if name not in kwargs] if missing: raise HandlerError(f"Missing required field(s): {', '.join(missing)}") if self._schema_explicit: schema_props = self.schema.get("properties") if schema_props is not None: kwargs = {k: v for k, v in kwargs.items() if k in schema_props} elif not self._accepts_var_kwargs: kwargs = {k: v for k, v in kwargs.items() if k in self._fields} return kwargs async def __call__(self, **kwargs: Any) -> Any: async with self._acquire(): try: kwargs = self._prepare_kwargs(kwargs) except HandlerError: raise except Exception as exc: raise HandlerError(str(exc)) from exc for guard in self.guards: try: kwargs = await guard(self, **kwargs) except GuardError: raise except Exception as exc: raise GuardError(str(exc)) from exc try: if self._schema_explicit: schema_props = self.schema.get("properties") if schema_props is not None: kwargs = {k: v for k, v in kwargs.items() if k in schema_props} elif not self._accepts_var_kwargs: kwargs = {k: v for k, v in kwargs.items() if k in self._fields} token = CONTEXT.set(self.context) try: return await self.handler(**kwargs) finally: CONTEXT.reset(token) except HandlerError: raise except Exception as exc: raise HandlerError(str(exc)) from exc
[docs] async def call_streaming(self, **kwargs: Any) -> AsyncGenerator[tuple[str, str], None]: """Execute handler, yielding ``(key, text)`` chunks for streaming output. Uses ``handler.stream(**kwargs)`` if the handler exposes one. Otherwise falls back to ``__call__()`` and yields the full result as a single ``("output", ...)`` chunk. Semaphore is held for the entire iteration. """ async with self._acquire(): try: kwargs = self._prepare_kwargs(kwargs) except HandlerError: raise except Exception as exc: raise HandlerError(str(exc)) from exc for guard in self.guards: try: kwargs = await guard(self, **kwargs) except GuardError: raise except Exception as exc: raise GuardError(str(exc)) from exc if self._schema_explicit: schema_props = self.schema.get("properties") if schema_props is not None: kwargs = {k: v for k, v in kwargs.items() if k in schema_props} elif not self._accepts_var_kwargs: kwargs = {k: v for k, v in kwargs.items() if k in self._fields} stream_fn = getattr(self.handler, "stream", None) token = CONTEXT.set(self.context) try: if callable(stream_fn): async for chunk in stream_fn(**kwargs): yield chunk else: result = await self.handler(**kwargs) if isinstance(result, str): yield ("output", result) else: yield ("output", str(result)) except HandlerError: raise except Exception as exc: raise HandlerError(str(exc)) from exc finally: CONTEXT.reset(token)