Source code for acp_sdk.client.client

import asyncio
import logging
import ssl
import typing
from collections.abc import AsyncIterator
from types import TracebackType
from typing import Self

import httpx
from httpx_sse import EventSource, aconnect_sse
from pydantic import TypeAdapter

from acp_sdk.client.types import Input
from acp_sdk.client.utils import input_to_messages
from acp_sdk.instrumentation import get_tracer
from acp_sdk.models import (
    ACPError,
    AgentManifest,
    AgentName,
    AgentReadResponse,
    AgentsListResponse,
    AwaitResume,
    Error,
    ErrorCode,
    ErrorEvent,
    Event,
    PingResponse,
    Run,
    RunCancelResponse,
    RunCreateRequest,
    RunCreateResponse,
    RunEventsListResponse,
    RunId,
    RunMode,
    RunResumeRequest,
    RunResumeResponse,
    Session,
    SessionReadResponse,
)

logger = logging.getLogger(__name__)


[docs] class Client:
[docs] def __init__( self, *, session: Session | None = None, client: httpx.AsyncClient | None = None, manage_client: bool = True, auth: httpx._types.AuthTypes | None = None, params: httpx._types.QueryParamTypes | None = None, headers: httpx._types.HeaderTypes | None = None, cookies: httpx._types.CookieTypes | None = None, timeout: httpx._types.TimeoutTypes = None, verify: ssl.SSLContext | str | bool = True, cert: httpx._types.CertTypes | None = None, http1: bool = True, http2: bool = False, proxy: httpx._types.ProxyTypes | None = None, mounts: None | (typing.Mapping[str, httpx.AsyncBaseTransport | None]) = None, follow_redirects: bool = False, limits: httpx.Limits = httpx._config.DEFAULT_LIMITS, max_redirects: int = httpx._config.DEFAULT_MAX_REDIRECTS, event_hooks: None | (typing.Mapping[str, list[httpx._client.EventHook]]) = None, base_url: httpx.URL | str = "", transport: httpx.AsyncBaseTransport | None = None, trust_env: bool = True, ) -> None: self._session = session self._session_last_refresh_base_url: httpx.URL | None = None self._session_refresh_lock = asyncio.Lock() self._client = client or httpx.AsyncClient( auth=auth, params=params, headers=headers, cookies=cookies, timeout=timeout, verify=verify, cert=cert, http1=http1, http2=http2, proxy=proxy, mounts=mounts, follow_redirects=follow_redirects, limits=limits, max_redirects=max_redirects, event_hooks=event_hooks, base_url=base_url, transport=transport, trust_env=trust_env, ) self._manage_client = manage_client
@property def client(self) -> httpx.AsyncClient: return self._client async def __aenter__(self) -> Self: if self._manage_client: await self._client.__aenter__() self._session_span_manager = ( ( get_tracer() .start_as_current_span("session", attributes={"acp.session": str(self._session.id)}) .__enter__() ) if self._session else None ) return self async def __aexit__( self, exc_type: type[BaseException] | None = None, exc_value: BaseException | None = None, traceback: TracebackType | None = None, ) -> None: if self._session_span_manager: self._session_span_manager.__exit__(exc_type, exc_value, traceback) if self._manage_client: await self._client.__aexit__(exc_type, exc_value, traceback)
[docs] def session(self, session: Session | None = None) -> Self: return Client(client=self._client, manage_client=False, session=session or Session())
[docs] async def agents(self, *, base_url: httpx.URL | str | None = None) -> AsyncIterator[AgentManifest]: response = await self._client.get(self._create_url("/agents", base_url=base_url)) self._raise_error(response) for agent in AgentsListResponse.model_validate(response.json()).agents: yield agent
[docs] async def agent(self, *, name: AgentName, base_url: httpx.URL | str | None = None) -> AgentManifest: response = await self._client.get(self._create_url(f"/agents/{name}", base_url=base_url)) self._raise_error(response) response = AgentReadResponse.model_validate(response.json()) return AgentManifest(**response.model_dump())
[docs] async def ping(self, *, base_url: httpx.URL | str | None = None) -> bool: response = await self._client.get(self._create_url("/ping", base_url=base_url)) self._raise_error(response) PingResponse.model_validate(response.json()) return
[docs] async def run_sync(self, input: Input, *, agent: AgentName, base_url: httpx.URL | str | None = None) -> Run: response = await self._client.post( self._create_url("/runs", base_url=base_url), content=RunCreateRequest( agent_name=agent, input=input_to_messages(input), mode=RunMode.SYNC, **(await self._prepare_session_for_run(base_url=base_url)), ).model_dump_json(), ) self._raise_error(response) response = RunCreateResponse.model_validate(response.json()) return Run(**response.model_dump())
[docs] async def run_async(self, input: Input, *, agent: AgentName, base_url: httpx.URL | str | None = None) -> Run: response = await self._client.post( self._create_url("/runs", base_url=base_url), content=RunCreateRequest( agent_name=agent, input=input_to_messages(input), mode=RunMode.ASYNC, **(await self._prepare_session_for_run(base_url=base_url)), ).model_dump_json(), ) self._raise_error(response) response = RunCreateResponse.model_validate(response.json()) return Run(**response.model_dump())
[docs] async def run_stream( self, input: Input, *, agent: AgentName, base_url: httpx.URL | str | None = None ) -> AsyncIterator[Event]: async with aconnect_sse( self._client, "POST", self._create_url("/runs", base_url=base_url), content=RunCreateRequest( agent_name=agent, input=input_to_messages(input), mode=RunMode.STREAM, session=await self._prepare_session_for_run(base_url=base_url), ).model_dump_json(), ) as event_source: async for event in self._validate_stream(event_source): yield event
[docs] async def run_status(self, *, run_id: RunId, base_url: httpx.URL | str | None = None) -> Run: response = await self._client.get(self._create_url(f"/runs/{run_id}", base_url=base_url)) self._raise_error(response) return Run.model_validate(response.json())
[docs] async def run_events(self, *, run_id: RunId, base_url: httpx.URL | str | None = None) -> AsyncIterator[Event]: response = await self._client.get(self._create_url(f"/runs/{run_id}/events", base_url=base_url)) self._raise_error(response) response = RunEventsListResponse.model_validate(response.json()) for event in response.events: yield event
[docs] async def run_cancel(self, *, run_id: RunId, base_url: httpx.URL | str | None = None) -> Run: response = await self._client.post(self._create_url(f"/runs/{run_id}/cancel", base_url=base_url)) self._raise_error(response) response = RunCancelResponse.model_validate(response.json()) return Run(**response.model_dump())
[docs] async def run_resume_sync( self, await_resume: AwaitResume, *, run_id: RunId, base_url: httpx.URL | str | None = None ) -> Run: response = await self._client.post( self._create_url(f"/runs/{run_id}", base_url=base_url), content=RunResumeRequest(await_resume=await_resume, mode=RunMode.SYNC).model_dump_json(), ) self._raise_error(response) response = RunResumeResponse.model_validate(response.json()) return Run(**response.model_dump())
[docs] async def run_resume_async( self, await_resume: AwaitResume, *, run_id: RunId, base_url: httpx.URL | str | None = None ) -> Run: response = await self._client.post( self._create_url(f"/runs/{run_id}", base_url=base_url), content=RunResumeRequest(await_resume=await_resume, mode=RunMode.ASYNC).model_dump_json(), ) self._raise_error(response) response = RunResumeResponse.model_validate(response.json()) return Run(**response.model_dump())
[docs] async def run_resume_stream( self, await_resume: AwaitResume, *, run_id: RunId, base_url: httpx.URL | str | None = None ) -> AsyncIterator[Event]: async with aconnect_sse( self._client, "POST", self._create_url(f"/runs/{run_id}", base_url=base_url), content=RunResumeRequest(await_resume=await_resume, mode=RunMode.STREAM).model_dump_json(), ) as event_source: async for event in self._validate_stream(event_source): yield event
[docs] async def refresh_session( self, *, base_url: httpx.URL | str | None = None, timeout: httpx._types.TimeoutTypes = 5000 ) -> Session: if not self._session: raise RuntimeError("Client is not in a session") async with self._session_refresh_lock: url = self._create_url( f"/sessions/{self._session.id}", base_url=base_url or self._session_last_refresh_base_url, ) response = await self._client.get(url, timeout=timeout) self._raise_error(response) response = SessionReadResponse.model_validate(response.json()) self._session = Session(**response.model_dump()) return self._session
async def _validate_stream( self, event_source: EventSource, ) -> AsyncIterator[Event]: if event_source.response.is_error: await event_source.response.aread() self._raise_error(event_source.response) async for event in event_source.aiter_sse(): event: Event = TypeAdapter(Event).validate_json(event.data) if isinstance(event, ErrorEvent): raise ACPError(error=event.error) yield event def _raise_error(self, response: httpx.Response) -> None: try: response.raise_for_status() except httpx.HTTPError: raise ACPError(Error.model_validate(response.json())) def _create_base_url(self, base_url: httpx.URL | str | None) -> httpx.URL: base_url = httpx.URL(base_url or self._client.base_url) if not base_url.raw_path.endswith(b"/"): base_url = base_url.copy_with(raw_path=base_url.raw_path + b"/") return base_url def _create_url(self, endpoint: str, base_url: httpx.URL | str | None) -> httpx.URL: merge_url = httpx.URL(endpoint) if not merge_url.is_relative_url: raise ValueError("Endpoint must be a relative URL") base_url = self._create_base_url(base_url) merge_raw_path = base_url.raw_path + merge_url.raw_path.lstrip(b"/") return base_url.copy_with(raw_path=merge_raw_path) async def _prepare_session_for_run(self, *, base_url: httpx.URL | str | None) -> dict: if not self._session: return {} target_base_url = self._create_base_url(base_url=base_url) try: if self._session_last_refresh_base_url == target_base_url: # Same server, no need to forward session return {"session_id": self._session.id} session = await self.refresh_session(base_url=self._session_last_refresh_base_url or target_base_url) return {"session": session} except ACPError as e: if e.error.code == ErrorCode.NOT_FOUND: return {"session": self._session} raise e finally: await self._update_session_refresh_url(target_base_url) async def _update_session_refresh_url(self, url: httpx.URL) -> None: async with self._session_refresh_lock: self._session_last_refresh_base_url = url