"""Agent: the core agentic loop orchestrating transport, tools, and context."""
from __future__ import annotations
import asyncio
import dataclasses
import json
import logging
import time
from collections.abc import AsyncGenerator, Iterable
from dataclasses import dataclass, field
from datetime import datetime
from typing import Any, Self
from .blocks import AudioBlock, ContentBlock, ImageBlock, TextBlock, ToolResultBlock, ToolUseBlock, VideoBlock
from .context import ContextStore
from .events import (
AudioOutput,
Error,
ImageOutput,
IterationEnd,
SessionEndEvent,
StreamEvent,
TextDelta,
ToolInputDelta,
ToolOutputDelta,
ToolResult,
ToolUseStart,
VideoOutput,
)
from .messages import Message
from .models import Capability
from .selector import ToolSelector
from .stream import AgentStream
from .tool import Tool
from .transport import CompletionTransport
from .types import StopReason, Usage
logger = logging.getLogger(__name__)
class _RepetitionDetector:
"""Detects when model output is stuck in a repetitive loop.
Two complementary checks run periodically on accumulated text:
1. **Short-period**: counts trailing consecutive repetitions of
patterns from 1 to ``max_period`` chars. Triggers when repetitions
span >= ``min_repeat_span`` chars. Catches single-token and
short-phrase loops quickly.
2. **Long-period**: checks whether the last ``long_window`` chars
appear verbatim earlier in the output. Catches paragraph-level
repetition that the short-period check would miss.
"""
__slots__ = (
"_parts",
"_total_len",
"_last_check",
"_interval",
"_min_len",
"_max_period",
"_min_repeat_span",
"_long_window",
)
def __init__(
self,
interval: int = 200,
min_len: int = 800,
max_period: int = 150,
min_repeat_span: int = 200,
long_window: int = 500,
) -> None:
self._parts: list[str] = []
self._total_len = 0
self._last_check = 0
self._interval = interval
self._min_len = min_len
self._max_period = max_period
self._min_repeat_span = min_repeat_span
self._long_window = long_window
def feed(self, delta: str) -> bool:
"""Feed a text delta. Returns ``True`` when a loop is detected."""
self._parts.append(delta)
self._total_len += len(delta)
if self._total_len < self._min_len:
return False
if self._total_len - self._last_check < self._interval:
return False
self._last_check = self._total_len
full = "".join(self._parts)
self._parts = [full]
n = len(full)
# --- Short-period: trailing repetition of a small pattern ---
max_p = min(self._max_period, n // 3)
for p in range(1, max_p + 1):
chunk = full[n - p : n]
count = 1
pos = n - 2 * p
while pos >= 0 and full[pos : pos + p] == chunk:
count += 1
pos -= p
if count >= 3 and count * p >= self._min_repeat_span:
return True
# --- Long-period: trailing window found earlier verbatim ---
w = min(self._long_window, n // 2)
if w >= self._min_repeat_span:
window = full[-w:]
if full.find(window, 0, n - w) >= 0:
return True
return False
[docs]
@dataclass(slots=True)
class Agent:
system: str
transport: CompletionTransport
tools: list[Tool[Any]] = field(default_factory=list)
selector: ToolSelector | None = field(default=None)
max_iterations: int = field(default=50)
last_iteration_message: Message | None = field(default=None)
[docs]
def copy(self, **overrides: Any) -> Self:
"""Return a new Agent with *overrides* applied."""
return dataclasses.replace(self, **overrides)
[docs]
def run_stream(self, user_message: str, context: ContextStore) -> AgentStream:
return AgentStream(self._run_loop(user_message, context))
[docs]
async def run(self, user_message: str, context: ContextStore) -> str:
return await self.run_stream(user_message, context).get_final_text()
async def _dispatch_tools_streaming(
self,
blocks: list[ToolUseBlock],
iteration: int,
output_queue: asyncio.Queue[ToolOutputDelta | None],
) -> list[ToolResultBlock]:
"""Like dispatch_tools but pushes ToolOutputDelta events for streaming tools."""
tool_names = [b.name for b in blocks]
logger.info("Dispatching %d tool(s) with streaming: %s", len(blocks), tool_names)
async def _run_one(block: ToolUseBlock) -> ToolResultBlock:
tool = self._find_tool(block.name)
if tool is None:
logger.warning("Unknown tool requested: %s", block.name)
return ToolResultBlock(tool_use_id=block.id, content=f"Unknown tool: {block.name}", is_error=True)
logger.debug("Tool %s (id=%s) args=%s", block.name, block.id, json.dumps(block.input)[:200])
if tool.supports_streaming:
chunks: list[tuple[float, str, str]] = []
t0 = time.monotonic()
try:
async for key, text in tool.call_streaming(**block.input):
chunks.append((time.monotonic() - t0, key, text))
await output_queue.put(
ToolOutputDelta(tool_use_id=block.id, name=block.name, key=key, delta=text)
)
except Exception as exc:
logger.error("Tool %s raised %s: %s", block.name, type(exc).__name__, exc, exc_info=True)
return ToolResultBlock(tool_use_id=block.id, content=str(exc), is_error=True)
return ToolResultBlock(tool_use_id=block.id, content=tool.format_stream_result(chunks))
else:
try:
result = await tool(**block.input)
if isinstance(result, str):
content: str | list[TextBlock | ImageBlock | AudioBlock | VideoBlock] = result
elif isinstance(result, list) and all(isinstance(b, ContentBlock) for b in result):
content = result
else:
content = str(result)
except Exception as exc:
logger.error("Tool %s raised %s: %s", block.name, type(exc).__name__, exc, exc_info=True)
return ToolResultBlock(tool_use_id=block.id, content=str(exc), is_error=True)
return ToolResultBlock(tool_use_id=block.id, content=content)
results = list(await asyncio.gather(*[_run_one(b) for b in blocks]))
error_count = sum(1 for r in results if r.is_error)
logger.info("Tools complete: %d total, %d errors", len(results), error_count)
return results
def _find_tool(self, name: str) -> Tool[Any] | None:
for tool in self.tools:
if tool.name == name:
return tool
return None
async def _append(self, context: ContextStore, message: Message) -> None:
await context.append(message)
@staticmethod
def _accumulate_text(
content: list[TextBlock | ImageBlock | AudioBlock | VideoBlock | ToolUseBlock],
delta: str,
) -> None:
"""Append text delta — merge into last TextBlock or start a new one."""
if content and isinstance(content[-1], TextBlock):
content[-1] = TextBlock(text=content[-1].text + delta)
else:
content.append(TextBlock(text=delta))
@staticmethod
def _finalize_pending_tools(
pending: dict[str, dict[str, Any]],
usage: Usage,
) -> tuple[list[ToolUseBlock], set[str]]:
"""Convert streamed tool-call fragments into ToolUseBlocks.
Returns (blocks, malformed_ids). Malformed IDs arise when
max_tokens truncates the response mid-tool-call, producing
incomplete JSON (expected with eager_input_streaming). The
caller is responsible for not executing malformed tools.
"""
blocks: list[ToolUseBlock] = []
malformed: set[str] = set()
for tid, info in pending.items():
raw = "".join(info["json_parts"])
if not raw:
logger.warning(
"Tool %s (id=%s) received empty arguments (output may be truncated, output_tokens=%d)",
info["name"],
tid,
usage.output_tokens,
)
inp: dict[str, Any] = {}
else:
try:
inp = json.loads(raw)
except json.JSONDecodeError as exc:
logger.warning(
"Tool %s (id=%s) has malformed JSON arguments: %s\nRaw: %s",
info["name"],
tid,
exc,
raw,
)
malformed.add(tid)
inp = {}
blocks.append(ToolUseBlock(id=tid, name=info["name"], input=inp))
return blocks, malformed
async def _select_tools(self, history: list[Message], tools: list[Tool[Any]]) -> Iterable[Tool[Any]]:
if not tools:
return []
if not self.selector:
return tools
return await self.selector.select(history, tools)
async def _run_loop(self, user_message: str, context: ContextStore) -> AsyncGenerator[StreamEvent, None]:
total_usage = Usage(0, 0)
session_end_emitted = False
ts = datetime.now().astimezone().strftime("%Y-%m-%d %H:%M:%S %Z")
await self._append(context, Message(role="user", content=[TextBlock(text=f"[{ts}] {user_message}")]))
try:
for iteration in range(1, self.max_iterations + 1):
history = await context.get_history()
logger.info("Iteration %d, history length=%d", iteration, len(history))
effective_history = (
[*history, self.last_iteration_message]
if self.last_iteration_message and iteration == self.max_iterations
else history
)
model = getattr(self.transport, "model", None)
model_caps = getattr(model, "capabilities", None)
if model_caps is not None and Capability.tool_use not in model_caps:
active_tools: list[Tool[Any]] = []
else:
active_tools = list(await self._select_tools(effective_history, self.tools))
content: list[TextBlock | ImageBlock | AudioBlock | VideoBlock | ToolUseBlock] = []
pending: dict[str, dict[str, Any]] = {}
stop_reason = StopReason.end_turn
malformed: set[str] = set()
repetition_detected = False
rep_detector = _RepetitionDetector()
try:
async for event in self.transport.stream(effective_history, active_tools, self.system):
yield event
match event:
case TextDelta(delta=delta):
self._accumulate_text(content, delta)
if rep_detector.feed(delta):
note = "\n\n[Output truncated: repetitive content detected]"
self._accumulate_text(content, note)
yield TextDelta(index=0, delta=note)
repetition_detected = True
break
case ImageOutput(data=data, media_type=mt):
content.append(ImageBlock(media_type=mt, data=data))
case VideoOutput(data=data, media_type=mt):
content.append(VideoBlock(media_type=mt, data=data))
case ToolUseStart(tool_use_id=tid, name=name):
pending[tid] = {"name": name, "json_parts": []}
case ToolInputDelta(tool_use_id=tid, partial_json=pj):
if tid in pending:
pending[tid]["json_parts"].append(pj)
case IterationEnd(usage=usage, stop_reason=sr):
blocks, malformed = self._finalize_pending_tools(pending, usage)
content.extend(blocks)
pending.clear()
total_usage = total_usage + usage
await context.add_context_tokens(usage.input_tokens, usage.output_tokens)
stop_reason = sr
except Exception as exc:
logger.error("Transport error: %s", exc, exc_info=True)
yield Error(exception=exc)
yield SessionEndEvent(stop_reason=StopReason.error, total_usage=total_usage)
session_end_emitted = True
return
if repetition_detected:
await self._append(context, Message(role="assistant", content=list(content)))
partial = getattr(self.transport, "last_usage", None)
if partial:
total_usage = total_usage + partial
yield SessionEndEvent(stop_reason=StopReason.end_turn, total_usage=total_usage)
session_end_emitted = True
return
tool_blocks = [b for b in content if isinstance(b, ToolUseBlock)]
if tool_blocks:
if stop_reason != StopReason.tool_use:
logger.warning(
"Dispatching %d tool(s) despite stop_reason=%s",
len(tool_blocks),
stop_reason,
)
# Dispatch tools BEFORE appending to context - cancellation
# between here and the two appends below cannot leave orphan
# ToolUseBlocks in the persistent context store.
valid = [b for b in tool_blocks if b.id not in malformed]
error_results = [
ToolResultBlock(
tool_use_id=b.id,
content=(
f"Malformed JSON arguments for tool {b.name}."
f" Raw input could not be parsed. Please retry the tool call"
f" with valid JSON arguments."
),
is_error=True,
)
for b in tool_blocks
if b.id in malformed
]
partial_output: dict[str, list[tuple[float, str, str]]] = {}
t0_map: dict[str, float] = {}
dispatch_task: asyncio.Task[list[ToolResultBlock]] | None = None
try:
if valid:
has_streaming = any(
(t := self._find_tool(b.name)) is not None and t.supports_streaming for b in valid
)
if has_streaming:
output_queue: asyncio.Queue[ToolOutputDelta | None] = asyncio.Queue()
async def _dispatch_and_signal() -> list[ToolResultBlock]:
result = await self._dispatch_tools_streaming(valid, iteration, output_queue)
await output_queue.put(None)
return result
dispatch_task = asyncio.create_task(_dispatch_and_signal())
while True:
ev = await output_queue.get()
if ev is None:
break
if ev.tool_use_id not in t0_map:
t0_map[ev.tool_use_id] = time.monotonic()
partial_output.setdefault(ev.tool_use_id, []).append(
(time.monotonic() - t0_map[ev.tool_use_id], ev.key, ev.delta)
)
yield ev
dispatched = await dispatch_task
else:
dispatched = await self.dispatch_tools(valid, iteration)
else:
dispatched = []
results = dispatched + error_results
except asyncio.CancelledError:
if dispatch_task is not None and not dispatch_task.done():
dispatch_task.cancel()
try:
await dispatch_task
except (asyncio.CancelledError, Exception):
pass
interrupted_results: list[ToolResultBlock] = []
for b in tool_blocks:
chunks = partial_output.get(b.id, [])
tool = self._find_tool(b.name)
if chunks and tool:
msg = tool.format_stream_result(chunks) + "\n[interrupted by user]"
elif chunks:
msg = "".join(text for _, _, text in chunks) + "\n[interrupted by user]"
else:
msg = "[interrupted by user]"
interrupted_results.append(ToolResultBlock(tool_use_id=b.id, content=msg, is_error=True))
await self._append(context, Message(role="assistant", content=list(content)))
await self._append(context, Message(role="user", content=list(interrupted_results)))
raise
# Append both messages atomically (assistant + tool results)
await self._append(context, Message(role="assistant", content=list(content)))
await self._append(context, Message(role="user", content=list(results)))
# Gemini stops generating (~20 tokens, end_turn) after receiving
# media as sibling inlineData parts alongside functionResponse.
# A "Proceed." user message nudges it to actually analyze the content.
if getattr(self.transport, "nudge_on_media_tool_result", False) and any(
not isinstance(r.content, str)
and any(isinstance(b, (AudioBlock, ImageBlock, VideoBlock)) for b in r.content)
for r in results
):
await self._append(
context,
Message(
role="user",
content=[
TextBlock(
text="You now have the media file above in your context. Proceed.",
)
],
),
)
# Yield ToolResult events + media output events.
# Non-streaming tools return full content (str or list of
# TextBlock/ImageBlock/VideoBlock) — no information is lost.
# Images/videos are yielded as separate ImageOutput/VideoOutput
# events so the REPL can save them to disk; the model sees the
# actual pixel data via ImageBlock/VideoBlock in the tool result.
by_id = {b.id: b for b in tool_blocks}
for r in results:
block = by_id.get(r.tool_use_id)
if isinstance(r.content, str):
result_content = r.content
else:
result_content = "\n".join(b.text for b in r.content if isinstance(b, TextBlock))
for media_block in r.content:
if isinstance(media_block, ImageBlock):
yield ImageOutput(
index=0, data=media_block.data, media_type=media_block.media_type
)
elif isinstance(media_block, AudioBlock):
yield AudioOutput(
index=0, data=media_block.data, media_type=media_block.media_type
)
elif isinstance(media_block, VideoBlock):
yield VideoOutput(
index=0, data=media_block.data, media_type=media_block.media_type
)
yield ToolResult(
tool_use_id=r.tool_use_id,
name=block.name if block else "",
is_error=r.is_error,
content=result_content,
input=block.input if block else {},
)
continue
await self._append(context, Message(role="assistant", content=list(content)))
match stop_reason:
case StopReason.end_turn:
logger.debug("End turn: total_usage=%s", total_usage)
yield SessionEndEvent(stop_reason=StopReason.end_turn, total_usage=total_usage)
session_end_emitted = True
return
case StopReason.max_tokens | StopReason.error:
yield Error(exception=RuntimeError(f"Transport stopped with: {stop_reason}"))
yield SessionEndEvent(stop_reason=StopReason.error, total_usage=total_usage)
session_end_emitted = True
return
logger.warning("Max iterations (%d) reached", self.max_iterations)
yield SessionEndEvent(stop_reason=StopReason.error, total_usage=total_usage)
session_end_emitted = True
except GeneratorExit:
return
except BaseException:
if not session_end_emitted:
yield SessionEndEvent(stop_reason=StopReason.error, total_usage=total_usage)
raise