mirror of
https://github.com/tiennm99/litellm.git
synced 2026-06-30 01:03:03 +00:00
778a7f752d
* Add OAuth M2M support for A2A agents targeting Databricks Apps Databricks App endpoints reject static bearer tokens and require a short-lived OAuth token minted via the workspace OIDC token endpoint. A2A agents could previously only authenticate outbound with static_headers or client header passthrough, so Databricks App agents could not be registered. Agents configured with a databricks_oauth block in litellm_params now mint and cache a client_credentials token and attach it as the outbound Authorization header on both message/send and message/stream calls, overriding any statically configured Authorization. * Add tests covering Databricks App OAuth token error paths Cover the HTTP status error, transport error, non-object JSON body, and invalid expires_in fallback branches in the token cache so the failure handling is locked in by regression tests. * Harden Databricks App OAuth token cache Cap the cache TTL at the token's own lifetime so a token whose validity is shorter than the refresh buffer is never cached and served stale; include a digest of client_secret in the cache key so a rotated secret mints a fresh token instead of reusing the old one; and prune the per-key lock when its cached token is evicted so the lock map stays bounded by the live key set. * Clear per-key locks on Databricks OAuth cache flush * fix(a2a/databricks): mint OAuth token via Basic auth header, not unsupported auth= kwarg litellm's AsyncHTTPHandler.post (what get_async_httpx_client returns) has no auth parameter, so minting a Databricks App OAuth token raised "AsyncHTTPHandler.post() got an unexpected keyword argument 'auth'" before any network call ever left the proxy, breaking the feature end to end. The handler also calls raise_for_status() internally and re-raises a MaskedHTTPStatusError (a subclass of httpx.HTTPStatusError), so the explicit raise_for_status() after post() was dead code. Build the HTTP Basic Authorization header by hand and pass it via headers, which is what the Databricks workspace OIDC token endpoint documents for client authentication. The token-cache tests now model the real handler contract with create_autospec so the rejected auth= signature is enforced; the previous mocks accepted any kwargs and silently hid the bug. Co-authored-by: Mateo Wang <mateo-berri@users.noreply.github.com> * Prune Databricks OAuth lock on the short-lived-token path When expires_in is below the refresh buffer the token is intentionally not cached, so _remove_key never runs for that key and the per-key lock created by _get_lock leaked permanently. Drop the lock in that branch so _locks stays bounded by the live key set, and assert the cleanup in the short-lived-token test * Gate A2A Databricks OAuth on the databricks_oauth block at the call site Make the gating explicit where the header is applied so it is clear that only agents configured with a databricks_oauth block enter the OAuth path; every other agent is left untouched. Add a regression test asserting a non-Databricks agent never invokes the token resolver. --------- Co-authored-by: Cursor Agent <cursoragent@cursor.com> Co-authored-by: Mateo Wang <mateo-berri@users.noreply.github.com>
441 lines
15 KiB
Python
441 lines
15 KiB
Python
"""
|
|
Unit tests for A2A agent custom header forwarding.
|
|
|
|
Tests cover:
|
|
- Static headers forwarded to backend agent
|
|
- Dynamic headers extracted from client request and forwarded
|
|
- Static headers win over dynamic on conflict
|
|
- No headers configured — existing behavior unchanged
|
|
- merge_agent_headers utility
|
|
"""
|
|
|
|
import sys
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
import pytest
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Helper: build a minimal mock agent
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def _make_mock_agent(
|
|
static_headers=None,
|
|
extra_headers=None,
|
|
url="http://backend-agent:10001",
|
|
):
|
|
mock_agent = MagicMock()
|
|
mock_agent.agent_id = "agent-123"
|
|
mock_agent.agent_card_params = {"url": url, "name": "Test Agent"}
|
|
mock_agent.litellm_params = {}
|
|
mock_agent.static_headers = static_headers or {}
|
|
mock_agent.extra_headers = extra_headers or []
|
|
return mock_agent
|
|
|
|
|
|
def _make_mock_request(extra_headers=None, method="message/send"):
|
|
"""Build a mock FastAPI Request with configurable headers."""
|
|
mock_request = MagicMock()
|
|
headers = {"content-type": "application/json"}
|
|
if extra_headers:
|
|
headers.update(extra_headers)
|
|
mock_request.headers = headers
|
|
mock_request.json = AsyncMock(
|
|
return_value={
|
|
"jsonrpc": "2.0",
|
|
"id": "test-id",
|
|
"method": method,
|
|
"params": {
|
|
"message": {
|
|
"role": "user",
|
|
"parts": [{"kind": "text", "text": "Hello"}],
|
|
"messageId": "msg-123",
|
|
}
|
|
},
|
|
}
|
|
)
|
|
return mock_request
|
|
|
|
|
|
def _make_a2a_types_module():
|
|
"""Return (module, MessageSendParams, SendMessageRequest, SendStreamingMessageRequest)."""
|
|
try:
|
|
from a2a.types import (
|
|
MessageSendParams,
|
|
SendMessageRequest,
|
|
SendStreamingMessageRequest,
|
|
)
|
|
|
|
mock_a2a_types = MagicMock()
|
|
mock_a2a_types.MessageSendParams = MessageSendParams
|
|
mock_a2a_types.SendMessageRequest = SendMessageRequest
|
|
mock_a2a_types.SendStreamingMessageRequest = SendStreamingMessageRequest
|
|
return mock_a2a_types
|
|
except ImportError:
|
|
pass
|
|
|
|
def _make_cls(name):
|
|
class MockCls:
|
|
def __init__(self, **kwargs):
|
|
self.__dict__.update(kwargs)
|
|
self._kwargs = kwargs
|
|
|
|
def model_dump(self, mode="json", exclude_none=False):
|
|
result = dict(self._kwargs)
|
|
if exclude_none:
|
|
result = {k: v for k, v in result.items() if v is not None}
|
|
return result
|
|
|
|
MockCls.__name__ = name
|
|
return MockCls
|
|
|
|
mock_a2a_types = MagicMock()
|
|
mock_a2a_types.MessageSendParams = _make_cls("MessageSendParams")
|
|
mock_a2a_types.SendMessageRequest = _make_cls("SendMessageRequest")
|
|
mock_a2a_types.SendStreamingMessageRequest = _make_cls(
|
|
"SendStreamingMessageRequest"
|
|
)
|
|
return mock_a2a_types
|
|
|
|
|
|
async def _invoke(mock_agent, mock_request, mock_asend_message):
|
|
"""Run invoke_agent_a2a with standard patches applied."""
|
|
from litellm.proxy._types import UserAPIKeyAuth
|
|
|
|
mock_user_api_key_dict = UserAPIKeyAuth(api_key="sk-test", user_id="u1")
|
|
mock_fastapi_response = MagicMock()
|
|
mock_a2a_types = _make_a2a_types_module()
|
|
|
|
mock_response = MagicMock()
|
|
mock_response.model_dump.return_value = {
|
|
"jsonrpc": "2.0",
|
|
"id": "test-id",
|
|
"result": {"status": "success"},
|
|
}
|
|
|
|
with (
|
|
patch(
|
|
"litellm.proxy.agent_endpoints.a2a_endpoints._get_agent",
|
|
return_value=mock_agent,
|
|
),
|
|
patch(
|
|
"litellm.proxy.agent_endpoints.auth.agent_permission_handler.AgentRequestHandler.is_agent_allowed",
|
|
new_callable=AsyncMock,
|
|
return_value=True,
|
|
),
|
|
patch(
|
|
"litellm.proxy.common_request_processing.add_litellm_data_to_request",
|
|
side_effect=lambda data, **kw: data,
|
|
),
|
|
patch(
|
|
"litellm.a2a_protocol.asend_message",
|
|
new_callable=AsyncMock,
|
|
return_value=mock_response,
|
|
) as mock_asend,
|
|
patch(
|
|
"litellm.a2a_protocol.create_a2a_client",
|
|
new_callable=AsyncMock,
|
|
),
|
|
patch(
|
|
"litellm.proxy.proxy_server.general_settings",
|
|
{},
|
|
),
|
|
patch(
|
|
"litellm.proxy.proxy_server.proxy_config",
|
|
MagicMock(),
|
|
),
|
|
patch(
|
|
"litellm.proxy.proxy_server.version",
|
|
"1.0.0",
|
|
),
|
|
patch.dict(
|
|
sys.modules,
|
|
{"a2a": MagicMock(), "a2a.types": mock_a2a_types},
|
|
),
|
|
patch(
|
|
"litellm.a2a_protocol.main.A2A_SDK_AVAILABLE",
|
|
True,
|
|
),
|
|
):
|
|
from litellm.proxy.agent_endpoints.a2a_endpoints import invoke_agent_a2a
|
|
|
|
await invoke_agent_a2a(
|
|
agent_id="test-agent",
|
|
request=mock_request,
|
|
fastapi_response=mock_fastapi_response,
|
|
user_api_key_dict=mock_user_api_key_dict,
|
|
)
|
|
return mock_asend
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_static_headers_forwarded():
|
|
"""Static headers configured on the agent are passed to asend_message."""
|
|
mock_agent = _make_mock_agent(static_headers={"Authorization": "Bearer token123"})
|
|
mock_request = _make_mock_request()
|
|
|
|
mock_asend = await _invoke(mock_agent, mock_request, None)
|
|
|
|
call_kwargs = mock_asend.call_args.kwargs
|
|
headers = call_kwargs.get("agent_extra_headers")
|
|
assert headers is not None, "agent_extra_headers should not be None"
|
|
assert headers.get("Authorization") == "Bearer token123"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_dynamic_headers_forwarded():
|
|
"""Dynamic headers listed in extra_headers are extracted from the client request."""
|
|
mock_agent = _make_mock_agent(extra_headers=["x-api-key"])
|
|
mock_request = _make_mock_request(extra_headers={"x-api-key": "secret"})
|
|
|
|
mock_asend = await _invoke(mock_agent, mock_request, None)
|
|
|
|
call_kwargs = mock_asend.call_args.kwargs
|
|
headers = call_kwargs.get("agent_extra_headers")
|
|
assert headers is not None
|
|
assert headers.get("x-api-key") == "secret"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_static_overrides_dynamic():
|
|
"""When the same header appears in both static and dynamic, static wins."""
|
|
mock_agent = _make_mock_agent(
|
|
static_headers={"Authorization": "Bearer static-token"},
|
|
extra_headers=["Authorization"],
|
|
)
|
|
# Client sends a different value for Authorization
|
|
mock_request = _make_mock_request(
|
|
extra_headers={"Authorization": "Bearer dynamic-token"}
|
|
)
|
|
|
|
mock_asend = await _invoke(mock_agent, mock_request, None)
|
|
|
|
call_kwargs = mock_asend.call_args.kwargs
|
|
headers = call_kwargs.get("agent_extra_headers")
|
|
assert headers is not None
|
|
assert headers.get("Authorization") == "Bearer static-token"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_no_headers():
|
|
"""When no headers are configured, agent_extra_headers is None and behaviour is unchanged."""
|
|
mock_agent = _make_mock_agent() # no static_headers or extra_headers
|
|
mock_request = _make_mock_request()
|
|
|
|
mock_asend = await _invoke(mock_agent, mock_request, None)
|
|
|
|
call_kwargs = mock_asend.call_args.kwargs
|
|
headers = call_kwargs.get("agent_extra_headers")
|
|
assert headers is None
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Convention-based x-a2a-{agent_id/name}-{header_name} tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_convention_header_by_agent_name():
|
|
"""x-a2a-{agent_name}-{header} is forwarded using the agent name alias."""
|
|
mock_agent = _make_mock_agent()
|
|
mock_agent.agent_name = "my-agent"
|
|
mock_request = _make_mock_request(
|
|
extra_headers={"x-a2a-my-agent-authorization": "Bearer conv-token"}
|
|
)
|
|
|
|
mock_asend = await _invoke(mock_agent, mock_request, None)
|
|
|
|
headers = mock_asend.call_args.kwargs.get("agent_extra_headers")
|
|
assert headers is not None
|
|
assert headers.get("authorization") == "Bearer conv-token"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_convention_header_by_agent_id():
|
|
"""x-a2a-{agent_id}-{header} is forwarded using the agent UUID."""
|
|
mock_agent = _make_mock_agent()
|
|
mock_agent.agent_id = "abc-123"
|
|
mock_agent.agent_name = "other-name"
|
|
mock_request = _make_mock_request(
|
|
extra_headers={"x-a2a-abc-123-x-api-key": "id-secret"}
|
|
)
|
|
|
|
mock_asend = await _invoke(mock_agent, mock_request, None)
|
|
|
|
headers = mock_asend.call_args.kwargs.get("agent_extra_headers")
|
|
assert headers is not None
|
|
assert headers.get("x-api-key") == "id-secret"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_convention_header_static_still_wins():
|
|
"""Static headers still override convention-based dynamic headers."""
|
|
mock_agent = _make_mock_agent(
|
|
static_headers={"authorization": "Bearer static-wins"}
|
|
)
|
|
mock_agent.agent_name = "my-agent"
|
|
mock_request = _make_mock_request(
|
|
extra_headers={"x-a2a-my-agent-authorization": "Bearer conv-value"}
|
|
)
|
|
|
|
mock_asend = await _invoke(mock_agent, mock_request, None)
|
|
|
|
headers = mock_asend.call_args.kwargs.get("agent_extra_headers")
|
|
assert headers is not None
|
|
assert headers.get("authorization") == "Bearer static-wins"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_convention_unrelated_prefix_not_forwarded():
|
|
"""Headers that start with x-a2a- but target a different agent are ignored."""
|
|
mock_agent = _make_mock_agent()
|
|
mock_agent.agent_id = "agent-abc"
|
|
mock_agent.agent_name = "my-agent"
|
|
mock_request = _make_mock_request(
|
|
extra_headers={"x-a2a-other-agent-authorization": "Bearer wrong"}
|
|
)
|
|
|
|
mock_asend = await _invoke(mock_agent, mock_request, None)
|
|
|
|
headers = mock_asend.call_args.kwargs.get("agent_extra_headers")
|
|
assert headers is None
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Databricks App OAuth M2M injection
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def _mock_databricks_token_client(access_token="dbx-oauth-token"):
|
|
response = MagicMock()
|
|
response.raise_for_status = MagicMock()
|
|
response.json = MagicMock(
|
|
return_value={"access_token": access_token, "expires_in": 3600}
|
|
)
|
|
client = MagicMock()
|
|
client.post = AsyncMock(return_value=response)
|
|
return client
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_databricks_oauth_header_injected():
|
|
"""A databricks_oauth block mints an outbound Bearer Authorization header."""
|
|
from litellm.proxy.agent_endpoints import databricks_oauth
|
|
|
|
databricks_oauth.databricks_app_oauth_token_cache.flush_cache()
|
|
|
|
mock_agent = _make_mock_agent()
|
|
mock_agent.litellm_params = {
|
|
"databricks_oauth": {
|
|
"client_id": "cid",
|
|
"client_secret": "secret",
|
|
"workspace_url": "https://dbc.cloud.databricks.com",
|
|
}
|
|
}
|
|
mock_request = _make_mock_request()
|
|
|
|
with patch(
|
|
"litellm.proxy.agent_endpoints.databricks_oauth.get_async_httpx_client",
|
|
return_value=_mock_databricks_token_client("minted-token"),
|
|
):
|
|
mock_asend = await _invoke(mock_agent, mock_request, None)
|
|
|
|
headers = mock_asend.call_args.kwargs.get("agent_extra_headers")
|
|
assert headers is not None
|
|
assert headers.get("Authorization") == "Bearer minted-token"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_databricks_oauth_overrides_static_authorization():
|
|
"""The minted OAuth token wins over a statically configured Authorization."""
|
|
from litellm.proxy.agent_endpoints import databricks_oauth
|
|
|
|
databricks_oauth.databricks_app_oauth_token_cache.flush_cache()
|
|
|
|
mock_agent = _make_mock_agent(static_headers={"Authorization": "Bearer static-pat"})
|
|
mock_agent.litellm_params = {
|
|
"databricks_oauth": {
|
|
"client_id": "cid",
|
|
"client_secret": "secret",
|
|
"workspace_url": "https://dbc.cloud.databricks.com",
|
|
}
|
|
}
|
|
mock_request = _make_mock_request()
|
|
|
|
with patch(
|
|
"litellm.proxy.agent_endpoints.databricks_oauth.get_async_httpx_client",
|
|
return_value=_mock_databricks_token_client("oauth-wins"),
|
|
):
|
|
mock_asend = await _invoke(mock_agent, mock_request, None)
|
|
|
|
headers = mock_asend.call_args.kwargs.get("agent_extra_headers")
|
|
assert headers is not None
|
|
assert headers.get("Authorization") == "Bearer oauth-wins"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_non_databricks_agent_skips_oauth_resolution():
|
|
"""Agents without a databricks_oauth block never enter the OAuth path."""
|
|
mock_agent = _make_mock_agent(static_headers={"x-custom": "v"})
|
|
mock_agent.litellm_params = {"require_trace_id_on_calls_to_agent": False}
|
|
mock_request = _make_mock_request()
|
|
|
|
with patch(
|
|
"litellm.proxy.agent_endpoints.a2a_endpoints.resolve_databricks_app_auth_header",
|
|
new_callable=AsyncMock,
|
|
) as mock_resolve:
|
|
mock_asend = await _invoke(mock_agent, mock_request, None)
|
|
|
|
mock_resolve.assert_not_called()
|
|
headers = mock_asend.call_args.kwargs.get("agent_extra_headers")
|
|
assert headers == {"x-custom": "v"}
|
|
assert "Authorization" not in headers
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Direct unit test for the merge utility
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def test_merge_agent_headers_util_dynamic_only():
|
|
from litellm.proxy.agent_endpoints.utils import merge_agent_headers
|
|
|
|
result = merge_agent_headers(dynamic_headers={"x-key": "val"})
|
|
assert result == {"x-key": "val"}
|
|
|
|
|
|
def test_merge_agent_headers_util_static_only():
|
|
from litellm.proxy.agent_endpoints.utils import merge_agent_headers
|
|
|
|
result = merge_agent_headers(static_headers={"Authorization": "Bearer tok"})
|
|
assert result == {"Authorization": "Bearer tok"}
|
|
|
|
|
|
def test_merge_agent_headers_util_static_wins():
|
|
from litellm.proxy.agent_endpoints.utils import merge_agent_headers
|
|
|
|
result = merge_agent_headers(
|
|
dynamic_headers={"Authorization": "dynamic", "x-extra": "d"},
|
|
static_headers={"Authorization": "static"},
|
|
)
|
|
assert result == {"Authorization": "static", "x-extra": "d"}
|
|
|
|
|
|
def test_merge_agent_headers_util_none_returns_none():
|
|
from litellm.proxy.agent_endpoints.utils import merge_agent_headers
|
|
|
|
result = merge_agent_headers()
|
|
assert result is None
|
|
|
|
|
|
def test_merge_agent_headers_util_empty_dicts_returns_none():
|
|
from litellm.proxy.agent_endpoints.utils import merge_agent_headers
|
|
|
|
result = merge_agent_headers(dynamic_headers={}, static_headers={})
|
|
assert result is None
|