Files
litellm/tests/test_litellm/proxy/agent_endpoints/test_agent_headers.py
T
Mateo Wang 778a7f752d Support OAuth M2M for Databricks Apps A2A agents (#29586)
* Add OAuth M2M support for A2A agents targeting Databricks Apps

Databricks App endpoints reject static bearer tokens and require a
short-lived OAuth token minted via the workspace OIDC token endpoint.
A2A agents could previously only authenticate outbound with static_headers
or client header passthrough, so Databricks App agents could not be
registered.

Agents configured with a databricks_oauth block in litellm_params now mint
and cache a client_credentials token and attach it as the outbound
Authorization header on both message/send and message/stream calls,
overriding any statically configured Authorization.

* Add tests covering Databricks App OAuth token error paths

Cover the HTTP status error, transport error, non-object JSON body, and
invalid expires_in fallback branches in the token cache so the failure
handling is locked in by regression tests.

* Harden Databricks App OAuth token cache

Cap the cache TTL at the token's own lifetime so a token whose validity is
shorter than the refresh buffer is never cached and served stale; include a
digest of client_secret in the cache key so a rotated secret mints a fresh
token instead of reusing the old one; and prune the per-key lock when its
cached token is evicted so the lock map stays bounded by the live key set.

* Clear per-key locks on Databricks OAuth cache flush

* fix(a2a/databricks): mint OAuth token via Basic auth header, not unsupported auth= kwarg

litellm's AsyncHTTPHandler.post (what get_async_httpx_client returns) has no
auth parameter, so minting a Databricks App OAuth token raised
"AsyncHTTPHandler.post() got an unexpected keyword argument 'auth'" before any
network call ever left the proxy, breaking the feature end to end. The handler
also calls raise_for_status() internally and re-raises a MaskedHTTPStatusError
(a subclass of httpx.HTTPStatusError), so the explicit raise_for_status() after
post() was dead code.

Build the HTTP Basic Authorization header by hand and pass it via headers, which
is what the Databricks workspace OIDC token endpoint documents for client
authentication. The token-cache tests now model the real handler contract with
create_autospec so the rejected auth= signature is enforced; the previous mocks
accepted any kwargs and silently hid the bug.

Co-authored-by: Mateo Wang <mateo-berri@users.noreply.github.com>

* Prune Databricks OAuth lock on the short-lived-token path

When expires_in is below the refresh buffer the token is intentionally
not cached, so _remove_key never runs for that key and the per-key lock
created by _get_lock leaked permanently. Drop the lock in that branch so
_locks stays bounded by the live key set, and assert the cleanup in the
short-lived-token test

* Gate A2A Databricks OAuth on the databricks_oauth block at the call site

Make the gating explicit where the header is applied so it is clear that only
agents configured with a databricks_oauth block enter the OAuth path; every
other agent is left untouched. Add a regression test asserting a non-Databricks
agent never invokes the token resolver.

---------

Co-authored-by: Cursor Agent <cursoragent@cursor.com>
Co-authored-by: Mateo Wang <mateo-berri@users.noreply.github.com>
2026-06-04 23:03:37 -07:00

441 lines
15 KiB
Python

"""
Unit tests for A2A agent custom header forwarding.
Tests cover:
- Static headers forwarded to backend agent
- Dynamic headers extracted from client request and forwarded
- Static headers win over dynamic on conflict
- No headers configured — existing behavior unchanged
- merge_agent_headers utility
"""
import sys
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
# ---------------------------------------------------------------------------
# Helper: build a minimal mock agent
# ---------------------------------------------------------------------------
def _make_mock_agent(
static_headers=None,
extra_headers=None,
url="http://backend-agent:10001",
):
mock_agent = MagicMock()
mock_agent.agent_id = "agent-123"
mock_agent.agent_card_params = {"url": url, "name": "Test Agent"}
mock_agent.litellm_params = {}
mock_agent.static_headers = static_headers or {}
mock_agent.extra_headers = extra_headers or []
return mock_agent
def _make_mock_request(extra_headers=None, method="message/send"):
"""Build a mock FastAPI Request with configurable headers."""
mock_request = MagicMock()
headers = {"content-type": "application/json"}
if extra_headers:
headers.update(extra_headers)
mock_request.headers = headers
mock_request.json = AsyncMock(
return_value={
"jsonrpc": "2.0",
"id": "test-id",
"method": method,
"params": {
"message": {
"role": "user",
"parts": [{"kind": "text", "text": "Hello"}],
"messageId": "msg-123",
}
},
}
)
return mock_request
def _make_a2a_types_module():
"""Return (module, MessageSendParams, SendMessageRequest, SendStreamingMessageRequest)."""
try:
from a2a.types import (
MessageSendParams,
SendMessageRequest,
SendStreamingMessageRequest,
)
mock_a2a_types = MagicMock()
mock_a2a_types.MessageSendParams = MessageSendParams
mock_a2a_types.SendMessageRequest = SendMessageRequest
mock_a2a_types.SendStreamingMessageRequest = SendStreamingMessageRequest
return mock_a2a_types
except ImportError:
pass
def _make_cls(name):
class MockCls:
def __init__(self, **kwargs):
self.__dict__.update(kwargs)
self._kwargs = kwargs
def model_dump(self, mode="json", exclude_none=False):
result = dict(self._kwargs)
if exclude_none:
result = {k: v for k, v in result.items() if v is not None}
return result
MockCls.__name__ = name
return MockCls
mock_a2a_types = MagicMock()
mock_a2a_types.MessageSendParams = _make_cls("MessageSendParams")
mock_a2a_types.SendMessageRequest = _make_cls("SendMessageRequest")
mock_a2a_types.SendStreamingMessageRequest = _make_cls(
"SendStreamingMessageRequest"
)
return mock_a2a_types
async def _invoke(mock_agent, mock_request, mock_asend_message):
"""Run invoke_agent_a2a with standard patches applied."""
from litellm.proxy._types import UserAPIKeyAuth
mock_user_api_key_dict = UserAPIKeyAuth(api_key="sk-test", user_id="u1")
mock_fastapi_response = MagicMock()
mock_a2a_types = _make_a2a_types_module()
mock_response = MagicMock()
mock_response.model_dump.return_value = {
"jsonrpc": "2.0",
"id": "test-id",
"result": {"status": "success"},
}
with (
patch(
"litellm.proxy.agent_endpoints.a2a_endpoints._get_agent",
return_value=mock_agent,
),
patch(
"litellm.proxy.agent_endpoints.auth.agent_permission_handler.AgentRequestHandler.is_agent_allowed",
new_callable=AsyncMock,
return_value=True,
),
patch(
"litellm.proxy.common_request_processing.add_litellm_data_to_request",
side_effect=lambda data, **kw: data,
),
patch(
"litellm.a2a_protocol.asend_message",
new_callable=AsyncMock,
return_value=mock_response,
) as mock_asend,
patch(
"litellm.a2a_protocol.create_a2a_client",
new_callable=AsyncMock,
),
patch(
"litellm.proxy.proxy_server.general_settings",
{},
),
patch(
"litellm.proxy.proxy_server.proxy_config",
MagicMock(),
),
patch(
"litellm.proxy.proxy_server.version",
"1.0.0",
),
patch.dict(
sys.modules,
{"a2a": MagicMock(), "a2a.types": mock_a2a_types},
),
patch(
"litellm.a2a_protocol.main.A2A_SDK_AVAILABLE",
True,
),
):
from litellm.proxy.agent_endpoints.a2a_endpoints import invoke_agent_a2a
await invoke_agent_a2a(
agent_id="test-agent",
request=mock_request,
fastapi_response=mock_fastapi_response,
user_api_key_dict=mock_user_api_key_dict,
)
return mock_asend
# ---------------------------------------------------------------------------
# Tests
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_static_headers_forwarded():
"""Static headers configured on the agent are passed to asend_message."""
mock_agent = _make_mock_agent(static_headers={"Authorization": "Bearer token123"})
mock_request = _make_mock_request()
mock_asend = await _invoke(mock_agent, mock_request, None)
call_kwargs = mock_asend.call_args.kwargs
headers = call_kwargs.get("agent_extra_headers")
assert headers is not None, "agent_extra_headers should not be None"
assert headers.get("Authorization") == "Bearer token123"
@pytest.mark.asyncio
async def test_dynamic_headers_forwarded():
"""Dynamic headers listed in extra_headers are extracted from the client request."""
mock_agent = _make_mock_agent(extra_headers=["x-api-key"])
mock_request = _make_mock_request(extra_headers={"x-api-key": "secret"})
mock_asend = await _invoke(mock_agent, mock_request, None)
call_kwargs = mock_asend.call_args.kwargs
headers = call_kwargs.get("agent_extra_headers")
assert headers is not None
assert headers.get("x-api-key") == "secret"
@pytest.mark.asyncio
async def test_static_overrides_dynamic():
"""When the same header appears in both static and dynamic, static wins."""
mock_agent = _make_mock_agent(
static_headers={"Authorization": "Bearer static-token"},
extra_headers=["Authorization"],
)
# Client sends a different value for Authorization
mock_request = _make_mock_request(
extra_headers={"Authorization": "Bearer dynamic-token"}
)
mock_asend = await _invoke(mock_agent, mock_request, None)
call_kwargs = mock_asend.call_args.kwargs
headers = call_kwargs.get("agent_extra_headers")
assert headers is not None
assert headers.get("Authorization") == "Bearer static-token"
@pytest.mark.asyncio
async def test_no_headers():
"""When no headers are configured, agent_extra_headers is None and behaviour is unchanged."""
mock_agent = _make_mock_agent() # no static_headers or extra_headers
mock_request = _make_mock_request()
mock_asend = await _invoke(mock_agent, mock_request, None)
call_kwargs = mock_asend.call_args.kwargs
headers = call_kwargs.get("agent_extra_headers")
assert headers is None
# ---------------------------------------------------------------------------
# Convention-based x-a2a-{agent_id/name}-{header_name} tests
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_convention_header_by_agent_name():
"""x-a2a-{agent_name}-{header} is forwarded using the agent name alias."""
mock_agent = _make_mock_agent()
mock_agent.agent_name = "my-agent"
mock_request = _make_mock_request(
extra_headers={"x-a2a-my-agent-authorization": "Bearer conv-token"}
)
mock_asend = await _invoke(mock_agent, mock_request, None)
headers = mock_asend.call_args.kwargs.get("agent_extra_headers")
assert headers is not None
assert headers.get("authorization") == "Bearer conv-token"
@pytest.mark.asyncio
async def test_convention_header_by_agent_id():
"""x-a2a-{agent_id}-{header} is forwarded using the agent UUID."""
mock_agent = _make_mock_agent()
mock_agent.agent_id = "abc-123"
mock_agent.agent_name = "other-name"
mock_request = _make_mock_request(
extra_headers={"x-a2a-abc-123-x-api-key": "id-secret"}
)
mock_asend = await _invoke(mock_agent, mock_request, None)
headers = mock_asend.call_args.kwargs.get("agent_extra_headers")
assert headers is not None
assert headers.get("x-api-key") == "id-secret"
@pytest.mark.asyncio
async def test_convention_header_static_still_wins():
"""Static headers still override convention-based dynamic headers."""
mock_agent = _make_mock_agent(
static_headers={"authorization": "Bearer static-wins"}
)
mock_agent.agent_name = "my-agent"
mock_request = _make_mock_request(
extra_headers={"x-a2a-my-agent-authorization": "Bearer conv-value"}
)
mock_asend = await _invoke(mock_agent, mock_request, None)
headers = mock_asend.call_args.kwargs.get("agent_extra_headers")
assert headers is not None
assert headers.get("authorization") == "Bearer static-wins"
@pytest.mark.asyncio
async def test_convention_unrelated_prefix_not_forwarded():
"""Headers that start with x-a2a- but target a different agent are ignored."""
mock_agent = _make_mock_agent()
mock_agent.agent_id = "agent-abc"
mock_agent.agent_name = "my-agent"
mock_request = _make_mock_request(
extra_headers={"x-a2a-other-agent-authorization": "Bearer wrong"}
)
mock_asend = await _invoke(mock_agent, mock_request, None)
headers = mock_asend.call_args.kwargs.get("agent_extra_headers")
assert headers is None
# ---------------------------------------------------------------------------
# Databricks App OAuth M2M injection
# ---------------------------------------------------------------------------
def _mock_databricks_token_client(access_token="dbx-oauth-token"):
response = MagicMock()
response.raise_for_status = MagicMock()
response.json = MagicMock(
return_value={"access_token": access_token, "expires_in": 3600}
)
client = MagicMock()
client.post = AsyncMock(return_value=response)
return client
@pytest.mark.asyncio
async def test_databricks_oauth_header_injected():
"""A databricks_oauth block mints an outbound Bearer Authorization header."""
from litellm.proxy.agent_endpoints import databricks_oauth
databricks_oauth.databricks_app_oauth_token_cache.flush_cache()
mock_agent = _make_mock_agent()
mock_agent.litellm_params = {
"databricks_oauth": {
"client_id": "cid",
"client_secret": "secret",
"workspace_url": "https://dbc.cloud.databricks.com",
}
}
mock_request = _make_mock_request()
with patch(
"litellm.proxy.agent_endpoints.databricks_oauth.get_async_httpx_client",
return_value=_mock_databricks_token_client("minted-token"),
):
mock_asend = await _invoke(mock_agent, mock_request, None)
headers = mock_asend.call_args.kwargs.get("agent_extra_headers")
assert headers is not None
assert headers.get("Authorization") == "Bearer minted-token"
@pytest.mark.asyncio
async def test_databricks_oauth_overrides_static_authorization():
"""The minted OAuth token wins over a statically configured Authorization."""
from litellm.proxy.agent_endpoints import databricks_oauth
databricks_oauth.databricks_app_oauth_token_cache.flush_cache()
mock_agent = _make_mock_agent(static_headers={"Authorization": "Bearer static-pat"})
mock_agent.litellm_params = {
"databricks_oauth": {
"client_id": "cid",
"client_secret": "secret",
"workspace_url": "https://dbc.cloud.databricks.com",
}
}
mock_request = _make_mock_request()
with patch(
"litellm.proxy.agent_endpoints.databricks_oauth.get_async_httpx_client",
return_value=_mock_databricks_token_client("oauth-wins"),
):
mock_asend = await _invoke(mock_agent, mock_request, None)
headers = mock_asend.call_args.kwargs.get("agent_extra_headers")
assert headers is not None
assert headers.get("Authorization") == "Bearer oauth-wins"
@pytest.mark.asyncio
async def test_non_databricks_agent_skips_oauth_resolution():
"""Agents without a databricks_oauth block never enter the OAuth path."""
mock_agent = _make_mock_agent(static_headers={"x-custom": "v"})
mock_agent.litellm_params = {"require_trace_id_on_calls_to_agent": False}
mock_request = _make_mock_request()
with patch(
"litellm.proxy.agent_endpoints.a2a_endpoints.resolve_databricks_app_auth_header",
new_callable=AsyncMock,
) as mock_resolve:
mock_asend = await _invoke(mock_agent, mock_request, None)
mock_resolve.assert_not_called()
headers = mock_asend.call_args.kwargs.get("agent_extra_headers")
assert headers == {"x-custom": "v"}
assert "Authorization" not in headers
# ---------------------------------------------------------------------------
# Direct unit test for the merge utility
# ---------------------------------------------------------------------------
def test_merge_agent_headers_util_dynamic_only():
from litellm.proxy.agent_endpoints.utils import merge_agent_headers
result = merge_agent_headers(dynamic_headers={"x-key": "val"})
assert result == {"x-key": "val"}
def test_merge_agent_headers_util_static_only():
from litellm.proxy.agent_endpoints.utils import merge_agent_headers
result = merge_agent_headers(static_headers={"Authorization": "Bearer tok"})
assert result == {"Authorization": "Bearer tok"}
def test_merge_agent_headers_util_static_wins():
from litellm.proxy.agent_endpoints.utils import merge_agent_headers
result = merge_agent_headers(
dynamic_headers={"Authorization": "dynamic", "x-extra": "d"},
static_headers={"Authorization": "static"},
)
assert result == {"Authorization": "static", "x-extra": "d"}
def test_merge_agent_headers_util_none_returns_none():
from litellm.proxy.agent_endpoints.utils import merge_agent_headers
result = merge_agent_headers()
assert result is None
def test_merge_agent_headers_util_empty_dicts_returns_none():
from litellm.proxy.agent_endpoints.utils import merge_agent_headers
result = merge_agent_headers(dynamic_headers={}, static_headers={})
assert result is None