Source code for acp_sdk.client.client

import ssl
import typing
import uuid
from collections.abc import AsyncGenerator, AsyncIterator
from contextlib import asynccontextmanager
from types import TracebackType
from typing import Self

import httpx
from httpx_sse import EventSource, aconnect_sse
from opentelemetry.instrumentation.httpx import HTTPXClientInstrumentor
from pydantic import TypeAdapter

from acp_sdk.client.types import Input
from acp_sdk.client.utils import input_to_messages
from acp_sdk.instrumentation import get_tracer
from acp_sdk.models import (
    ACPError,
    Agent,
    AgentName,
    AgentReadResponse,
    AgentsListResponse,
    AwaitResume,
    Error,
    Event,
    PingResponse,
    Run,
    RunCancelResponse,
    RunCreateRequest,
    RunCreateResponse,
    RunEventsListResponse,
    RunId,
    RunMode,
    RunResumeRequest,
    RunResumeResponse,
    SessionId,
)


[docs] class Client:
[docs] def __init__( self, *, session_id: SessionId | None = None, client: httpx.AsyncClient | None = None, instrument: bool = True, auth: httpx._types.AuthTypes | None = None, params: httpx._types.QueryParamTypes | None = None, headers: httpx._types.HeaderTypes | None = None, cookies: httpx._types.CookieTypes | None = None, timeout: httpx._types.TimeoutTypes = None, verify: ssl.SSLContext | str | bool = True, cert: httpx._types.CertTypes | None = None, http1: bool = True, http2: bool = False, proxy: httpx._types.ProxyTypes | None = None, mounts: None | (typing.Mapping[str, httpx.AsyncBaseTransport | None]) = None, follow_redirects: bool = False, limits: httpx.Limits = httpx._config.DEFAULT_LIMITS, max_redirects: int = httpx._config.DEFAULT_MAX_REDIRECTS, event_hooks: None | (typing.Mapping[str, list[httpx._client.EventHook]]) = None, base_url: httpx.URL | str = "", transport: httpx.AsyncBaseTransport | None = None, trust_env: bool = True, ) -> None: self._session_id = session_id self._client = client or httpx.AsyncClient( auth=auth, params=params, headers=headers, cookies=cookies, timeout=timeout, verify=verify, cert=cert, http1=http1, http2=http2, proxy=proxy, mounts=mounts, follow_redirects=follow_redirects, limits=limits, max_redirects=max_redirects, event_hooks=event_hooks, base_url=base_url, transport=transport, trust_env=trust_env, ) if instrument: HTTPXClientInstrumentor.instrument_client(self._client)
@property def client(self) -> httpx.AsyncClient: return self._client async def __aenter__(self) -> Self: await self._client.__aenter__() return self async def __aexit__( self, exc_type: type[BaseException] | None = None, exc_value: BaseException | None = None, traceback: TracebackType | None = None, ) -> None: await self._client.__aexit__(exc_type, exc_value, traceback)
[docs] @asynccontextmanager async def session(self, session_id: SessionId | None = None) -> AsyncGenerator[Self]: session_id = session_id or uuid.uuid4() with get_tracer().start_as_current_span("session", attributes={"acp.session": str(session_id)}): yield Client(client=self._client, session_id=session_id, instrument=False)
[docs] async def agents(self) -> AsyncIterator[Agent]: response = await self._client.get("/agents") self._raise_error(response) for agent in AgentsListResponse.model_validate(response.json()).agents: yield agent
[docs] async def agent(self, *, name: AgentName) -> Agent: response = await self._client.get(f"/agents/{name}") self._raise_error(response) response = AgentReadResponse.model_validate(response.json()) return Agent(**response.model_dump())
[docs] async def ping(self) -> bool: response = await self._client.get("/ping") self._raise_error(response) PingResponse.model_validate(response.json()) return
[docs] async def run_sync(self, input: Input, *, agent: AgentName) -> Run: response = await self._client.post( "/runs", content=RunCreateRequest( agent_name=agent, input=input_to_messages(input), mode=RunMode.SYNC, session_id=self._session_id, ).model_dump_json(), ) self._raise_error(response) response = RunCreateResponse.model_validate(response.json()) return Run(**response.model_dump())
[docs] async def run_async(self, input: Input, *, agent: AgentName) -> Run: response = await self._client.post( "/runs", content=RunCreateRequest( agent_name=agent, input=input_to_messages(input), mode=RunMode.ASYNC, session_id=self._session_id, ).model_dump_json(), ) self._raise_error(response) response = RunCreateResponse.model_validate(response.json()) return Run(**response.model_dump())
[docs] async def run_stream(self, input: Input, *, agent: AgentName) -> AsyncIterator[Event]: async with aconnect_sse( self._client, "POST", "/runs", content=RunCreateRequest( agent_name=agent, input=input_to_messages(input), mode=RunMode.STREAM, session_id=self._session_id, ).model_dump_json(), ) as event_source: async for event in self._validate_stream(event_source): yield event
[docs] async def run_status(self, *, run_id: RunId) -> Run: response = await self._client.get(f"/runs/{run_id}") self._raise_error(response) return Run.model_validate(response.json())
[docs] async def run_events(self, *, run_id: RunId) -> AsyncIterator[Event]: response = await self._client.get(f"/runs/{run_id}/events") self._raise_error(response) response = RunEventsListResponse.model_validate(response.json()) for event in response.events: yield event
[docs] async def run_cancel(self, *, run_id: RunId) -> Run: response = await self._client.post(f"/runs/{run_id}/cancel") self._raise_error(response) response = RunCancelResponse.model_validate(response.json()) return Run(**response.model_dump())
[docs] async def run_resume_sync(self, await_resume: AwaitResume, *, run_id: RunId) -> Run: response = await self._client.post( f"/runs/{run_id}", content=RunResumeRequest(await_resume=await_resume, mode=RunMode.SYNC).model_dump_json(), ) self._raise_error(response) response = RunResumeResponse.model_validate(response.json()) return Run(**response.model_dump())
[docs] async def run_resume_async(self, await_resume: AwaitResume, *, run_id: RunId) -> Run: response = await self._client.post( f"/runs/{run_id}", content=RunResumeRequest(await_resume=await_resume, mode=RunMode.ASYNC).model_dump_json(), ) self._raise_error(response) response = RunResumeResponse.model_validate(response.json()) return Run(**response.model_dump())
[docs] async def run_resume_stream(self, await_resume: AwaitResume, *, run_id: RunId) -> AsyncIterator[Event]: async with aconnect_sse( self._client, "POST", f"/runs/{run_id}", content=RunResumeRequest(await_resume=await_resume, mode=RunMode.STREAM).model_dump_json(), ) as event_source: async for event in self._validate_stream(event_source): yield event
async def _validate_stream( self, event_source: EventSource, ) -> AsyncIterator[Event]: if event_source.response.is_error: await event_source.response.aread() self._raise_error(event_source.response) async for event in event_source.aiter_sse(): event = TypeAdapter(Event).validate_json(event.data) yield event def _raise_error(self, response: httpx.Response) -> None: try: response.raise_for_status() except httpx.HTTPError: raise ACPError(Error.model_validate(response.json()))