From 85cc7bc433b9dec89bbc4df876293bd7c5157fa7 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Sat, 28 Mar 2026 21:44:46 -0700 Subject: [PATCH] refactor: make unit test --- tests/agent_tests/test_a2a_agent.py | 82 ++++++++++++++++++++--------- 1 file changed, 57 insertions(+), 25 deletions(-) diff --git a/tests/agent_tests/test_a2a_agent.py b/tests/agent_tests/test_a2a_agent.py index ace1f8c54c..ed7a5ab982 100644 --- a/tests/agent_tests/test_a2a_agent.py +++ b/tests/agent_tests/test_a2a_agent.py @@ -1,38 +1,72 @@ """ Simple A2A agent tests - non-streaming and streaming. -These tests validate the localhost URL retry logic: if an A2A agent's card -contains a localhost/internal URL (e.g., http://0.0.0.0:8001/), the request -will fail with a connection error. LiteLLM detects this and automatically -retries using the original api_base URL instead. - -Requires A2A_AGENT_URL environment variable to be set. - -Run with: - A2A_AGENT_URL=https://your-agent.example.com pytest tests/agent_tests/test_a2a_agent.py -v -s +These tests use a mocked A2A client to avoid network/env dependencies. """ -import os - -import pytest +from types import SimpleNamespace from uuid import uuid4 +import pytest -def get_a2a_agent_url(): - """Get A2A agent URL from environment, skip test if not set.""" - url = os.environ.get("A2A_AGENT_URL") - return url + +class MockA2AResponse: + def __init__(self, text: str): + self._payload = { + "id": str(uuid4()), + "jsonrpc": "2.0", + "result": { + "message": { + "role": "agent", + "parts": [{"kind": "text", "text": text}], + "messageId": uuid4().hex, + } + }, + } + + def model_dump(self, mode="json", exclude_none=True): + return self._payload + + +class MockA2AStreamingChunk(MockA2AResponse): + def __init__(self, text: str, state: str): + super().__init__(text=text) + self._payload["result"]["status"] = {"state": state} + + +class MockA2AClient: + def __init__(self): + self._litellm_agent_card = SimpleNamespace( + name="mock-agent", url="http://mock-agent.local" + ) + + async def send_message(self, request): + return MockA2AResponse(text="hello") + + def send_message_streaming(self, request): + async def _stream(): + yield MockA2AStreamingChunk(text="hel", state="in_progress") + yield MockA2AStreamingChunk(text="hello", state="completed") + + return _stream() + + +@pytest.fixture +def mock_a2a_client(monkeypatch): + import litellm.a2a_protocol.main as a2a_main + + async def _fake_create_a2a_client(base_url, timeout=60.0, extra_headers=None): + return MockA2AClient() + + monkeypatch.setattr(a2a_main, "create_a2a_client", _fake_create_a2a_client) @pytest.mark.asyncio -@pytest.mark.flaky(retries=3, delay=5) -async def test_a2a_non_streaming(): +async def test_a2a_non_streaming(mock_a2a_client): """Test non-streaming A2A request.""" from a2a.types import MessageSendParams, SendMessageRequest from litellm.a2a_protocol import asend_message - api_base = get_a2a_agent_url() - request = SendMessageRequest( id=str(uuid4()), params=MessageSendParams( @@ -46,7 +80,7 @@ async def test_a2a_non_streaming(): response = await asend_message( request=request, - api_base=api_base, + api_base="http://mock", ) assert response is not None @@ -54,13 +88,11 @@ async def test_a2a_non_streaming(): @pytest.mark.asyncio -async def test_a2a_streaming(): +async def test_a2a_streaming(mock_a2a_client): """Test streaming A2A request.""" from a2a.types import MessageSendParams, SendStreamingMessageRequest from litellm.a2a_protocol import asend_message_streaming - api_base = get_a2a_agent_url() - request = SendStreamingMessageRequest( id=str(uuid4()), params=MessageSendParams( @@ -75,7 +107,7 @@ async def test_a2a_streaming(): chunks = [] async for chunk in asend_message_streaming( request=request, - api_base=api_base, + api_base="http://mock", ): chunks.append(chunk) print(f"\nStreaming chunk: {chunk}")