diff --git a/litellm/proxy/_experimental/mcp_server/server.py b/litellm/proxy/_experimental/mcp_server/server.py index 86a89f1638..f2dcdf9f48 100644 --- a/litellm/proxy/_experimental/mcp_server/server.py +++ b/litellm/proxy/_experimental/mcp_server/server.py @@ -32,7 +32,8 @@ from litellm.proxy._experimental.mcp_server.utils import ( LITELLM_MCP_SERVER_VERSION) from litellm.proxy._types import UserAPIKeyAuth from litellm.proxy.auth.ip_address_utils import IPAddressUtils -from litellm.proxy.litellm_pre_call_utils import LiteLLMProxyRequestSetup +from litellm.proxy.litellm_pre_call_utils import (LiteLLMProxyRequestSetup, + get_chain_id_from_headers) from litellm.types.mcp import MCPAuth from litellm.types.mcp_server.mcp_server_manager import MCPInfo, MCPServer from litellm.types.utils import CallTypes, StandardLoggingMCPToolCall @@ -313,6 +314,11 @@ if MCP_AVAILABLE: try: # Create a body date for logging body_data = {"name": name, "arguments": arguments} + # Set trace/session id from raw_headers so spend logs and logging_obj stay consistent (same as A2A) + chain_id = get_chain_id_from_headers(raw_headers) + if chain_id: + body_data["litellm_trace_id"] = chain_id + body_data["litellm_session_id"] = chain_id request = Request( scope={ @@ -863,6 +869,10 @@ if MCP_AVAILABLE: # This is intentionally minimal: only async_success_handler / post_call_failure_hook rules_obj = Rules() list_tools_call_id = str(uuid.uuid4()) + # Derive trace_id from raw_headers when not explicitly passed (same as A2A / MCP call_tool) + effective_litellm_trace_id = litellm_trace_id or get_chain_id_from_headers( + raw_headers + ) spend_logs_metadata: Dict[str, Any] = { "mcp_operation": "list_tools", } @@ -875,7 +885,7 @@ if MCP_AVAILABLE: "model": "MCP: list_tools", "call_type": CallTypes.list_mcp_tools.value, "litellm_call_id": list_tools_call_id, - "litellm_trace_id": litellm_trace_id, + "litellm_trace_id": effective_litellm_trace_id, "metadata": { "spend_logs_metadata": spend_logs_metadata, }, diff --git a/litellm/proxy/db/tool_registry_writer.py b/litellm/proxy/db/tool_registry_writer.py index 064bf2ff65..445d215193 100644 --- a/litellm/proxy/db/tool_registry_writer.py +++ b/litellm/proxy/db/tool_registry_writer.py @@ -86,11 +86,35 @@ async def batch_upsert_tools( verbose_proxy_logger.error("tool_registry_writer batch_upsert_tools error: %s", e) +async def _get_agent_ids_for_key_hashes( + prisma_client: "PrismaClient", + key_hashes: List[str], +) -> Dict[str, str]: + """Resolve agent_id from key table for each key_hash. Returns map token -> agent_id.""" + if not key_hashes: + return {} + try: + key_records = await prisma_client.db.litellm_verificationtoken.find_many( + where={"token": {"in": key_hashes}}, + select={"token": True, "agent_id": True}, + ) + return { + k.token: k.agent_id + for k in key_records + if k.agent_id is not None + } + except Exception as e: + verbose_proxy_logger.debug( + "tool_registry_writer _get_agent_ids_for_key_hashes error: %s", e + ) + return {} + + async def list_tools( prisma_client: "PrismaClient", call_policy: Optional[ToolCallPolicy] = None, ) -> List[LiteLLM_ToolTableRow]: - """Return all tools, optionally filtered by call_policy.""" + """Return all tools, optionally filtered by call_policy. Enriches each row with agent_id from key table.""" try: if call_policy is not None: rows = await prisma_client.db.query_raw( @@ -105,7 +129,12 @@ async def list_tools( 'key_hash, team_id, key_alias, created_at, updated_at, created_by, updated_by ' 'FROM "LiteLLM_ToolTable" ORDER BY created_at DESC', ) - return [_row_to_model(row) for row in rows] + tools = [_row_to_model(row) for row in rows] + key_hashes = list({t.key_hash for t in tools if t.key_hash}) + key_to_agent = await _get_agent_ids_for_key_hashes(prisma_client, key_hashes) + for t in tools: + t.agent_id = key_to_agent.get(t.key_hash) if t.key_hash else None + return tools except Exception as e: verbose_proxy_logger.error("tool_registry_writer list_tools error: %s", e) return [] @@ -115,7 +144,7 @@ async def get_tool( prisma_client: "PrismaClient", tool_name: str, ) -> Optional[LiteLLM_ToolTableRow]: - """Return a single tool row by tool_name.""" + """Return a single tool row by tool_name. Enriches with agent_id from key table if key_hash is set.""" try: rows = await prisma_client.db.query_raw( 'SELECT tool_id, tool_name, origin, call_policy, call_count, assignments, ' @@ -125,7 +154,13 @@ async def get_tool( ) if not rows: return None - return _row_to_model(rows[0]) + tool = _row_to_model(rows[0]) + if tool.key_hash: + key_to_agent = await _get_agent_ids_for_key_hashes( + prisma_client, [tool.key_hash] + ) + tool.agent_id = key_to_agent.get(tool.key_hash) + return tool except Exception as e: verbose_proxy_logger.error("tool_registry_writer get_tool error: %s", e) return None diff --git a/litellm/proxy/litellm_pre_call_utils.py b/litellm/proxy/litellm_pre_call_utils.py index 9861026ba3..5b3b7049a6 100644 --- a/litellm/proxy/litellm_pre_call_utils.py +++ b/litellm/proxy/litellm_pre_call_utils.py @@ -82,6 +82,25 @@ def _get_metadata_variable_name(request: Request) -> str: return "metadata" +def get_chain_id_from_headers(headers: Optional[Dict[str, str]]) -> Optional[str]: + """ + Extract chain id for call chaining from request headers. + + x-litellm-trace-id and x-litellm-session-id are interchangeable; when both + are present, x-litellm-trace-id takes precedence. Header keys are matched + case-insensitively so this works with raw header dicts from any transport. + + Used by MCP (and other paths that have raw_headers but no Request) to set + litellm_trace_id/litellm_session_id for spend logs and logging consistency. + """ + if not headers: + return None + normalized = {k.lower(): v for k, v in headers.items() if isinstance(k, str)} + return normalized.get("x-litellm-trace-id") or normalized.get( + "x-litellm-session-id" + ) + + def safe_add_api_version_from_query_params(data: dict, request: Request): try: if hasattr(request, "query_params"): diff --git a/litellm/types/tool_management.py b/litellm/types/tool_management.py index 8704ff2775..12fe58029e 100644 --- a/litellm/types/tool_management.py +++ b/litellm/types/tool_management.py @@ -20,6 +20,7 @@ class LiteLLM_ToolTableRow(BaseModel): key_hash: Optional[str] = None team_id: Optional[str] = None key_alias: Optional[str] = None + agent_id: Optional[str] = None # resolved from key table (key_hash -> key.agent_id) created_at: Optional[datetime] = None updated_at: Optional[datetime] = None created_by: Optional[str] = None