Source code for acp_sdk.server.agent

import abc
import asyncio
import inspect
from collections.abc import AsyncGenerator, Coroutine, Generator
from concurrent.futures import ThreadPoolExecutor
from typing import Callable

import janus

from acp_sdk.models import (
    AgentName,
    Message,
    SessionId,
)
from acp_sdk.models.models import Metadata
from acp_sdk.server.context import Context
from acp_sdk.server.types import RunYield, RunYieldResume


[docs] class Agent(abc.ABC): @property def name(self) -> AgentName: return self.__class__.__name__ @property def description(self) -> str: return "" @property def metadata(self) -> Metadata: return Metadata()
[docs] @abc.abstractmethod def run( self, input: list[Message], context: Context ) -> ( AsyncGenerator[RunYield, RunYieldResume] | Generator[RunYield, RunYieldResume] | Coroutine[RunYield] | RunYield ): pass
[docs] async def execute( self, input: list[Message], session_id: SessionId | None, executor: ThreadPoolExecutor ) -> AsyncGenerator[RunYield, RunYieldResume]: yield_queue: janus.Queue[RunYield] = janus.Queue() yield_resume_queue: janus.Queue[RunYieldResume] = janus.Queue() context = Context( session_id=session_id, executor=executor, yield_queue=yield_queue, yield_resume_queue=yield_resume_queue ) if inspect.isasyncgenfunction(self.run): run = asyncio.create_task(self._run_async_gen(input, context)) elif inspect.iscoroutinefunction(self.run): run = asyncio.create_task(self._run_coro(input, context)) elif inspect.isgeneratorfunction(self.run): run = asyncio.get_running_loop().run_in_executor(executor, self._run_gen, input, context) else: run = asyncio.get_running_loop().run_in_executor(executor, self._run_func, input, context) try: while not run.done() or yield_queue.async_q.qsize() > 0: value = yield await yield_queue.async_q.get() if isinstance(value, Exception): raise value await yield_resume_queue.async_q.put(value) except janus.AsyncQueueShutDown: pass
async def _run_async_gen(self, input: list[Message], context: Context) -> None: try: gen: AsyncGenerator[RunYield, RunYieldResume] = self.run(input, context) value = None while True: value = await context.yield_async(await gen.asend(value)) except StopAsyncIteration: pass except Exception as e: await context.yield_async(e) finally: context.shutdown() async def _run_coro(self, input: list[Message], context: Context) -> None: try: await context.yield_async(await self.run(input, context)) except Exception as e: await context.yield_async(e) finally: context.shutdown() def _run_gen(self, input: list[Message], context: Context) -> None: try: gen: Generator[RunYield, RunYieldResume] = self.run(input, context) value = None while True: value = context.yield_sync(gen.send(value)) except StopIteration: pass except Exception as e: context.yield_sync(e) finally: context.shutdown() def _run_func(self, input: list[Message], context: Context) -> None: try: context.yield_sync(self.run(input, context)) except Exception as e: context.yield_sync(e) finally: context.shutdown()
[docs] def agent( name: str | None = None, description: str | None = None, *, metadata: Metadata | None = None, ) -> Callable[[Callable], Agent]: """Decorator to create an agent.""" def decorator(fn: Callable) -> Agent: signature = inspect.signature(fn) parameters = list(signature.parameters.values()) if len(parameters) == 0: raise TypeError("The agent function must have at least 'input' argument") if len(parameters) > 2: raise TypeError("The agent function must have only 'input' and 'context' arguments") if len(parameters) == 2 and parameters[1].name != "context": raise TypeError("The second argument of the agent function must be 'context'") has_context_param = len(parameters) == 2 class DecoratorAgentBase(Agent): @property def name(self) -> str: return name or fn.__name__ @property def description(self) -> str: return description or inspect.getdoc(fn) or "" @property def metadata(self) -> Metadata: return metadata or Metadata() agent: Agent if inspect.isasyncgenfunction(fn): class AsyncGenDecoratorAgent(DecoratorAgentBase): async def run(self, input: list[Message], context: Context) -> AsyncGenerator[RunYield, RunYieldResume]: try: gen: AsyncGenerator[RunYield, RunYieldResume] = ( fn(input, context) if has_context_param else fn(input) ) value = None while True: value = yield await gen.asend(value) except StopAsyncIteration: pass agent = AsyncGenDecoratorAgent() elif inspect.iscoroutinefunction(fn): class CoroDecoratorAgent(DecoratorAgentBase): async def run(self, input: list[Message], context: Context) -> Coroutine[RunYield]: return await (fn(input, context) if has_context_param else fn(input)) agent = CoroDecoratorAgent() elif inspect.isgeneratorfunction(fn): class GenDecoratorAgent(DecoratorAgentBase): def run(self, input: list[Message], context: Context) -> Generator[RunYield, RunYieldResume]: yield from (fn(input, context) if has_context_param else fn(input)) agent = GenDecoratorAgent() else: class FuncDecoratorAgent(DecoratorAgentBase): def run(self, input: list[Message], context: Context) -> RunYield: return fn(input, context) if has_context_param else fn(input) agent = FuncDecoratorAgent() return agent return decorator