Source code for axio.tool_args

"""Incremental streaming parser for tool call JSON arguments.

Feeds partial JSON chunks (from ``ToolInputDelta.partial_json``) and emits
structured ``ToolField*`` events as top-level object fields are discovered.

Top-level *string* values are decoded (escape sequences resolved, quotes
stripped).  All other top-level values are emitted as raw JSON fragments.
"""

from __future__ import annotations

from collections.abc import Mapping
from enum import IntEnum
from types import MappingProxyType

from .events import ToolFieldDelta, ToolFieldEnd, ToolFieldStart

type ToolFieldEvent = ToolFieldStart | ToolFieldDelta | ToolFieldEnd


[docs] class State(IntEnum): INIT = 0 OBJ = 1 KEY = 2 COLON = 3 VAL = 4 STR = 5 RAW = 6 AFTER = 7 ESC = 8 UESC = 9
ESCAPES: Mapping[str, str] = MappingProxyType( { "n": "\n", "t": "\t", "r": "\r", "b": "\b", "f": "\f", '"': '"', "\\": "\\", "/": "/", } )
[docs] class ToolArgStream: """O(1)-per-character streaming parser for tool argument JSON. Usage:: stream = ToolArgStream("call_1") events = stream.feed('{"path":"/tmp/f') # [ToolFieldStart(0, "call_1", "path"), ToolFieldDelta(0, "call_1", "path", "/tmp/f")] events = stream.feed('oo.py"}') # [ToolFieldDelta(0, "call_1", "path", "oo.py"), ToolFieldEnd(0, "call_1", "path")] """ __slots__ = ( "_id", "_idx", "_st", "_key_chars", "_key", "_buf", "_u", "_high", "_depth", "_raw_str", "_raw_esc", "_esc_key", "_esc_ret", "_events", "_done", ) def __init__(self, tool_use_id: str, index: int = 0) -> None: self._id = tool_use_id self._idx = index self._st = State.INIT self._key_chars: list[str] = [] self._key = "" self._buf: list[str] = [] self._u: list[str] = [] self._high = 0 self._depth = 0 self._raw_str = False self._raw_esc = False self._esc_key = False self._esc_ret = State.KEY self._events: list[ToolFieldEvent] = [] self._done = False @property def current_key(self) -> str: """The field currently being streamed, or ``""``.""" return self._key @property def done(self) -> bool: """Whether the top-level JSON object has been fully parsed.""" return self._done
[docs] def feed(self, chunk: str) -> list[ToolFieldEvent]: """Process a partial JSON chunk and return any field events produced.""" self._events = [] for ch in chunk: self._step(ch) self._flush() return self._events
def _flush(self) -> None: if self._buf: self._events.append(ToolFieldDelta(self._idx, self._id, self._key, "".join(self._buf))) self._buf.clear() def _start(self) -> None: self._flush() self._events.append(ToolFieldStart(self._idx, self._id, self._key)) def _end(self) -> None: self._flush() self._events.append(ToolFieldEnd(self._idx, self._id, self._key)) def _step(self, ch: str) -> None: # noqa: PLR0912 match self._st: case State.INIT: if ch == "{": self._st = State.OBJ case State.OBJ: if ch == '"': self._key_chars.clear() self._st = State.KEY elif ch == "}": self._done = True self._st = State.INIT case State.KEY: if ch == "\\": self._esc_key = True self._esc_ret = State.KEY self._st = State.ESC elif ch == '"': self._key = "".join(self._key_chars) self._st = State.COLON else: self._key_chars.append(ch) case State.COLON: if ch == ":": self._start() self._st = State.VAL case State.VAL: if ch in " \t\r\n": pass elif ch == '"': self._st = State.STR else: self._buf.append(ch) self._depth = 1 if ch in "{[" else 0 self._raw_str = False self._raw_esc = False self._st = State.RAW case State.STR: if ch == "\\": self._esc_key = False self._esc_ret = State.STR self._st = State.ESC elif ch == '"': if self._high: self._buf.append("\ufffd") self._high = 0 self._end() self._st = State.AFTER else: if self._high: self._buf.append("\ufffd") self._high = 0 self._buf.append(ch) case State.RAW: if self._raw_str: self._buf.append(ch) if self._raw_esc: self._raw_esc = False elif ch == "\\": self._raw_esc = True elif ch == '"': self._raw_str = False elif self._depth == 0 and ch in " \t\r\n,}": # simple value (number/bool/null) ends on whitespace or delimiter self._end() self._st = State.AFTER if ch in ",}": self._step(ch) # reprocess delimiter elif ch == '"': self._buf.append(ch) self._raw_str = True elif ch in "{[": self._buf.append(ch) self._depth += 1 elif ch in "}]": self._buf.append(ch) self._depth -= 1 if self._depth == 0: self._end() self._st = State.AFTER else: self._buf.append(ch) case State.AFTER: if ch == ",": self._st = State.OBJ elif ch == "}": self._done = True self._st = State.INIT case State.ESC: if ch == "u": self._u.clear() self._st = State.UESC else: dec = ESCAPES.get(ch, ch) if self._esc_key: self._key_chars.append(dec) else: self._buf.append(dec) self._st = self._esc_ret case State.UESC: self._u.append(ch) if len(self._u) == 4: code = int("".join(self._u), 16) if self._esc_key: self._key_chars.append(chr(code)) elif self._high: if 0xDC00 <= code <= 0xDFFF: full = 0x10000 + (self._high - 0xD800) * 0x400 + (code - 0xDC00) self._buf.append(chr(full)) else: self._buf.append("\ufffd") self._buf.append(chr(code)) self._high = 0 elif 0xD800 <= code <= 0xDBFF: self._high = code else: self._buf.append(chr(code)) self._st = self._esc_ret