fix: working backend agent tracing with MCP tool calls

This commit is contained in:
Krrish Dholakia
2026-03-03 11:28:36 -08:00
parent bde027f0d4
commit 6feb9babc1
4 changed files with 71 additions and 6 deletions
@@ -32,7 +32,8 @@ from litellm.proxy._experimental.mcp_server.utils import (
LITELLM_MCP_SERVER_VERSION)
from litellm.proxy._types import UserAPIKeyAuth
from litellm.proxy.auth.ip_address_utils import IPAddressUtils
from litellm.proxy.litellm_pre_call_utils import LiteLLMProxyRequestSetup
from litellm.proxy.litellm_pre_call_utils import (LiteLLMProxyRequestSetup,
get_chain_id_from_headers)
from litellm.types.mcp import MCPAuth
from litellm.types.mcp_server.mcp_server_manager import MCPInfo, MCPServer
from litellm.types.utils import CallTypes, StandardLoggingMCPToolCall
@@ -313,6 +314,11 @@ if MCP_AVAILABLE:
try:
# Create a body date for logging
body_data = {"name": name, "arguments": arguments}
# Set trace/session id from raw_headers so spend logs and logging_obj stay consistent (same as A2A)
chain_id = get_chain_id_from_headers(raw_headers)
if chain_id:
body_data["litellm_trace_id"] = chain_id
body_data["litellm_session_id"] = chain_id
request = Request(
scope={
@@ -863,6 +869,10 @@ if MCP_AVAILABLE:
# This is intentionally minimal: only async_success_handler / post_call_failure_hook
rules_obj = Rules()
list_tools_call_id = str(uuid.uuid4())
# Derive trace_id from raw_headers when not explicitly passed (same as A2A / MCP call_tool)
effective_litellm_trace_id = litellm_trace_id or get_chain_id_from_headers(
raw_headers
)
spend_logs_metadata: Dict[str, Any] = {
"mcp_operation": "list_tools",
}
@@ -875,7 +885,7 @@ if MCP_AVAILABLE:
"model": "MCP: list_tools",
"call_type": CallTypes.list_mcp_tools.value,
"litellm_call_id": list_tools_call_id,
"litellm_trace_id": litellm_trace_id,
"litellm_trace_id": effective_litellm_trace_id,
"metadata": {
"spend_logs_metadata": spend_logs_metadata,
},
+39 -4
View File
@@ -86,11 +86,35 @@ async def batch_upsert_tools(
verbose_proxy_logger.error("tool_registry_writer batch_upsert_tools error: %s", e)
async def _get_agent_ids_for_key_hashes(
prisma_client: "PrismaClient",
key_hashes: List[str],
) -> Dict[str, str]:
"""Resolve agent_id from key table for each key_hash. Returns map token -> agent_id."""
if not key_hashes:
return {}
try:
key_records = await prisma_client.db.litellm_verificationtoken.find_many(
where={"token": {"in": key_hashes}},
select={"token": True, "agent_id": True},
)
return {
k.token: k.agent_id
for k in key_records
if k.agent_id is not None
}
except Exception as e:
verbose_proxy_logger.debug(
"tool_registry_writer _get_agent_ids_for_key_hashes error: %s", e
)
return {}
async def list_tools(
prisma_client: "PrismaClient",
call_policy: Optional[ToolCallPolicy] = None,
) -> List[LiteLLM_ToolTableRow]:
"""Return all tools, optionally filtered by call_policy."""
"""Return all tools, optionally filtered by call_policy. Enriches each row with agent_id from key table."""
try:
if call_policy is not None:
rows = await prisma_client.db.query_raw(
@@ -105,7 +129,12 @@ async def list_tools(
'key_hash, team_id, key_alias, created_at, updated_at, created_by, updated_by '
'FROM "LiteLLM_ToolTable" ORDER BY created_at DESC',
)
return [_row_to_model(row) for row in rows]
tools = [_row_to_model(row) for row in rows]
key_hashes = list({t.key_hash for t in tools if t.key_hash})
key_to_agent = await _get_agent_ids_for_key_hashes(prisma_client, key_hashes)
for t in tools:
t.agent_id = key_to_agent.get(t.key_hash) if t.key_hash else None
return tools
except Exception as e:
verbose_proxy_logger.error("tool_registry_writer list_tools error: %s", e)
return []
@@ -115,7 +144,7 @@ async def get_tool(
prisma_client: "PrismaClient",
tool_name: str,
) -> Optional[LiteLLM_ToolTableRow]:
"""Return a single tool row by tool_name."""
"""Return a single tool row by tool_name. Enriches with agent_id from key table if key_hash is set."""
try:
rows = await prisma_client.db.query_raw(
'SELECT tool_id, tool_name, origin, call_policy, call_count, assignments, '
@@ -125,7 +154,13 @@ async def get_tool(
)
if not rows:
return None
return _row_to_model(rows[0])
tool = _row_to_model(rows[0])
if tool.key_hash:
key_to_agent = await _get_agent_ids_for_key_hashes(
prisma_client, [tool.key_hash]
)
tool.agent_id = key_to_agent.get(tool.key_hash)
return tool
except Exception as e:
verbose_proxy_logger.error("tool_registry_writer get_tool error: %s", e)
return None
+19
View File
@@ -82,6 +82,25 @@ def _get_metadata_variable_name(request: Request) -> str:
return "metadata"
def get_chain_id_from_headers(headers: Optional[Dict[str, str]]) -> Optional[str]:
"""
Extract chain id for call chaining from request headers.
x-litellm-trace-id and x-litellm-session-id are interchangeable; when both
are present, x-litellm-trace-id takes precedence. Header keys are matched
case-insensitively so this works with raw header dicts from any transport.
Used by MCP (and other paths that have raw_headers but no Request) to set
litellm_trace_id/litellm_session_id for spend logs and logging consistency.
"""
if not headers:
return None
normalized = {k.lower(): v for k, v in headers.items() if isinstance(k, str)}
return normalized.get("x-litellm-trace-id") or normalized.get(
"x-litellm-session-id"
)
def safe_add_api_version_from_query_params(data: dict, request: Request):
try:
if hasattr(request, "query_params"):
+1
View File
@@ -20,6 +20,7 @@ class LiteLLM_ToolTableRow(BaseModel):
key_hash: Optional[str] = None
team_id: Optional[str] = None
key_alias: Optional[str] = None
agent_id: Optional[str] = None # resolved from key table (key_hash -> key.agent_id)
created_at: Optional[datetime] = None
updated_at: Optional[datetime] = None
created_by: Optional[str] = None