diff --git a/enterprise/litellm_enterprise/enterprise_callbacks/pagerduty/pagerduty.py b/enterprise/litellm_enterprise/enterprise_callbacks/pagerduty/pagerduty.py index e481cdc995..b6c9104b23 100644 --- a/enterprise/litellm_enterprise/enterprise_callbacks/pagerduty/pagerduty.py +++ b/enterprise/litellm_enterprise/enterprise_callbacks/pagerduty/pagerduty.py @@ -1,309 +1,311 @@ -""" -PagerDuty Alerting Integration - -Handles two types of alerts: -- High LLM API Failure Rate. Configure X fails in Y seconds to trigger an alert. -- High Number of Hanging LLM Requests. Configure X hangs in Y seconds to trigger an alert. - -Note: This is a Free feature on the regular litellm docker image. - -However, this is under the enterprise license -""" - -import asyncio -import os -from datetime import datetime, timedelta, timezone -from typing import List, Literal, Optional, Union - -from litellm._logging import verbose_logger -from litellm.caching import DualCache -from litellm.integrations.SlackAlerting.slack_alerting import SlackAlerting -from litellm.llms.custom_httpx.http_handler import ( - AsyncHTTPHandler, - get_async_httpx_client, - httpxSpecialProvider, -) -from litellm.proxy._types import UserAPIKeyAuth -from litellm.types.integrations.pagerduty import ( - AlertingConfig, - PagerDutyInternalEvent, - PagerDutyPayload, - PagerDutyRequestBody, -) -from litellm.types.utils import ( - CallTypesLiteral, - StandardLoggingPayload, - StandardLoggingPayloadErrorInformation, -) - -PAGERDUTY_DEFAULT_FAILURE_THRESHOLD = 60 -PAGERDUTY_DEFAULT_FAILURE_THRESHOLD_WINDOW_SECONDS = 60 -PAGERDUTY_DEFAULT_HANGING_THRESHOLD_SECONDS = 60 -PAGERDUTY_DEFAULT_HANGING_THRESHOLD_WINDOW_SECONDS = 600 - - -class PagerDutyAlerting(SlackAlerting): - """ - Tracks failed requests and hanging requests separately. - If threshold is crossed for either type, triggers a PagerDuty alert. - """ - - def __init__( - self, alerting_args: Optional[Union[AlertingConfig, dict]] = None, **kwargs - ): - super().__init__() - _api_key = os.getenv("PAGERDUTY_API_KEY") - if not _api_key: - raise ValueError("PAGERDUTY_API_KEY is not set") - - self.api_key: str = _api_key - alerting_args = alerting_args or {} - self.pagerduty_alerting_args: AlertingConfig = AlertingConfig( - failure_threshold=alerting_args.get( - "failure_threshold", PAGERDUTY_DEFAULT_FAILURE_THRESHOLD - ), - failure_threshold_window_seconds=alerting_args.get( - "failure_threshold_window_seconds", - PAGERDUTY_DEFAULT_FAILURE_THRESHOLD_WINDOW_SECONDS, - ), - hanging_threshold_seconds=alerting_args.get( - "hanging_threshold_seconds", PAGERDUTY_DEFAULT_HANGING_THRESHOLD_SECONDS - ), - hanging_threshold_window_seconds=alerting_args.get( - "hanging_threshold_window_seconds", - PAGERDUTY_DEFAULT_HANGING_THRESHOLD_WINDOW_SECONDS, - ), - ) - - # Separate storage for failures vs. hangs - self._failure_events: List[PagerDutyInternalEvent] = [] - self._hanging_events: List[PagerDutyInternalEvent] = [] - - # ------------------ MAIN LOGIC ------------------ # - - async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time): - """ - Record a failure event. Only send an alert to PagerDuty if the - configured *failure* threshold is exceeded in the specified window. - """ - now = datetime.now(timezone.utc) - standard_logging_payload: Optional[StandardLoggingPayload] = kwargs.get( - "standard_logging_object" - ) - if not standard_logging_payload: - raise ValueError( - "standard_logging_object is required for PagerDutyAlerting" - ) - - # Extract error details - error_info: Optional[StandardLoggingPayloadErrorInformation] = ( - standard_logging_payload.get("error_information") or {} - ) - _meta = standard_logging_payload.get("metadata") or {} - - self._failure_events.append( - PagerDutyInternalEvent( - failure_event_type="failed_response", - timestamp=now, - error_class=error_info.get("error_class"), - error_code=error_info.get("error_code"), - error_llm_provider=error_info.get("llm_provider"), - user_api_key_hash=_meta.get("user_api_key_hash"), - user_api_key_alias=_meta.get("user_api_key_alias"), - user_api_key_spend=_meta.get("user_api_key_spend"), - user_api_key_max_budget=_meta.get("user_api_key_max_budget"), - user_api_key_budget_reset_at=_meta.get("user_api_key_budget_reset_at"), - user_api_key_org_id=_meta.get("user_api_key_org_id"), - user_api_key_team_id=_meta.get("user_api_key_team_id"), - user_api_key_user_id=_meta.get("user_api_key_user_id"), - user_api_key_team_alias=_meta.get("user_api_key_team_alias"), - user_api_key_end_user_id=_meta.get("user_api_key_end_user_id"), - user_api_key_user_email=_meta.get("user_api_key_user_email"), - user_api_key_request_route=_meta.get("user_api_key_request_route"), - user_api_key_auth_metadata=_meta.get("user_api_key_auth_metadata"), - ) - ) - - # Prune + Possibly alert - window_seconds = self.pagerduty_alerting_args.get( - "failure_threshold_window_seconds", 60 - ) - threshold = self.pagerduty_alerting_args.get("failure_threshold", 1) - - # If threshold is crossed, send PD alert for failures - await self._send_alert_if_thresholds_crossed( - events=self._failure_events, - window_seconds=window_seconds, - threshold=threshold, - alert_prefix="High LLM API Failure Rate", - ) - - async def async_pre_call_hook( - self, - user_api_key_dict: UserAPIKeyAuth, - cache: DualCache, - data: dict, - call_type: CallTypesLiteral, - ) -> Optional[Union[Exception, str, dict]]: - """ - Example of detecting hanging requests by waiting a given threshold. - If the request didn't finish by then, we treat it as 'hanging'. - """ - verbose_logger.info("Inside Proxy Logging Pre-call hook!") - asyncio.create_task( - self.hanging_response_handler( - request_data=data, user_api_key_dict=user_api_key_dict - ) - ) - return None - - async def hanging_response_handler( - self, request_data: Optional[dict], user_api_key_dict: UserAPIKeyAuth - ): - """ - Checks if request completed by the time 'hanging_threshold_seconds' elapses. - If not, we classify it as a hanging request. - """ - verbose_logger.debug( - f"Inside Hanging Response Handler!..sleeping for {self.pagerduty_alerting_args.get('hanging_threshold_seconds', PAGERDUTY_DEFAULT_HANGING_THRESHOLD_SECONDS)} seconds" - ) - await asyncio.sleep( - self.pagerduty_alerting_args.get( - "hanging_threshold_seconds", PAGERDUTY_DEFAULT_HANGING_THRESHOLD_SECONDS - ) - ) - - if await self._request_is_completed(request_data=request_data): - return # It's not hanging if completed - - # Otherwise, record it as hanging - self._hanging_events.append( - PagerDutyInternalEvent( - failure_event_type="hanging_response", - timestamp=datetime.now(timezone.utc), - error_class="HangingRequest", - error_code="HangingRequest", - error_llm_provider="HangingRequest", - user_api_key_hash=user_api_key_dict.api_key, - user_api_key_alias=user_api_key_dict.key_alias, - user_api_key_spend=user_api_key_dict.spend, - user_api_key_max_budget=user_api_key_dict.max_budget, - user_api_key_budget_reset_at=( - user_api_key_dict.budget_reset_at.isoformat() - if user_api_key_dict.budget_reset_at - else None - ), - user_api_key_org_id=user_api_key_dict.org_id, - user_api_key_team_id=user_api_key_dict.team_id, - user_api_key_user_id=user_api_key_dict.user_id, - user_api_key_team_alias=user_api_key_dict.team_alias, - user_api_key_end_user_id=user_api_key_dict.end_user_id, - user_api_key_user_email=user_api_key_dict.user_email, - user_api_key_request_route=user_api_key_dict.request_route, - user_api_key_auth_metadata=user_api_key_dict.metadata, - ) - ) - - # Prune + Possibly alert - window_seconds = self.pagerduty_alerting_args.get( - "hanging_threshold_window_seconds", - PAGERDUTY_DEFAULT_HANGING_THRESHOLD_WINDOW_SECONDS, - ) - threshold: int = self.pagerduty_alerting_args.get( - "hanging_threshold_fails", PAGERDUTY_DEFAULT_HANGING_THRESHOLD_SECONDS - ) - - # If threshold is crossed, send PD alert for hangs - await self._send_alert_if_thresholds_crossed( - events=self._hanging_events, - window_seconds=window_seconds, - threshold=threshold, - alert_prefix="High Number of Hanging LLM Requests", - ) - - # ------------------ HELPERS ------------------ # - - async def _send_alert_if_thresholds_crossed( - self, - events: List[PagerDutyInternalEvent], - window_seconds: int, - threshold: int, - alert_prefix: str, - ): - """ - 1. Prune old events - 2. If threshold is reached, build alert, send to PagerDuty - 3. Clear those events - """ - cutoff = datetime.now(timezone.utc) - timedelta(seconds=window_seconds) - pruned = [e for e in events if e.get("timestamp", datetime.min) > cutoff] - - # Update the reference list - events.clear() - events.extend(pruned) - - # Check threshold - verbose_logger.debug( - f"Have {len(events)} events in the last {window_seconds} seconds. Threshold is {threshold}" - ) - if len(events) >= threshold: - # Build short summary of last N events - error_summaries = self._build_error_summaries(events, max_errors=5) - alert_message = ( - f"{alert_prefix}: {len(events)} in the last {window_seconds} seconds." - ) - custom_details = {"recent_errors": error_summaries} - - await self.send_alert_to_pagerduty( - alert_message=alert_message, - custom_details=custom_details, - ) - - # Clear them after sending an alert, so we don't spam - events.clear() - - def _build_error_summaries( - self, events: List[PagerDutyInternalEvent], max_errors: int = 5 - ) -> List[PagerDutyInternalEvent]: - """ - Build short text summaries for the last `max_errors`. - Example: "ValueError (code: 500, provider: openai)" - """ - recent = events[-max_errors:] - summaries = [] - for fe in recent: - # If any of these is None, show "N/A" to avoid messing up the summary string - fe.pop("timestamp") - summaries.append(fe) - return summaries - - async def send_alert_to_pagerduty(self, alert_message: str, custom_details: dict): - """ - Send [critical] Alert to PagerDuty - - https://developer.pagerduty.com/api-reference/YXBpOjI3NDgyNjU-pager-duty-v2-events-api - """ - try: - verbose_logger.debug(f"Sending alert to PagerDuty: {alert_message}") - async_client: AsyncHTTPHandler = get_async_httpx_client( - llm_provider=httpxSpecialProvider.LoggingCallback - ) - payload: PagerDutyRequestBody = PagerDutyRequestBody( - payload=PagerDutyPayload( - summary=alert_message, - severity="critical", - source="LiteLLM Alert", - component="LiteLLM", - custom_details=custom_details, - ), - routing_key=self.api_key, - event_action="trigger", - ) - - return await async_client.post( - url="https://events.pagerduty.com/v2/enqueue", - json=dict(payload), - headers={"Content-Type": "application/json"}, - ) - except Exception as e: - verbose_logger.exception(f"Error sending alert to PagerDuty: {e}") +""" +PagerDuty Alerting Integration + +Handles two types of alerts: +- High LLM API Failure Rate. Configure X fails in Y seconds to trigger an alert. +- High Number of Hanging LLM Requests. Configure X hangs in Y seconds to trigger an alert. + +Note: This is a Free feature on the regular litellm docker image. + +However, this is under the enterprise license +""" + +import asyncio +import os +from datetime import datetime, timedelta, timezone +from typing import List, Optional, Union + +from litellm._logging import verbose_logger +from litellm.caching import DualCache +from litellm.integrations.SlackAlerting.slack_alerting import SlackAlerting +from litellm.llms.custom_httpx.http_handler import ( + AsyncHTTPHandler, + get_async_httpx_client, + httpxSpecialProvider, +) +from litellm.proxy._types import UserAPIKeyAuth +from litellm.types.integrations.pagerduty import ( + AlertingConfig, + PagerDutyInternalEvent, + PagerDutyPayload, + PagerDutyRequestBody, +) +from litellm.types.utils import ( + CallTypesLiteral, + StandardLoggingPayload, + StandardLoggingPayloadErrorInformation, +) + +PAGERDUTY_DEFAULT_FAILURE_THRESHOLD = 60 +PAGERDUTY_DEFAULT_FAILURE_THRESHOLD_WINDOW_SECONDS = 60 +PAGERDUTY_DEFAULT_HANGING_THRESHOLD_SECONDS = 60 +PAGERDUTY_DEFAULT_HANGING_THRESHOLD_WINDOW_SECONDS = 600 + + +class PagerDutyAlerting(SlackAlerting): + """ + Tracks failed requests and hanging requests separately. + If threshold is crossed for either type, triggers a PagerDuty alert. + """ + + def __init__( + self, alerting_args: Optional[Union[AlertingConfig, dict]] = None, **kwargs + ): + super().__init__() + _api_key = os.getenv("PAGERDUTY_API_KEY") + if not _api_key: + raise ValueError("PAGERDUTY_API_KEY is not set") + + self.api_key: str = _api_key + alerting_args = alerting_args or {} + self.pagerduty_alerting_args: AlertingConfig = AlertingConfig( + failure_threshold=alerting_args.get( + "failure_threshold", PAGERDUTY_DEFAULT_FAILURE_THRESHOLD + ), + failure_threshold_window_seconds=alerting_args.get( + "failure_threshold_window_seconds", + PAGERDUTY_DEFAULT_FAILURE_THRESHOLD_WINDOW_SECONDS, + ), + hanging_threshold_seconds=alerting_args.get( + "hanging_threshold_seconds", PAGERDUTY_DEFAULT_HANGING_THRESHOLD_SECONDS + ), + hanging_threshold_window_seconds=alerting_args.get( + "hanging_threshold_window_seconds", + PAGERDUTY_DEFAULT_HANGING_THRESHOLD_WINDOW_SECONDS, + ), + ) + + # Separate storage for failures vs. hangs + self._failure_events: List[PagerDutyInternalEvent] = [] + self._hanging_events: List[PagerDutyInternalEvent] = [] + + # ------------------ MAIN LOGIC ------------------ # + + async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time): + """ + Record a failure event. Only send an alert to PagerDuty if the + configured *failure* threshold is exceeded in the specified window. + """ + now = datetime.now(timezone.utc) + standard_logging_payload: Optional[StandardLoggingPayload] = kwargs.get( + "standard_logging_object" + ) + if not standard_logging_payload: + raise ValueError( + "standard_logging_object is required for PagerDutyAlerting" + ) + + # Extract error details + error_info: Optional[StandardLoggingPayloadErrorInformation] = ( + standard_logging_payload.get("error_information") or {} + ) + _meta = standard_logging_payload.get("metadata") or {} + + self._failure_events.append( + PagerDutyInternalEvent( + failure_event_type="failed_response", + timestamp=now, + error_class=error_info.get("error_class"), + error_code=error_info.get("error_code"), + error_llm_provider=error_info.get("llm_provider"), + user_api_key_hash=_meta.get("user_api_key_hash"), + user_api_key_alias=_meta.get("user_api_key_alias"), + user_api_key_spend=_meta.get("user_api_key_spend"), + user_api_key_max_budget=_meta.get("user_api_key_max_budget"), + user_api_key_budget_reset_at=_meta.get("user_api_key_budget_reset_at"), + user_api_key_org_id=_meta.get("user_api_key_org_id"), + user_api_key_team_id=_meta.get("user_api_key_team_id"), + user_api_key_project_id=_meta.get("user_api_key_project_id"), + user_api_key_user_id=_meta.get("user_api_key_user_id"), + user_api_key_team_alias=_meta.get("user_api_key_team_alias"), + user_api_key_end_user_id=_meta.get("user_api_key_end_user_id"), + user_api_key_user_email=_meta.get("user_api_key_user_email"), + user_api_key_request_route=_meta.get("user_api_key_request_route"), + user_api_key_auth_metadata=_meta.get("user_api_key_auth_metadata"), + ) + ) + + # Prune + Possibly alert + window_seconds = self.pagerduty_alerting_args.get( + "failure_threshold_window_seconds", 60 + ) + threshold = self.pagerduty_alerting_args.get("failure_threshold", 1) + + # If threshold is crossed, send PD alert for failures + await self._send_alert_if_thresholds_crossed( + events=self._failure_events, + window_seconds=window_seconds, + threshold=threshold, + alert_prefix="High LLM API Failure Rate", + ) + + async def async_pre_call_hook( + self, + user_api_key_dict: UserAPIKeyAuth, + cache: DualCache, + data: dict, + call_type: CallTypesLiteral, + ) -> Optional[Union[Exception, str, dict]]: + """ + Example of detecting hanging requests by waiting a given threshold. + If the request didn't finish by then, we treat it as 'hanging'. + """ + verbose_logger.info("Inside Proxy Logging Pre-call hook!") + asyncio.create_task( + self.hanging_response_handler( + request_data=data, user_api_key_dict=user_api_key_dict + ) + ) + return None + + async def hanging_response_handler( + self, request_data: Optional[dict], user_api_key_dict: UserAPIKeyAuth + ): + """ + Checks if request completed by the time 'hanging_threshold_seconds' elapses. + If not, we classify it as a hanging request. + """ + verbose_logger.debug( + f"Inside Hanging Response Handler!..sleeping for {self.pagerduty_alerting_args.get('hanging_threshold_seconds', PAGERDUTY_DEFAULT_HANGING_THRESHOLD_SECONDS)} seconds" + ) + await asyncio.sleep( + self.pagerduty_alerting_args.get( + "hanging_threshold_seconds", PAGERDUTY_DEFAULT_HANGING_THRESHOLD_SECONDS + ) + ) + + if await self._request_is_completed(request_data=request_data): + return # It's not hanging if completed + + # Otherwise, record it as hanging + self._hanging_events.append( + PagerDutyInternalEvent( + failure_event_type="hanging_response", + timestamp=datetime.now(timezone.utc), + error_class="HangingRequest", + error_code="HangingRequest", + error_llm_provider="HangingRequest", + user_api_key_hash=user_api_key_dict.api_key, + user_api_key_alias=user_api_key_dict.key_alias, + user_api_key_spend=user_api_key_dict.spend, + user_api_key_max_budget=user_api_key_dict.max_budget, + user_api_key_budget_reset_at=( + user_api_key_dict.budget_reset_at.isoformat() + if user_api_key_dict.budget_reset_at + else None + ), + user_api_key_org_id=user_api_key_dict.org_id, + user_api_key_team_id=user_api_key_dict.team_id, + user_api_key_project_id=user_api_key_dict.project_id, + user_api_key_user_id=user_api_key_dict.user_id, + user_api_key_team_alias=user_api_key_dict.team_alias, + user_api_key_end_user_id=user_api_key_dict.end_user_id, + user_api_key_user_email=user_api_key_dict.user_email, + user_api_key_request_route=user_api_key_dict.request_route, + user_api_key_auth_metadata=user_api_key_dict.metadata, + ) + ) + + # Prune + Possibly alert + window_seconds = self.pagerduty_alerting_args.get( + "hanging_threshold_window_seconds", + PAGERDUTY_DEFAULT_HANGING_THRESHOLD_WINDOW_SECONDS, + ) + threshold: int = self.pagerduty_alerting_args.get( + "hanging_threshold_fails", PAGERDUTY_DEFAULT_HANGING_THRESHOLD_SECONDS + ) + + # If threshold is crossed, send PD alert for hangs + await self._send_alert_if_thresholds_crossed( + events=self._hanging_events, + window_seconds=window_seconds, + threshold=threshold, + alert_prefix="High Number of Hanging LLM Requests", + ) + + # ------------------ HELPERS ------------------ # + + async def _send_alert_if_thresholds_crossed( + self, + events: List[PagerDutyInternalEvent], + window_seconds: int, + threshold: int, + alert_prefix: str, + ): + """ + 1. Prune old events + 2. If threshold is reached, build alert, send to PagerDuty + 3. Clear those events + """ + cutoff = datetime.now(timezone.utc) - timedelta(seconds=window_seconds) + pruned = [e for e in events if e.get("timestamp", datetime.min) > cutoff] + + # Update the reference list + events.clear() + events.extend(pruned) + + # Check threshold + verbose_logger.debug( + f"Have {len(events)} events in the last {window_seconds} seconds. Threshold is {threshold}" + ) + if len(events) >= threshold: + # Build short summary of last N events + error_summaries = self._build_error_summaries(events, max_errors=5) + alert_message = ( + f"{alert_prefix}: {len(events)} in the last {window_seconds} seconds." + ) + custom_details = {"recent_errors": error_summaries} + + await self.send_alert_to_pagerduty( + alert_message=alert_message, + custom_details=custom_details, + ) + + # Clear them after sending an alert, so we don't spam + events.clear() + + def _build_error_summaries( + self, events: List[PagerDutyInternalEvent], max_errors: int = 5 + ) -> List[PagerDutyInternalEvent]: + """ + Build short text summaries for the last `max_errors`. + Example: "ValueError (code: 500, provider: openai)" + """ + recent = events[-max_errors:] + summaries = [] + for fe in recent: + # If any of these is None, show "N/A" to avoid messing up the summary string + fe.pop("timestamp") + summaries.append(fe) + return summaries + + async def send_alert_to_pagerduty(self, alert_message: str, custom_details: dict): + """ + Send [critical] Alert to PagerDuty + + https://developer.pagerduty.com/api-reference/YXBpOjI3NDgyNjU-pager-duty-v2-events-api + """ + try: + verbose_logger.debug(f"Sending alert to PagerDuty: {alert_message}") + async_client: AsyncHTTPHandler = get_async_httpx_client( + llm_provider=httpxSpecialProvider.LoggingCallback + ) + payload: PagerDutyRequestBody = PagerDutyRequestBody( + payload=PagerDutyPayload( + summary=alert_message, + severity="critical", + source="LiteLLM Alert", + component="LiteLLM", + custom_details=custom_details, + ), + routing_key=self.api_key, + event_action="trigger", + ) + + return await async_client.post( + url="https://events.pagerduty.com/v2/enqueue", + json=dict(payload), + headers={"Content-Type": "application/json"}, + ) + except Exception as e: + verbose_logger.exception(f"Error sending alert to PagerDuty: {e}") diff --git a/litellm/litellm_core_utils/litellm_logging.py b/litellm/litellm_core_utils/litellm_logging.py index 6a14e42c48..5f555d83cf 100644 --- a/litellm/litellm_core_utils/litellm_logging.py +++ b/litellm/litellm_core_utils/litellm_logging.py @@ -1,5444 +1,5446 @@ -# What is this? -## Common Utility file for Logging handler -# Logging function -> log the exact model details + what's being sent | Non-Blocking -import copy -import datetime -import json -import os -import re -import subprocess -import sys -import time -import traceback -from datetime import datetime as dt_object -from functools import lru_cache -from typing import ( - TYPE_CHECKING, - Any, - Callable, - Dict, - List, - Literal, - Optional, - Tuple, - Type, - Union, - cast, -) - -from httpx import Response -from pydantic import BaseModel - -import litellm -from litellm import ( - _custom_logger_compatible_callbacks_literal, - json_logs, - log_raw_request_response, - turn_off_message_logging, -) -from litellm._logging import _is_debugging_on, verbose_logger -from litellm._uuid import uuid -from litellm.batches.batch_utils import _handle_completed_batch -from litellm.caching.caching import DualCache, InMemoryCache -from litellm.caching.caching_handler import LLMCachingHandler -from litellm.constants import ( - DEFAULT_MOCK_RESPONSE_COMPLETION_TOKEN_COUNT, - DEFAULT_MOCK_RESPONSE_PROMPT_TOKEN_COUNT, - SENTRY_DENYLIST, - SENTRY_PII_DENYLIST, -) -from litellm.cost_calculator import ( - RealtimeAPITokenUsageProcessor, - _select_model_name_for_cost_calc, -) -from litellm.integrations.agentops import AgentOps -from litellm.integrations.anthropic_cache_control_hook import AnthropicCacheControlHook -from litellm.integrations.arize.arize import ArizeLogger -from litellm.integrations.custom_guardrail import CustomGuardrail -from litellm.integrations.custom_logger import CustomLogger -from litellm.integrations.deepeval.deepeval import DeepEvalLogger -from litellm.integrations.mlflow import MlflowLogger -from litellm.integrations.sqs import SQSLogger -from litellm.litellm_core_utils.core_helpers import reconstruct_model_name -from litellm.litellm_core_utils.get_litellm_params import get_litellm_params -from litellm.litellm_core_utils.llm_cost_calc.tool_call_cost_tracking import ( - StandardBuiltInToolCostTracking, -) -from litellm.litellm_core_utils.model_param_helper import ModelParamHelper -from litellm.litellm_core_utils.redact_messages import ( - redact_message_input_output_from_custom_logger, - redact_message_input_output_from_logging, -) -from litellm.llms.base_llm.ocr.transformation import OCRResponse -from litellm.llms.base_llm.search.transformation import SearchResponse -from litellm.responses.utils import ResponseAPILoggingUtils -from litellm.types.agents import LiteLLMSendMessageResponse -from litellm.types.containers.main import ContainerObject -from litellm.types.llms.openai import ( - AllMessageValues, - Batch, - FineTuningJob, - HttpxBinaryResponseContent, - OpenAIFileObject, - OpenAIModerationResponse, - ResponseAPIUsage, - ResponseCompletedEvent, - ResponsesAPIResponse, -) -from litellm.types.mcp import MCPPostCallResponseObject -from litellm.types.prompts.init_prompts import PromptSpec -from litellm.types.rerank import RerankResponse -from litellm.types.utils import ( - CachingDetails, - CallTypes, - CostBreakdown, - CostResponseTypes, - CustomPricingLiteLLMParams, - DynamicPromptManagementParamLiteral, - EmbeddingResponse, - GuardrailStatus, - ImageResponse, - LiteLLMBatch, - LiteLLMLoggingBaseClass, - LiteLLMRealtimeStreamLoggingObject, - ModelResponse, - ModelResponseStream, - RawRequestTypedDict, - StandardBuiltInToolsParams, - StandardCallbackDynamicParams, - StandardLoggingAdditionalHeaders, - StandardLoggingHiddenParams, - StandardLoggingMCPToolCall, - StandardLoggingMetadata, - StandardLoggingModelCostFailureDebugInformation, - StandardLoggingModelInformation, - StandardLoggingPayload, - StandardLoggingPayloadErrorInformation, - StandardLoggingPayloadStatus, - StandardLoggingPayloadStatusFields, - StandardLoggingPromptManagementMetadata, - StandardLoggingVectorStoreRequest, - TextCompletionResponse, - TranscriptionResponse, - Usage, -) -from litellm.types.videos.main import VideoObject -from litellm.utils import _get_base_model_from_metadata, executor, print_verbose - -from ..integrations.argilla import ArgillaLogger -from ..integrations.arize.arize_phoenix import ArizePhoenixLogger -from ..integrations.athina import AthinaLogger -from ..integrations.azure_sentinel.azure_sentinel import AzureSentinelLogger -from ..integrations.azure_storage.azure_storage import AzureBlobStorageLogger -from ..integrations.custom_prompt_management import CustomPromptManagement -from ..integrations.datadog.datadog import DataDogLogger -from ..integrations.datadog.datadog_llm_obs import DataDogLLMObsLogger -from ..integrations.dotprompt import DotpromptManager -from ..integrations.dynamodb import DyanmoDBLogger -from ..integrations.galileo import GalileoObserve -from ..integrations.gcs_bucket.gcs_bucket import GCSBucketLogger -from ..integrations.gcs_pubsub.pub_sub import GcsPubSubLogger -from ..integrations.greenscale import GreenscaleLogger -from ..integrations.helicone import HeliconeLogger -from ..integrations.humanloop import HumanloopLogger -from ..integrations.lago import LagoLogger -from ..integrations.langfuse.langfuse import LangFuseLogger -from ..integrations.langfuse.langfuse_handler import LangFuseHandler -from ..integrations.langfuse.langfuse_prompt_management import LangfusePromptManagement -from ..integrations.langsmith import LangsmithLogger -from ..integrations.literal_ai import LiteralAILogger -from ..integrations.logfire_logger import LogfireLevel, LogfireLogger -from ..integrations.lunary import LunaryLogger -from ..integrations.openmeter import OpenMeterLogger -from ..integrations.opik.opik import OpikLogger -from ..integrations.posthog import PostHogLogger -from ..integrations.prompt_layer import PromptLayerLogger -from ..integrations.s3 import S3Logger -from ..integrations.s3_v2 import S3Logger as S3V2Logger -from ..integrations.supabase import Supabase -from ..integrations.traceloop import TraceloopLogger -from .exception_mapping_utils import _get_response_headers -from .initialize_dynamic_callback_params import ( - initialize_standard_callback_dynamic_params as _initialize_standard_callback_dynamic_params, -) -from .specialty_caches.dynamic_logging_cache import DynamicLoggingCache - -if TYPE_CHECKING: - from litellm.llms.base_llm.passthrough.transformation import BasePassthroughConfig -try: - from litellm_enterprise.enterprise_callbacks.callback_controls import ( - EnterpriseCallbackControls, - ) - from litellm_enterprise.enterprise_callbacks.pagerduty.pagerduty import ( - PagerDutyAlerting, - ) - from litellm_enterprise.enterprise_callbacks.send_emails.resend_email import ( - ResendEmailLogger, - ) - from litellm_enterprise.enterprise_callbacks.send_emails.sendgrid_email import ( - SendGridEmailLogger, - ) - from litellm_enterprise.enterprise_callbacks.send_emails.smtp_email import ( - SMTPEmailLogger, - ) - from litellm_enterprise.litellm_core_utils.litellm_logging import ( - StandardLoggingPayloadSetup as EnterpriseStandardLoggingPayloadSetup, - ) - - from litellm.integrations.generic_api.generic_api_callback import GenericAPILogger - - EnterpriseStandardLoggingPayloadSetupVAR: Optional[ - Type[EnterpriseStandardLoggingPayloadSetup] - ] = EnterpriseStandardLoggingPayloadSetup -except Exception as e: - verbose_logger.debug( - f"[Non-Blocking] Unable to import GenericAPILogger - LiteLLM Enterprise Feature - {str(e)}" - ) - GenericAPILogger = CustomLogger # type: ignore - ResendEmailLogger = CustomLogger # type: ignore - SendGridEmailLogger = CustomLogger # type: ignore - SMTPEmailLogger = CustomLogger # type: ignore - PagerDutyAlerting = CustomLogger # type: ignore - EnterpriseCallbackControls = None # type: ignore - EnterpriseStandardLoggingPayloadSetupVAR = None -_in_memory_loggers: List[Any] = [] - -_STANDARD_LOGGING_METADATA_KEYS: frozenset = frozenset( - StandardLoggingMetadata.__annotations__.keys() -) - -### GLOBAL VARIABLES ### - -# Cache custom pricing keys as frozenset for O(1) lookups instead of looping through 49 keys -_CUSTOM_PRICING_KEYS: frozenset = frozenset( - CustomPricingLiteLLMParams.model_fields.keys() -) - -sentry_sdk_instance = None -capture_exception = None -add_breadcrumb = None -slack_app = None -alerts_channel = None -heliconeLogger = None -athinaLogger = None -promptLayerLogger = None -logfireLogger = None -weightsBiasesLogger = None -customLogger = None -langFuseLogger = None -openMeterLogger = None -lagoLogger = None -dataDogLogger = None -prometheusLogger = None -dynamoLogger = None -s3Logger = None -greenscaleLogger = None -lunaryLogger = None -supabaseClient = None -deepevalLogger = None -callback_list: Optional[List[str]] = [] -user_logger_fn = None -additional_details: Optional[Dict[str, str]] = {} -local_cache: Optional[Dict[str, str]] = {} -last_fetched_at = None -last_fetched_at_keys = None - - -#### -class ServiceTraceIDCache: - def __init__(self) -> None: - self.cache = InMemoryCache() - - def get_cache(self, litellm_call_id: str, service_name: str) -> Optional[str]: - key_name = "{}:{}".format(service_name, litellm_call_id) - response = self.cache.get_cache(key=key_name) - return response - - def set_cache(self, litellm_call_id: str, service_name: str, trace_id: str) -> None: - key_name = "{}:{}".format(service_name, litellm_call_id) - self.cache.set_cache(key=key_name, value=trace_id) - return None - - -in_memory_trace_id_cache = ServiceTraceIDCache() -in_memory_dynamic_logger_cache = DynamicLoggingCache() - -# Cached lazy import for PrometheusLogger -# Module-level cache to avoid repeated imports while preserving memory benefits -_PrometheusLogger = None - - -def _get_cached_prometheus_logger(): - """ - Get cached PrometheusLogger class. - Lazy imports on first call to avoid loading prometheus.py and utils.py at import time (60MB saved). - Subsequent calls use cached class for better performance. - """ - global _PrometheusLogger - if _PrometheusLogger is None: - from litellm.integrations.prometheus import PrometheusLogger - - _PrometheusLogger = PrometheusLogger - return _PrometheusLogger - - -class Logging(LiteLLMLoggingBaseClass): - global supabaseClient, promptLayerLogger, weightsBiasesLogger, logfireLogger, capture_exception, add_breadcrumb, lunaryLogger, logfireLogger, prometheusLogger, slack_app - custom_pricing: bool = False - stream_options = None - litellm_request_debug: bool = False - - def __init__( - self, - model: str, - messages, - stream, - call_type, - start_time, - litellm_call_id: str, - function_id: str, - litellm_trace_id: Optional[str] = None, - dynamic_input_callbacks: Optional[ - List[Union[str, Callable, CustomLogger]] - ] = None, - dynamic_success_callbacks: Optional[ - List[Union[str, Callable, CustomLogger]] - ] = None, - dynamic_async_success_callbacks: Optional[ - List[Union[str, Callable, CustomLogger]] - ] = None, - dynamic_failure_callbacks: Optional[ - List[Union[str, Callable, CustomLogger]] - ] = None, - dynamic_async_failure_callbacks: Optional[ - List[Union[str, Callable, CustomLogger]] - ] = None, - applied_guardrails: Optional[List[str]] = None, - kwargs: Optional[Dict] = None, - log_raw_request_response: bool = False, - ): - _input: Optional[str] = messages # save original value of messages - if messages is not None: - if isinstance(messages, str): - messages = [ - {"role": "user", "content": messages} - ] # convert text completion input to the chat completion format - elif ( - isinstance(messages, list) - and len(messages) > 0 - and isinstance(messages[0], str) - ): - new_messages = [] - for m in messages: - new_messages.append({"role": "user", "content": m}) - messages = new_messages - - self.model = model - self.messages = copy.deepcopy(messages) if messages is not None else None - self.stream = stream - self.start_time = start_time # log the call start time - self.call_type = call_type - self.litellm_call_id = litellm_call_id - self.litellm_trace_id: str = ( - litellm_trace_id if litellm_trace_id else str(uuid.uuid4()) - ) - self.function_id = function_id - self.streaming_chunks: List[Any] = [] # for generating complete stream response - self.sync_streaming_chunks: List[ - Any - ] = [] # for generating complete stream response - self.log_raw_request_response = log_raw_request_response - - # Initialize dynamic callbacks - self.dynamic_input_callbacks: Optional[ - List[Union[str, Callable, CustomLogger]] - ] = dynamic_input_callbacks - self.dynamic_success_callbacks: Optional[ - List[Union[str, Callable, CustomLogger]] - ] = dynamic_success_callbacks - self.dynamic_async_success_callbacks: Optional[ - List[Union[str, Callable, CustomLogger]] - ] = dynamic_async_success_callbacks - self.dynamic_failure_callbacks: Optional[ - List[Union[str, Callable, CustomLogger]] - ] = dynamic_failure_callbacks - self.dynamic_async_failure_callbacks: Optional[ - List[Union[str, Callable, CustomLogger]] - ] = dynamic_async_failure_callbacks - - # Process dynamic callbacks - self.process_dynamic_callbacks() - - ## DYNAMIC LANGFUSE / GCS / logging callback KEYS ## - self.standard_callback_dynamic_params: StandardCallbackDynamicParams = ( - self.initialize_standard_callback_dynamic_params(kwargs) - ) - self.standard_built_in_tools_params: StandardBuiltInToolsParams = ( - self.initialize_standard_built_in_tools_params(kwargs) - ) - ## TIME TO FIRST TOKEN LOGGING ## - self.completion_start_time: Optional[datetime.datetime] = None - self._llm_caching_handler: Optional[LLMCachingHandler] = None - - # INITIAL LITELLM_PARAMS - litellm_params = {} - if kwargs is not None: - litellm_params = get_litellm_params(**kwargs) - litellm_params = scrub_sensitive_keys_in_metadata(litellm_params) - - self.litellm_params = litellm_params - - # Initialize cost breakdown field - self.cost_breakdown: Optional[CostBreakdown] = None - - # Init Caching related details - self.caching_details: Optional[CachingDetails] = None - - # Passthrough endpoint guardrails config for field targeting - self.passthrough_guardrails_config: Optional[Dict[str, Any]] = None - - self.model_call_details: Dict[str, Any] = { - "litellm_trace_id": litellm_trace_id, - "litellm_call_id": litellm_call_id, - "input": _input, - "litellm_params": litellm_params, - "applied_guardrails": applied_guardrails, - "model": model, - } - - def process_dynamic_callbacks(self): - """ - Initializes CustomLogger compatible callbacks in self.dynamic_* callbacks - - If a callback is in litellm._known_custom_logger_compatible_callbacks, it needs to be intialized and added to the respective dynamic_* callback list. - """ - # Process input callbacks - self.dynamic_input_callbacks = self._process_dynamic_callback_list( - self.dynamic_input_callbacks, dynamic_callbacks_type="input" - ) - - # Process failure callbacks - self.dynamic_failure_callbacks = self._process_dynamic_callback_list( - self.dynamic_failure_callbacks, dynamic_callbacks_type="failure" - ) - - # Process async failure callbacks - self.dynamic_async_failure_callbacks = self._process_dynamic_callback_list( - self.dynamic_async_failure_callbacks, dynamic_callbacks_type="async_failure" - ) - - # Process success callbacks - self.dynamic_success_callbacks = self._process_dynamic_callback_list( - self.dynamic_success_callbacks, dynamic_callbacks_type="success" - ) - - # Process async success callbacks - self.dynamic_async_success_callbacks = self._process_dynamic_callback_list( - self.dynamic_async_success_callbacks, dynamic_callbacks_type="async_success" - ) - - def _process_dynamic_callback_list( - self, - callback_list: Optional[List[Union[str, Callable, CustomLogger]]], - dynamic_callbacks_type: Literal[ - "input", "success", "failure", "async_success", "async_failure" - ], - ) -> Optional[List[Union[str, Callable, CustomLogger]]]: - """ - Helper function to initialize CustomLogger compatible callbacks in self.dynamic_* callbacks - - - If a callback is in litellm._known_custom_logger_compatible_callbacks, - replace the string with the initialized callback class. - - If dynamic callback is a "success" callback that is a known_custom_logger_compatible_callbacks then add it to dynamic_async_success_callbacks - - If dynamic callback is a "failure" callback that is a known_custom_logger_compatible_callbacks then add it to dynamic_failure_callbacks - """ - if callback_list is None: - return None - - processed_list: List[Union[str, Callable, CustomLogger]] = [] - for callback in callback_list: - if ( - isinstance(callback, str) - and callback in litellm._known_custom_logger_compatible_callbacks - ): - callback_class = _init_custom_logger_compatible_class( - callback, internal_usage_cache=None, llm_router=None # type: ignore - ) - if callback_class is not None: - processed_list.append(callback_class) - - # If processing dynamic_success_callbacks, add to dynamic_async_success_callbacks - if dynamic_callbacks_type == "success": - if self.dynamic_async_success_callbacks is None: - self.dynamic_async_success_callbacks = [] - self.dynamic_async_success_callbacks.append(callback_class) - elif dynamic_callbacks_type == "failure": - if self.dynamic_async_failure_callbacks is None: - self.dynamic_async_failure_callbacks = [] - self.dynamic_async_failure_callbacks.append(callback_class) - else: - processed_list.append(callback) - return processed_list - - def initialize_standard_callback_dynamic_params( - self, kwargs: Optional[Dict] = None - ) -> StandardCallbackDynamicParams: - """ - Initialize the standard callback dynamic params from the kwargs - - checks if langfuse_secret_key, gcs_bucket_name in kwargs and sets the corresponding attributes in StandardCallbackDynamicParams - """ - - return _initialize_standard_callback_dynamic_params(kwargs) - - def initialize_standard_built_in_tools_params( - self, kwargs: Optional[Dict] = None - ) -> StandardBuiltInToolsParams: - """ - Initialize the standard built-in tools params from the kwargs - - checks if web_search_options in kwargs or tools and sets the corresponding attribute in StandardBuiltInToolsParams - """ - return StandardBuiltInToolsParams( - web_search_options=StandardBuiltInToolCostTracking._get_web_search_options( - kwargs or {} - ), - file_search=StandardBuiltInToolCostTracking._get_file_search_tool_call( - kwargs or {} - ), - ) - - def update_environment_variables( - self, - litellm_params: Dict, - optional_params: Dict, - model: Optional[str] = None, - user: Optional[str] = None, - **additional_params, - ): - self.optional_params = optional_params - if model is not None: - self.model = model - self.user = user - self.litellm_params = { - **self.litellm_params, - **scrub_sensitive_keys_in_metadata(litellm_params), - } - self.litellm_request_debug = litellm_params.get("litellm_request_debug", False) - self.logger_fn = litellm_params.get("logger_fn", None) - if _is_debugging_on() or self.litellm_request_debug: - verbose_logger.debug(f"self.optional_params: {self.optional_params}") - - self.model_call_details.update( - { - "model": self.model, - "messages": self.messages, - "optional_params": self.optional_params, - "litellm_params": self.litellm_params, - "start_time": self.start_time, - "stream": self.stream, - "user": user, - "call_type": str(self.call_type), - "litellm_call_id": self.litellm_call_id, - "completion_start_time": self.completion_start_time, - "standard_callback_dynamic_params": self.standard_callback_dynamic_params, - **self.optional_params, - **additional_params, - } - ) - - ## check if stream options is set ## - used by CustomStreamWrapper for easy instrumentation - if "stream_options" in additional_params: - self.stream_options = additional_params["stream_options"] - ## check if custom pricing set ## - if any( - litellm_params.get(key) is not None - for key in _CUSTOM_PRICING_KEYS & litellm_params.keys() - ): - self.custom_pricing = True - - if "custom_llm_provider" in self.model_call_details: - self.custom_llm_provider = self.model_call_details["custom_llm_provider"] - - def update_messages(self, messages: List[AllMessageValues]): - """ - Update the logged value of the messages in the model_call_details - - Allows pre-call hooks to update the messages before the call is made - """ - self.messages = messages - self.model_call_details["messages"] = messages - - def should_run_prompt_management_hooks( - self, - non_default_params: Dict, - prompt_id: Optional[str] = None, - tools: Optional[List[Dict]] = None, - ) -> bool: - """ - Return True if prompt management hooks should be run - """ - if prompt_id: - return True - - if self._should_run_prompt_management_hooks_without_prompt_id( - non_default_params=non_default_params, - tools=tools, - ): - return True - - return False - - def _should_run_prompt_management_hooks_without_prompt_id( - self, - non_default_params: Dict, - tools: Optional[List[Dict]] = None, - ) -> bool: - """ - Certain prompt management hooks don't need a `prompt_id` to be passed in, they are triggered by dynamic params - - eg. AnthropicCacheControlHook and BedrockKnowledgeBaseHook both don't require a `prompt_id` to be passed in, they are triggered by dynamic params - """ - for param in non_default_params: - if param in DynamicPromptManagementParamLiteral.list_all_params(): - return True - - ############################################################################# - # Check if Vector Store / Knowledge Base hooks should be applied to the prompt - ############################################################################# - if litellm.vector_store_registry is not None: - if litellm.vector_store_registry.get_vector_store_to_run( - non_default_params=non_default_params, tools=tools - ): - return True - return False - - def get_chat_completion_prompt( - self, - model: str, - messages: List[AllMessageValues], - non_default_params: Dict, - prompt_variables: Optional[dict], - prompt_id: Optional[str] = None, - prompt_spec: Optional[PromptSpec] = None, - prompt_management_logger: Optional[CustomLogger] = None, - prompt_label: Optional[str] = None, - prompt_version: Optional[int] = None, - ) -> Tuple[str, List[AllMessageValues], dict]: - custom_logger = ( - prompt_management_logger - or self.get_custom_logger_for_prompt_management( - model=model, - non_default_params=non_default_params, - prompt_id=prompt_id, - prompt_spec=prompt_spec, - dynamic_callback_params=self.standard_callback_dynamic_params, - ) - ) - - if custom_logger: - ( - model, - messages, - non_default_params, - ) = custom_logger.get_chat_completion_prompt( - model=model, - messages=messages, - non_default_params=non_default_params or {}, - prompt_id=prompt_id, - prompt_spec=prompt_spec, - prompt_variables=prompt_variables, - dynamic_callback_params=self.standard_callback_dynamic_params, - prompt_label=prompt_label, - prompt_version=prompt_version, - ) - self.messages = messages - return model, messages, non_default_params - - async def async_get_chat_completion_prompt( - self, - model: str, - messages: List[AllMessageValues], - non_default_params: Dict, - prompt_variables: Optional[dict], - prompt_id: Optional[str] = None, - prompt_spec: Optional[PromptSpec] = None, - prompt_management_logger: Optional[CustomLogger] = None, - tools: Optional[List[Dict]] = None, - prompt_label: Optional[str] = None, - prompt_version: Optional[int] = None, - ) -> Tuple[str, List[AllMessageValues], dict]: - custom_logger = ( - prompt_management_logger - or self.get_custom_logger_for_prompt_management( - model=model, - tools=tools, - non_default_params=non_default_params, - prompt_id=prompt_id, - prompt_spec=prompt_spec, - dynamic_callback_params=self.standard_callback_dynamic_params, - ) - ) - - if custom_logger: - ( - model, - messages, - non_default_params, - ) = await custom_logger.async_get_chat_completion_prompt( - model=model, - messages=messages, - non_default_params=non_default_params or {}, - prompt_id=prompt_id, - prompt_spec=prompt_spec, - prompt_variables=prompt_variables, - dynamic_callback_params=self.standard_callback_dynamic_params, - litellm_logging_obj=self, - tools=tools, - prompt_label=prompt_label, - prompt_version=prompt_version, - ) - self.messages = messages - return model, messages, non_default_params - - def _auto_detect_prompt_management_logger( - self, - prompt_id: str, - prompt_spec: Optional[PromptSpec], - dynamic_callback_params: StandardCallbackDynamicParams, - ) -> Optional[CustomLogger]: - """ - Auto-detect which prompt management system owns the given prompt_id. - - This allows a user to just pass prompt_id in the completion call and it will be auto-detected which system owns this prompt. - - Args: - prompt_id: The prompt ID to check - dynamic_callback_params: Dynamic callback parameters for should_run_prompt_management checks - - Returns: - A CustomLogger instance if a matching prompt management system is found, None otherwise - """ - prompt_management_loggers = ( - litellm.logging_callback_manager.get_custom_loggers_for_type( - callback_type=CustomPromptManagement - ) - ) - - for logger in prompt_management_loggers: - if isinstance(logger, CustomPromptManagement): - try: - if logger.should_run_prompt_management( - prompt_id=prompt_id, - prompt_spec=prompt_spec, - dynamic_callback_params=dynamic_callback_params, - ): - self.model_call_details[ - "prompt_integration" - ] = logger.__class__.__name__ - return logger - except Exception: - # If check fails, continue to next logger - continue - - return None - - def get_custom_logger_for_prompt_management( - self, - model: str, - non_default_params: Dict, - tools: Optional[List[Dict]] = None, - prompt_id: Optional[str] = None, - prompt_spec: Optional[PromptSpec] = None, - dynamic_callback_params: Optional[StandardCallbackDynamicParams] = None, - ) -> Optional[CustomLogger]: - """ - Get a custom logger for prompt management based on model name or available callbacks. - - Args: - model: The model name to check for prompt management integration - non_default_params: Non-default parameters passed to the completion call - tools: Optional tools passed to the completion call - prompt_id: Optional prompt ID to auto-detect which system owns this prompt - dynamic_callback_params: Dynamic callback parameters for should_run_prompt_management checks - - Returns: - A CustomLogger instance if one is found, None otherwise - """ - # First check if model starts with a known custom logger compatible callback - # This takes precedence for backward compatibility - for callback_name in litellm._known_custom_logger_compatible_callbacks: - if model.startswith(callback_name): - custom_logger = _init_custom_logger_compatible_class( - logging_integration=callback_name, - internal_usage_cache=None, - llm_router=None, - ) - if custom_logger is not None: - self.model_call_details["prompt_integration"] = model.split("/")[0] - return custom_logger - - # If prompt_id is provided, try to auto-detect which system has this prompt - if prompt_id and dynamic_callback_params is not None: - auto_detected_logger = self._auto_detect_prompt_management_logger( - prompt_id=prompt_id, - prompt_spec=prompt_spec, - dynamic_callback_params=dynamic_callback_params, - ) - if auto_detected_logger is not None: - return auto_detected_logger - - # Then check for any registered CustomPromptManagement loggers (fallback) - prompt_management_loggers = ( - litellm.logging_callback_manager.get_custom_loggers_for_type( - callback_type=CustomPromptManagement - ) - ) - - if prompt_management_loggers: - logger = prompt_management_loggers[0] - self.model_call_details["prompt_integration"] = logger.__class__.__name__ - return logger - - if anthropic_cache_control_logger := AnthropicCacheControlHook.get_custom_logger_for_anthropic_cache_control_hook( - non_default_params - ): - self.model_call_details[ - "prompt_integration" - ] = anthropic_cache_control_logger.__class__.__name__ - return anthropic_cache_control_logger - - ######################################################### - # Vector Store / Knowledge Base hooks - ######################################################### - if litellm.vector_store_registry is not None: - vector_store_custom_logger = _init_custom_logger_compatible_class( - logging_integration="vector_store_pre_call_hook", - internal_usage_cache=None, - llm_router=None, - ) - self.model_call_details[ - "prompt_integration" - ] = vector_store_custom_logger.__class__.__name__ - # Add to global callbacks so post-call hooks are invoked - if ( - vector_store_custom_logger - and vector_store_custom_logger not in litellm.callbacks - ): - litellm.logging_callback_manager.add_litellm_callback( - vector_store_custom_logger - ) - return vector_store_custom_logger - - return None - - def get_custom_logger_for_anthropic_cache_control_hook( - self, non_default_params: Dict - ) -> Optional[CustomLogger]: - if non_default_params.get("cache_control_injection_points", None): - custom_logger = _init_custom_logger_compatible_class( - logging_integration="anthropic_cache_control_hook", - internal_usage_cache=None, - llm_router=None, - ) - return custom_logger - return None - - def _get_raw_request_body(self, data: Optional[Union[dict, str]]) -> dict: - if data is None: - return {"error": "Received empty dictionary for raw request body"} - if isinstance(data, str): - try: - return json.loads(data) - except Exception: - return { - "error": "Unable to parse raw request body. Got - {}".format(data) - } - return data - - def _get_masked_api_base(self, api_base: str) -> str: - if "key=" in api_base: - # Find the position of "key=" in the string - key_index = api_base.find("key=") + 4 - # Mask the last 5 characters after "key=" - masked_api_base = api_base[:key_index] + "*" * 5 + api_base[-4:] - else: - masked_api_base = api_base - return str(masked_api_base) - - def _pre_call(self, input, api_key, model=None, additional_args={}): - """ - Common helper function across the sync + async pre-call function - """ - - self.model_call_details["input"] = input - self.model_call_details["api_key"] = api_key - self.model_call_details["additional_args"] = additional_args - self.model_call_details["log_event_type"] = "pre_api_call" - if ( - model - ): # if model name was changes pre-call, overwrite the initial model call name with the new one - self.model_call_details["model"] = model - self.model_call_details["litellm_params"][ - "api_base" - ] = self._get_masked_api_base(additional_args.get("api_base", "")) - - def pre_call(self, input, api_key, model=None, additional_args={}): # noqa: PLR0915 - # Log the exact input to the LLM API - litellm.error_logs["PRE_CALL"] = locals() - try: - self._pre_call( - input=input, - api_key=api_key, - model=model, - additional_args=additional_args, - ) - - # User Logging -> if you pass in a custom logging function - self._print_llm_call_debugging_log( - api_base=additional_args.get("api_base", ""), - headers=additional_args.get("headers", {}), - additional_args=additional_args, - ) - # log raw request to provider (like LangFuse) -- if opted in. - if ( - self.log_raw_request_response is True - or log_raw_request_response is True - ): - _litellm_params = self.model_call_details.get("litellm_params", {}) - _metadata = _litellm_params.get("metadata", {}) or {} - try: - # [Non-blocking Extra Debug Information in metadata] - if turn_off_message_logging is True: - _metadata[ - "raw_request" - ] = "redacted by litellm. \ - 'litellm.turn_off_message_logging=True'" - else: - curl_command = self._get_request_curl_command( - api_base=additional_args.get("api_base", ""), - headers=additional_args.get("headers", {}), - additional_args=additional_args, - data=additional_args.get("complete_input_dict", {}), - ) - - _metadata["raw_request"] = str(curl_command) - # split up, so it's easier to parse in the UI - self.model_call_details[ - "raw_request_typed_dict" - ] = RawRequestTypedDict( - raw_request_api_base=str( - additional_args.get("api_base") or "" - ), - raw_request_body=self._get_raw_request_body( - additional_args.get("complete_input_dict", {}) - ), - # NOTE: setting ignore_sensitive_headers to True will cause - # the Authorization header to be leaked when calls to the health - # endpoint are made and fail. - raw_request_headers=self._get_masked_headers( - additional_args.get("headers", {}) or {}, - ), - error=None, - ) - except Exception as e: - self.model_call_details[ - "raw_request_typed_dict" - ] = RawRequestTypedDict( - error=str(e), - ) - _metadata[ - "raw_request" - ] = "Unable to Log \ - raw request: {}".format( - str(e) - ) - if getattr(self, "logger_fn", None) and callable(self.logger_fn): - try: - self.logger_fn( - self.model_call_details - ) # Expectation: any logger function passed in by the user should accept a dict object - except Exception as e: - verbose_logger.exception( - "LiteLLM.LoggingError: [Non-Blocking] Exception occurred while logging {}".format( - str(e) - ) - ) - - self.model_call_details["api_call_start_time"] = datetime.datetime.now() - # Input Integration Logging -> If you want to log the fact that an attempt to call the model was made - callbacks = litellm.input_callback + (self.dynamic_input_callbacks or []) - for callback in callbacks: - try: - if callback == "supabase" and supabaseClient is not None: - verbose_logger.debug("reaches supabase for logging!") - model = self.model_call_details["model"] - messages = self.model_call_details["input"] - verbose_logger.debug(f"supabaseClient: {supabaseClient}") - supabaseClient.input_log_event( - model=model, - messages=messages, - end_user=self.model_call_details.get("user", "default"), - litellm_call_id=self.litellm_params["litellm_call_id"], - print_verbose=print_verbose, - ) - elif callback == "sentry" and add_breadcrumb: - try: - details_to_log = copy.deepcopy(self.model_call_details) - except Exception: - details_to_log = self.model_call_details - if litellm.turn_off_message_logging: - # make a copy of the _model_Call_details and log it - details_to_log.pop("messages", None) - details_to_log.pop("input", None) - details_to_log.pop("prompt", None) - - add_breadcrumb( - category="litellm.llm_call", - message=f"Model Call Details pre-call: {details_to_log}", - level="info", - ) - - elif isinstance(callback, CustomLogger): # custom logger class - callback.log_pre_api_call( - model=self.model, - messages=self.messages, - kwargs=self.model_call_details, - ) - elif ( - callable(callback) and customLogger is not None - ): # custom logger functions - customLogger.log_input_event( - model=self.model, - messages=self.messages, - kwargs=self.model_call_details, - print_verbose=print_verbose, - callback_func=callback, - ) - except Exception as e: - verbose_logger.exception( - "litellm.Logging.pre_call(): Exception occured - {}".format( - str(e) - ) - ) - verbose_logger.debug( - f"LiteLLM.Logging: is sentry capture exception initialized {capture_exception}" - ) - if capture_exception: # log this error to sentry for debugging - capture_exception(e) - except Exception as e: - verbose_logger.exception( - "LiteLLM.LoggingError: [Non-Blocking] Exception occurred while logging {}".format( - str(e) - ) - ) - verbose_logger.error( - f"LiteLLM.Logging: is sentry capture exception initialized {capture_exception}" - ) - if capture_exception: # log this error to sentry for debugging - capture_exception(e) - - def _print_llm_call_debugging_log( - self, - api_base: str, - headers: dict, - additional_args: dict, - ): - """ - Internal debugging helper function - - Prints the RAW curl command sent from LiteLLM - """ - if _is_debugging_on() or self.litellm_request_debug: - if json_logs: - masked_headers = self._get_masked_headers(headers) - if self.litellm_request_debug: - verbose_logger.warning( # .warning ensures this shows up in all environments - "POST Request Sent from LiteLLM", - extra={"api_base": {api_base}, **masked_headers}, - ) - else: - verbose_logger.debug( - "POST Request Sent from LiteLLM", - extra={"api_base": {api_base}, **masked_headers}, - ) - else: - headers = additional_args.get("headers", {}) - if headers is None: - headers = {} - data = additional_args.get("complete_input_dict", {}) - api_base = str(additional_args.get("api_base", "")) - curl_command = self._get_request_curl_command( - api_base=api_base, - headers=headers, - additional_args=additional_args, - data=data, - ) - if self.litellm_request_debug: - verbose_logger.warning( - f"\033[92m{curl_command}\033[0m\n" - ) # .warning ensures this shows up in all environments - else: - verbose_logger.debug(f"\033[92m{curl_command}\033[0m\n") - - def _get_request_body(self, data: dict) -> str: - return str(data) - - def _get_request_curl_command( - self, api_base: str, headers: Optional[dict], additional_args: dict, data: dict - ) -> str: - masked_api_base = self._get_masked_api_base(api_base) - if headers is None: - headers = {} - curl_command = "\n\nPOST Request Sent from LiteLLM:\n" - curl_command += "curl -X POST \\\n" - curl_command += f"{masked_api_base} \\\n" - masked_headers = self._get_masked_headers(headers) - formatted_headers = " ".join( - [f"-H '{k}: {v}'" for k, v in masked_headers.items()] - ) - curl_command += ( - f"{formatted_headers} \\\n" if formatted_headers.strip() != "" else "" - ) - curl_command += f"-d '{self._get_request_body(data)}'\n" - if additional_args.get("request_str", None) is not None: - # print the sagemaker / bedrock client request - curl_command = "\nRequest Sent from LiteLLM:\n" - request_str = additional_args.get("request_str", "") - curl_command += request_str - elif api_base == "": - curl_command = str(self.model_call_details) - return curl_command - - def _get_masked_headers( - self, headers: dict, ignore_sensitive_headers: bool = False - ) -> dict: - """ - Internal debugging helper function - - Masks the headers of the request sent from LiteLLM - """ - return _get_masked_values( - headers, ignore_sensitive_values=ignore_sensitive_headers - ) - - def post_call( - self, original_response, input=None, api_key=None, additional_args={} - ): - # Log the exact result from the LLM API, for streaming - log the type of response received - litellm.error_logs["POST_CALL"] = locals() - if isinstance(original_response, dict): - original_response = json.dumps(original_response) - try: - self.model_call_details["input"] = input - self.model_call_details["api_key"] = api_key - self.model_call_details["original_response"] = original_response - self.model_call_details["additional_args"] = additional_args - self.model_call_details["log_event_type"] = "post_api_call" - - if self.litellm_request_debug: - attr = "warning" - else: - attr = "debug" - - if json_logs: - callattr = getattr(verbose_logger, attr) - callattr( - "RAW RESPONSE:\n{}\n\n".format( - self.model_call_details.get( - "original_response", self.model_call_details - ) - ), - ) - else: - callattr = getattr(verbose_logger, attr) - callattr( - "RAW RESPONSE:\n{}\n\n".format( - self.model_call_details.get( - "original_response", self.model_call_details - ) - ) - ) - if getattr(self, "logger_fn", None) and callable(self.logger_fn): - try: - self.logger_fn( - self.model_call_details - ) # Expectation: any logger function passed in by the user should accept a dict object - except Exception as e: - verbose_logger.exception( - "LiteLLM.LoggingError: [Non-Blocking] Exception occurred while logging {}".format( - str(e) - ) - ) - original_response = redact_message_input_output_from_logging( - model_call_details=( - self.model_call_details - if hasattr(self, "model_call_details") - else {} - ), - result=original_response, - ) - # Input Integration Logging -> If you want to log the fact that an attempt to call the model was made - - callbacks = litellm.input_callback + (self.dynamic_input_callbacks or []) - for callback in callbacks: - try: - if callback == "sentry" and add_breadcrumb: - verbose_logger.debug("reaches sentry breadcrumbing") - try: - details_to_log = copy.deepcopy(self.model_call_details) - except Exception: - details_to_log = self.model_call_details - if litellm.turn_off_message_logging: - # make a copy of the _model_Call_details and log it - details_to_log.pop("messages", None) - details_to_log.pop("input", None) - details_to_log.pop("prompt", None) - - add_breadcrumb( - category="litellm.llm_call", - message=f"Model Call Details post-call: {details_to_log}", - level="info", - ) - elif isinstance(callback, CustomLogger): # custom logger class - callback.log_post_api_call( - kwargs=self.model_call_details, - response_obj=None, - start_time=self.start_time, - end_time=None, - ) - except Exception as e: - verbose_logger.exception( - "LiteLLM.LoggingError: [Non-Blocking] Exception occurred while post-call logging with integrations {}".format( - str(e) - ) - ) - verbose_logger.debug( - f"LiteLLM.Logging: is sentry capture exception initialized {capture_exception}" - ) - if capture_exception: # log this error to sentry for debugging - capture_exception(e) - except Exception as e: - verbose_logger.exception( - "LiteLLM.LoggingError: [Non-Blocking] Exception occurred while logging {}".format( - str(e) - ) - ) - - async def async_post_mcp_tool_call_hook( - self, - kwargs: dict, - response_obj: Any, - start_time: datetime.datetime, - end_time: datetime.datetime, - ): - """ - Post MCP Tool Call Hook - - Use this to modify the MCP tool call response before it is returned to the user. - """ - from litellm.types.llms.base import HiddenParams - from litellm.types.mcp import MCPPostCallResponseObject - - callbacks = self.get_combined_callback_list( - dynamic_success_callbacks=self.dynamic_success_callbacks, - global_callbacks=litellm.success_callback, - ) - post_mcp_tool_call_response_obj: MCPPostCallResponseObject = ( - MCPPostCallResponseObject( - mcp_tool_call_response=response_obj, hidden_params=HiddenParams() - ) - ) - for callback in callbacks: - try: - if isinstance(callback, CustomLogger): - response: Optional[ - MCPPostCallResponseObject - ] = await callback.async_post_mcp_tool_call_hook( - kwargs=kwargs, - response_obj=post_mcp_tool_call_response_obj, - start_time=start_time, - end_time=end_time, - ) - ###################################################################### - # if any of the callbacks modify the response, use the modified response - # current implementation returns the first modified response - ###################################################################### - if response is not None: - response_obj = self._parse_post_mcp_call_hook_response( - response=response - ) - except Exception as e: - verbose_logger.exception( - "LiteLLM.LoggingError: [Non-Blocking] Exception occurred while logging {}".format( - str(e) - ) - ) - return response_obj - - def _parse_post_mcp_call_hook_response( - self, response: Optional[MCPPostCallResponseObject] - ) -> Any: - """ - Parse the response from the post_mcp_tool_call_hook - - 1. Unpack the mcp_tool_call_response - 2. save the updated response_cost to the model_call_details - """ - if response is None: - return None - self.model_call_details["response_cost"] = response.hidden_params.response_cost - return response.mcp_tool_call_response - - def get_response_ms(self) -> float: - return ( - self.model_call_details.get("end_time", datetime.datetime.now()) - - self.model_call_details.get("start_time", datetime.datetime.now()) - ).total_seconds() * 1000 - - def set_cost_breakdown( - self, - input_cost: float, - output_cost: float, - total_cost: float, - cost_for_built_in_tools_cost_usd_dollar: float, - additional_costs: Optional[dict] = None, - original_cost: Optional[float] = None, - discount_percent: Optional[float] = None, - discount_amount: Optional[float] = None, - margin_percent: Optional[float] = None, - margin_fixed_amount: Optional[float] = None, - margin_total_amount: Optional[float] = None, - ) -> None: - """ - Helper method to store cost breakdown in the logging object. - - Args: - input_cost: Cost of input/prompt tokens - output_cost: Cost of output/completion tokens - cost_for_built_in_tools_cost_usd_dollar: Cost of built-in tools - total_cost: Total cost of request - additional_costs: Free-form additional costs dict (e.g., {"azure_model_router_flat_cost": 0.00014}) - original_cost: Cost before discount - discount_percent: Discount percentage (0.05 = 5%) - discount_amount: Discount amount in USD - margin_percent: Margin percentage applied (0.10 = 10%) - margin_fixed_amount: Fixed margin amount in USD - margin_total_amount: Total margin added in USD - """ - - self.cost_breakdown = CostBreakdown( - input_cost=input_cost, - output_cost=output_cost, - total_cost=total_cost, - tool_usage_cost=cost_for_built_in_tools_cost_usd_dollar, - ) - - # Store additional costs if provided (free-form dict for extensibility) - if ( - additional_costs - and isinstance(additional_costs, dict) - and len(additional_costs) > 0 - ): - self.cost_breakdown["additional_costs"] = additional_costs - - # Store discount information if provided - if original_cost is not None: - self.cost_breakdown["original_cost"] = original_cost - if discount_percent is not None: - self.cost_breakdown["discount_percent"] = discount_percent - if discount_amount is not None: - self.cost_breakdown["discount_amount"] = discount_amount - - # Store margin information if provided - if margin_percent is not None: - self.cost_breakdown["margin_percent"] = margin_percent - if margin_fixed_amount is not None: - self.cost_breakdown["margin_fixed_amount"] = margin_fixed_amount - if margin_total_amount is not None: - self.cost_breakdown["margin_total_amount"] = margin_total_amount - - def _response_cost_calculator( - self, - result: Union[ - ModelResponse, - ModelResponseStream, - EmbeddingResponse, - ImageResponse, - TranscriptionResponse, - TextCompletionResponse, - HttpxBinaryResponseContent, - RerankResponse, - Batch, - FineTuningJob, - ResponsesAPIResponse, - ResponseCompletedEvent, - OpenAIFileObject, - LiteLLMRealtimeStreamLoggingObject, - OpenAIModerationResponse, - "SearchResponse", - ], - cache_hit: Optional[bool] = None, - litellm_model_name: Optional[str] = None, - router_model_id: Optional[str] = None, - ) -> Optional[float]: - """ - Calculate response cost using result + logging object variables. - - used for consistent cost calculation across response headers + logging integrations. - """ - - if isinstance(result, BaseModel) and hasattr(result, "_hidden_params"): - hidden_params = getattr(result, "_hidden_params", {}) - if ( - "response_cost" in hidden_params - and hidden_params["response_cost"] is not None - ): # use cost if already calculated - return hidden_params["response_cost"] - elif ( - router_model_id is None and "model_id" in hidden_params - ): # use model_id if not already set - router_model_id = hidden_params["model_id"] - - ## RESPONSE COST ## - custom_pricing = use_custom_pricing_for_model( - litellm_params=( - self.litellm_params if hasattr(self, "litellm_params") else None - ) - ) - - prompt = "" # use for tts cost calc - _input = self.model_call_details.get("input", None) - if _input is not None and isinstance(_input, str): - prompt = _input - - if cache_hit is None: - cache_hit = self.model_call_details.get("cache_hit", False) - - try: - response_cost_calculator_kwargs = { - "response_object": result, - "model": litellm_model_name or self.model, - "cache_hit": cache_hit, - "custom_llm_provider": self.model_call_details.get( - "custom_llm_provider", None - ), - "base_model": _get_base_model_from_metadata( - model_call_details=self.model_call_details - ), - "call_type": self.call_type, - "optional_params": self.optional_params, - "custom_pricing": custom_pricing, - "prompt": prompt, - "standard_built_in_tools_params": self.standard_built_in_tools_params, - "router_model_id": router_model_id, - "litellm_logging_obj": self, - "service_tier": ( - self.optional_params.get("service_tier") - if self.optional_params - else None - ), - } - except Exception as e: # error creating kwargs for cost calculation - debug_info = StandardLoggingModelCostFailureDebugInformation( - error_str=str(e), - traceback_str=_get_traceback_str_for_error(str(e)), - ) - verbose_logger.debug( - f"response_cost_failure_debug_information: {debug_info}" - ) - self.model_call_details[ - "response_cost_failure_debug_information" - ] = debug_info - return None - - try: - response_cost = litellm.response_cost_calculator( - **response_cost_calculator_kwargs - ) - - verbose_logger.debug(f"response_cost: {response_cost}") - return response_cost - except Exception as e: # error calculating cost - debug_info = StandardLoggingModelCostFailureDebugInformation( - error_str=str(e), - traceback_str=_get_traceback_str_for_error(str(e)), - model=response_cost_calculator_kwargs["model"], - cache_hit=response_cost_calculator_kwargs["cache_hit"], - custom_llm_provider=response_cost_calculator_kwargs[ - "custom_llm_provider" - ], - base_model=response_cost_calculator_kwargs["base_model"], - call_type=response_cost_calculator_kwargs["call_type"], - custom_pricing=response_cost_calculator_kwargs["custom_pricing"], - ) - verbose_logger.debug( - f"response_cost_failure_debug_information: {debug_info}" - ) - self.model_call_details[ - "response_cost_failure_debug_information" - ] = debug_info - - return None - - async def _response_cost_calculator_async( - self, - result: Union[ - ModelResponse, - ModelResponseStream, - EmbeddingResponse, - ImageResponse, - TranscriptionResponse, - TextCompletionResponse, - HttpxBinaryResponseContent, - RerankResponse, - Batch, - FineTuningJob, - ], - cache_hit: Optional[bool] = None, - ) -> Optional[float]: - return self._response_cost_calculator(result=result, cache_hit=cache_hit) - - def should_run_logging( - self, - event_type: Literal[ - "async_success", "sync_success", "async_failure", "sync_failure" - ], - stream: bool = False, - ) -> bool: - try: - if self.model_call_details.get(f"has_logged_{event_type}", False) is True: - return False - - return True - except Exception: - return True - - def has_run_logging( - self, - event_type: Literal[ - "async_success", "sync_success", "async_failure", "sync_failure" - ], - ) -> None: - if self.stream is not None and self.stream is True: - """ - Ignore check on stream, as there can be multiple chunks - """ - return - self.model_call_details[f"has_logged_{event_type}"] = True - return - - def should_run_callback( - self, callback: litellm.CALLBACK_TYPES, litellm_params: dict, event_hook: str - ) -> bool: - if litellm.global_disable_no_log_param: - return True - - if litellm_params.get("no-log", False) is True: - # proxy cost tracking cal backs should run - - if not ( - isinstance(callback, CustomLogger) - and "_PROXY_" in callback.__class__.__name__ - ): - verbose_logger.debug( - f"no-log request, skipping logging for {event_hook} event" - ) - return False - - # Check for dynamically disabled callbacks via headers - if ( - EnterpriseCallbackControls is not None - and EnterpriseCallbackControls.is_callback_disabled_dynamically( - callback=callback, - litellm_params=litellm_params, - standard_callback_dynamic_params=self.standard_callback_dynamic_params, - ) - ): - verbose_logger.debug( - f"Callback {callback} disabled via x-litellm-disable-callbacks header for {event_hook} event" - ) - return False - - return True - - def _update_completion_start_time(self, completion_start_time: datetime.datetime): - self.completion_start_time = completion_start_time - self.model_call_details["completion_start_time"] = self.completion_start_time - - def normalize_logging_result(self, result: Any) -> Any: - """ - Some endpoints return a different type of result than what is expected by the logging system. - This function is used to normalize the result to the expected type. - """ - logging_result = result - if self.call_type == CallTypes.arealtime.value and isinstance(result, list): - combined_usage_object = RealtimeAPITokenUsageProcessor.collect_and_combine_usage_from_realtime_stream_results( - results=result - ) - logging_result = ( - RealtimeAPITokenUsageProcessor.create_logging_realtime_object( - usage=combined_usage_object, - results=result, - ) - ) - - elif ( - self.call_type == CallTypes.llm_passthrough_route.value - or self.call_type == CallTypes.allm_passthrough_route.value - ) and isinstance(result, Response): - from litellm.utils import ProviderConfigManager - - provider_config = ProviderConfigManager.get_provider_passthrough_config( - provider=self.model_call_details.get("custom_llm_provider", ""), - model=self.model, - ) - if provider_config is not None: - logging_result = provider_config.logging_non_streaming_response( - model=self.model, - custom_llm_provider=self.model_call_details.get( - "custom_llm_provider", "" - ), - httpx_response=result, - request_data=self.model_call_details.get("request_data", {}), - logging_obj=self, - endpoint=self.model_call_details.get("endpoint", ""), - ) - return logging_result - - def _process_hidden_params_and_response_cost( - self, - logging_result, - start_time, - end_time, - ): - hidden_params = getattr(logging_result, "_hidden_params", {}) - if hidden_params: - if self.model_call_details.get("litellm_params") is not None: - self.model_call_details["litellm_params"].setdefault("metadata", {}) - if self.model_call_details["litellm_params"]["metadata"] is None: - self.model_call_details["litellm_params"]["metadata"] = {} - self.model_call_details["litellm_params"]["metadata"]["hidden_params"] = getattr(logging_result, "_hidden_params", {}) # type: ignore - - if "response_cost" in hidden_params: - self.model_call_details["response_cost"] = hidden_params["response_cost"] - else: - self.model_call_details["response_cost"] = self._response_cost_calculator( - result=logging_result - ) - - self.model_call_details[ - "standard_logging_object" - ] = get_standard_logging_object_payload( - kwargs=self.model_call_details, - init_response_obj=logging_result, - start_time=start_time, - end_time=end_time, - logging_obj=self, - status="success", - standard_built_in_tools_params=self.standard_built_in_tools_params, - ) - - def _transform_usage_objects(self, result): - if isinstance(result, ResponsesAPIResponse): - result = result.model_copy() - transformed_usage = ( - ResponseAPILoggingUtils._transform_response_api_usage_to_chat_usage( - result.usage - ) - ) - setattr(result, "usage", transformed_usage) - if ( - standard_logging_payload := self.model_call_details.get( - "standard_logging_object" - ) - ) is not None: - response_dict = ( - result.model_dump() - if hasattr(result, "model_dump") - else dict(result) - ) - # Ensure usage is properly included with transformed chat format - if transformed_usage is not None: - response_dict["usage"] = ( - transformed_usage.model_dump() - if hasattr(transformed_usage, "model_dump") - else dict(transformed_usage) - ) - standard_logging_payload["response"] = response_dict - elif isinstance(result, TranscriptionResponse): - from litellm.litellm_core_utils.llm_cost_calc.usage_object_transformation import ( - TranscriptionUsageObjectTransformation, - ) - - result = result.model_copy() - transformed_usage = TranscriptionUsageObjectTransformation.transform_transcription_usage_object(result.usage) # type: ignore - setattr(result, "usage", transformed_usage) - return result - - def _success_handler_helper_fn( - self, - result=None, - start_time=None, - end_time=None, - cache_hit=None, - standard_logging_object: Optional[StandardLoggingPayload] = None, - ): - try: - if start_time is None: - start_time = self.start_time - if end_time is None: - end_time = datetime.datetime.now() - if self.completion_start_time is None: - self.completion_start_time = end_time - self.model_call_details[ - "completion_start_time" - ] = self.completion_start_time - - self.model_call_details["log_event_type"] = "successful_api_call" - self.model_call_details["end_time"] = end_time - self.model_call_details["cache_hit"] = cache_hit - - if self.call_type == CallTypes.anthropic_messages.value: - result = self._handle_anthropic_messages_response_logging(result=result) - elif ( - self.call_type == CallTypes.generate_content.value - or self.call_type == CallTypes.agenerate_content.value - ): - result = self._handle_non_streaming_google_genai_generate_content_response_logging( - result=result - ) - elif ( - self.call_type == CallTypes.asend_message.value - or self.call_type == CallTypes.send_message.value - ): - result = self._handle_a2a_response_logging(result=result) - - logging_result = self.normalize_logging_result(result=result) - - if ( - standard_logging_object is None - and result is not None - and self.stream is not True - ): - if self._is_recognized_call_type_for_logging( - logging_result=logging_result - ): - self._process_hidden_params_and_response_cost( - logging_result=logging_result, - start_time=start_time, - end_time=end_time, - ) - elif isinstance(result, dict) or isinstance(result, list): - self.model_call_details[ - "standard_logging_object" - ] = get_standard_logging_object_payload( - kwargs=self.model_call_details, - init_response_obj=result, - start_time=start_time, - end_time=end_time, - logging_obj=self, - status="success", - standard_built_in_tools_params=self.standard_built_in_tools_params, - ) - elif standard_logging_object is not None: - self.model_call_details[ - "standard_logging_object" - ] = standard_logging_object - else: - self.model_call_details["response_cost"] = None - - result = self._transform_usage_objects(result=result) - - if ( - litellm.max_budget - and self.stream is False - and result is not None - and isinstance(result, dict) - and "content" in result - ): - time_diff = (end_time - start_time).total_seconds() - float_diff = float(time_diff) - litellm._current_cost += litellm.completion_cost( - model=self.model, - prompt="", - completion=getattr(result, "content", ""), - total_time=float_diff, - standard_built_in_tools_params=self.standard_built_in_tools_params, - ) - - return start_time, end_time, result - except Exception as e: - raise Exception(f"[Non-Blocking] LiteLLM.Success_Call Error: {str(e)}") - - def _is_recognized_call_type_for_logging( - self, - logging_result: Any, - ): - """ - Returns True if the call type is recognized for logging (eg. ModelResponse, ModelResponseStream, etc.) - """ - if ( - isinstance(logging_result, ModelResponse) - or isinstance(logging_result, ModelResponseStream) - or isinstance(logging_result, EmbeddingResponse) - or isinstance(logging_result, ImageResponse) - or isinstance(logging_result, TranscriptionResponse) - or isinstance(logging_result, TextCompletionResponse) - or isinstance(logging_result, HttpxBinaryResponseContent) # tts - or isinstance(logging_result, RerankResponse) - or isinstance(logging_result, FineTuningJob) - or isinstance(logging_result, LiteLLMBatch) - or isinstance(logging_result, ResponsesAPIResponse) - or isinstance(logging_result, OpenAIFileObject) - or isinstance(logging_result, LiteLLMRealtimeStreamLoggingObject) - or isinstance(logging_result, OpenAIModerationResponse) - or isinstance(logging_result, OCRResponse) # OCR - or isinstance(logging_result, SearchResponse) # Search API - or isinstance(logging_result, dict) - and logging_result.get("object") == "vector_store.search_results.page" - or isinstance(logging_result, dict) - and logging_result.get("object") == "search" # Search API (dict format) - or isinstance(logging_result, VideoObject) - or isinstance(logging_result, ContainerObject) - or isinstance(logging_result, LiteLLMSendMessageResponse) # A2A - or (self.call_type == CallTypes.call_mcp_tool.value) - ): - return True - return False - - def _flush_passthrough_collected_chunks_helper( - self, - raw_bytes: List[bytes], - provider_config: "BasePassthroughConfig", - ) -> Optional["CostResponseTypes"]: - all_chunks = provider_config._convert_raw_bytes_to_str_lines(raw_bytes) - complete_streaming_response = provider_config.handle_logging_collected_chunks( - all_chunks=all_chunks, - litellm_logging_obj=self, - model=self.model, - custom_llm_provider=self.model_call_details.get("custom_llm_provider", ""), - endpoint=self.model_call_details.get("endpoint", ""), - ) - return complete_streaming_response - - def flush_passthrough_collected_chunks( - self, - raw_bytes: List[bytes], - provider_config: "BasePassthroughConfig", - ): - """ - Flush collected chunks from the logging object - This is used to log the collected chunks once streaming is done on passthrough endpoints - - 1. Decode the raw bytes to string lines - 2. Get the complete streaming response from the provider config - 3. Log the complete streaming response (trigger success handler) - This is used for passthrough endpoints - """ - complete_streaming_response = self._flush_passthrough_collected_chunks_helper( - raw_bytes=raw_bytes, - provider_config=provider_config, - ) - - if complete_streaming_response is not None: - self.success_handler(result=complete_streaming_response) - return - - async def async_flush_passthrough_collected_chunks( - self, - raw_bytes: List[bytes], - provider_config: "BasePassthroughConfig", - ): - complete_streaming_response = self._flush_passthrough_collected_chunks_helper( - raw_bytes=raw_bytes, - provider_config=provider_config, - ) - - if complete_streaming_response is not None: - await self.async_success_handler(result=complete_streaming_response) - return - - def success_handler( # noqa: PLR0915 - self, result=None, start_time=None, end_time=None, cache_hit=None, **kwargs - ): - verbose_logger.debug( - f"Logging Details LiteLLM-Success Call: Cache_hit={cache_hit}" - ) - if not self.should_run_logging( - event_type="sync_success" - ): # prevent double logging - return - start_time, end_time, result = self._success_handler_helper_fn( - start_time=start_time, - end_time=end_time, - result=result, - cache_hit=cache_hit, - standard_logging_object=kwargs.get("standard_logging_object", None), - ) - litellm_params = self.model_call_details.get("litellm_params", {}) - is_sync_request = ( - litellm_params.get(CallTypes.acompletion.value, False) is not True - and litellm_params.get(CallTypes.aresponses.value, False) is not True - and litellm_params.get(CallTypes.aembedding.value, False) is not True - and litellm_params.get(CallTypes.aimage_generation.value, False) is not True - and litellm_params.get(CallTypes.atranscription.value, False) is not True - ) - try: - ## BUILD COMPLETE STREAMED RESPONSE - complete_streaming_response: Optional[ - Union[ModelResponse, TextCompletionResponse, ResponsesAPIResponse] - ] = None - if "complete_streaming_response" in self.model_call_details: - return # break out of this. - complete_streaming_response = self._get_assembled_streaming_response( - result=result, - start_time=start_time, - end_time=end_time, - is_async=False, - streaming_chunks=self.sync_streaming_chunks, - ) - if complete_streaming_response is not None: - verbose_logger.debug( - "Logging Details LiteLLM-Success Call streaming complete" - ) - self.model_call_details[ - "complete_streaming_response" - ] = complete_streaming_response - self.model_call_details[ - "response_cost" - ] = self._response_cost_calculator(result=complete_streaming_response) - ## STANDARDIZED LOGGING PAYLOAD - self.model_call_details[ - "standard_logging_object" - ] = get_standard_logging_object_payload( - kwargs=self.model_call_details, - init_response_obj=complete_streaming_response, - start_time=start_time, - end_time=end_time, - logging_obj=self, - status="success", - standard_built_in_tools_params=self.standard_built_in_tools_params, - ) - if ( - standard_logging_payload := self.model_call_details.get( - "standard_logging_object" - ) - ) is not None: - # Only emit for sync requests (async_success_handler handles async) - if is_sync_request: - emit_standard_logging_payload(standard_logging_payload) - callbacks = self.get_combined_callback_list( - dynamic_success_callbacks=self.dynamic_success_callbacks, - global_callbacks=litellm.success_callback, - ) - - ## REDACT MESSAGES ## - result = redact_message_input_output_from_logging( - model_call_details=( - self.model_call_details - if hasattr(self, "model_call_details") - else {} - ), - result=result, - ) - ## LOGGING HOOK ## - for callback in callbacks: - if isinstance(callback, CustomLogger): - self.model_call_details, result = callback.logging_hook( - kwargs=self.model_call_details, - result=result, - call_type=self.call_type, - ) - - self.has_run_logging(event_type="sync_success") - for callback in callbacks: - try: - should_run = self.should_run_callback( - callback=callback, - litellm_params=litellm_params, - event_hook="success_handler", - ) - if not should_run: - continue - if callback == "promptlayer" and promptLayerLogger is not None: - print_verbose("reaches promptlayer for logging!") - promptLayerLogger.log_event( - kwargs=self.model_call_details, - response_obj=result, - start_time=start_time, - end_time=end_time, - print_verbose=print_verbose, - ) - if callback == "supabase" and supabaseClient is not None: - print_verbose("reaches supabase for logging!") - kwargs = self.model_call_details - - # this only logs streaming once, complete_streaming_response exists i.e when stream ends - if self.stream: - if "complete_streaming_response" not in kwargs: - continue - else: - print_verbose("reaches supabase for streaming logging!") - result = kwargs["complete_streaming_response"] - - model = kwargs["model"] - messages = kwargs["messages"] - optional_params = kwargs.get("optional_params", {}) - litellm_params = kwargs.get("litellm_params", {}) - supabaseClient.log_event( - model=model, - messages=messages, - end_user=optional_params.get("user", "default"), - response_obj=result, - start_time=start_time, - end_time=end_time, - litellm_call_id=( - current_call_id - if ( - current_call_id := litellm_params.get( - "litellm_call_id" - ) - ) - is not None - else str(uuid.uuid4()) - ), - print_verbose=print_verbose, - ) - if callback == "wandb" and weightsBiasesLogger is not None: - print_verbose("reaches wandb for logging!") - weightsBiasesLogger.log_event( - kwargs=self.model_call_details, - response_obj=result, - start_time=start_time, - end_time=end_time, - print_verbose=print_verbose, - ) - if callback == "logfire" and logfireLogger is not None: - verbose_logger.debug("reaches logfire for success logging!") - kwargs = {} - for k, v in self.model_call_details.items(): - if ( - k != "original_response" - ): # copy.deepcopy raises errors as this could be a coroutine - kwargs[k] = v - - # this only logs streaming once, complete_streaming_response exists i.e when stream ends - if self.stream: - if "complete_streaming_response" not in kwargs: - continue - else: - print_verbose("reaches logfire for streaming logging!") - result = kwargs["complete_streaming_response"] - - logfireLogger.log_event( - kwargs=self.model_call_details, - response_obj=result, - start_time=start_time, - end_time=end_time, - print_verbose=print_verbose, - level=LogfireLevel.INFO.value, # type: ignore - ) - - if callback == "lunary" and lunaryLogger is not None: - print_verbose("reaches lunary for logging!") - model = self.model - kwargs = self.model_call_details - - input = kwargs.get("messages", kwargs.get("input", None)) - - type = ( - "embed" - if self.call_type == CallTypes.embedding.value - else "llm" - ) - - # this only logs streaming once, complete_streaming_response exists i.e when stream ends - if self.stream: - if "complete_streaming_response" not in kwargs: - continue - else: - result = kwargs["complete_streaming_response"] - - lunaryLogger.log_event( - type=type, - kwargs=kwargs, - event="end", - model=model, - input=input, - user_id=kwargs.get("user", None), - # user_props=self.model_call_details.get("user_props", None), - extra=kwargs.get("optional_params", {}), - response_obj=result, - start_time=start_time, - end_time=end_time, - run_id=self.litellm_call_id, - print_verbose=print_verbose, - ) - if callback == "helicone" and heliconeLogger is not None: - print_verbose("reaches helicone for logging!") - model = self.model - messages = self.model_call_details["input"] - kwargs = self.model_call_details - - # this only logs streaming once, complete_streaming_response exists i.e when stream ends - if self.stream: - if "complete_streaming_response" not in kwargs: - continue - else: - print_verbose("reaches helicone for streaming logging!") - result = kwargs["complete_streaming_response"] - - heliconeLogger.log_success( - model=model, - messages=messages, - response_obj=result, - start_time=start_time, - end_time=end_time, - print_verbose=print_verbose, - kwargs=kwargs, - ) - if callback == "langfuse": - global langFuseLogger - print_verbose("reaches langfuse for success logging!") - kwargs = {} - for k, v in self.model_call_details.items(): - if ( - k != "original_response" - ): # copy.deepcopy raises errors as this could be a coroutine - kwargs[k] = v - # this only logs streaming once, complete_streaming_response exists i.e when stream ends - if self.stream: - verbose_logger.debug( - f"is complete_streaming_response in kwargs: {kwargs.get('complete_streaming_response', None)}" - ) - if complete_streaming_response is None: - continue - else: - print_verbose("reaches langfuse for streaming logging!") - result = kwargs["complete_streaming_response"] - - langfuse_logger_to_use = LangFuseHandler.get_langfuse_logger_for_request( - globalLangfuseLogger=langFuseLogger, - standard_callback_dynamic_params=self.standard_callback_dynamic_params, - in_memory_dynamic_logger_cache=in_memory_dynamic_logger_cache, - ) - if langfuse_logger_to_use is not None: - _response = langfuse_logger_to_use.log_event_on_langfuse( - kwargs=kwargs, - response_obj=result, - start_time=start_time, - end_time=end_time, - user_id=kwargs.get("user", None), - ) - if _response is not None and isinstance(_response, dict): - _trace_id = _response.get("trace_id", None) - if _trace_id is not None: - in_memory_trace_id_cache.set_cache( - litellm_call_id=self.litellm_call_id, - service_name="langfuse", - trace_id=_trace_id, - ) - if callback == "greenscale" and greenscaleLogger is not None: - kwargs = {} - for k, v in self.model_call_details.items(): - if ( - k != "original_response" - ): # copy.deepcopy raises errors as this could be a coroutine - kwargs[k] = v - # this only logs streaming once, complete_streaming_response exists i.e when stream ends - if self.stream: - verbose_logger.debug( - f"is complete_streaming_response in kwargs: {kwargs.get('complete_streaming_response', None)}" - ) - if complete_streaming_response is None: - continue - else: - print_verbose( - "reaches greenscale for streaming logging!" - ) - result = kwargs["complete_streaming_response"] - - greenscaleLogger.log_event( - kwargs=kwargs, - response_obj=result, - start_time=start_time, - end_time=end_time, - print_verbose=print_verbose, - ) - if callback == "athina" and athinaLogger is not None: - deep_copy = {} - for k, v in self.model_call_details.items(): - deep_copy[k] = v - athinaLogger.log_event( - kwargs=deep_copy, - response_obj=result, - start_time=start_time, - end_time=end_time, - print_verbose=print_verbose, - ) - if callback == "traceloop": - deep_copy = {} - for k, v in self.model_call_details.items(): - if k != "original_response": - deep_copy[k] = v - traceloopLogger.log_event( - kwargs=deep_copy, - response_obj=result, - start_time=start_time, - end_time=end_time, - user_id=kwargs.get("user", None), - print_verbose=print_verbose, - ) - if callback == "s3": - global s3Logger - if s3Logger is None: - s3Logger = S3Logger() - if self.stream: - if "complete_streaming_response" in self.model_call_details: - print_verbose( - "S3Logger Logger: Got Stream Event - Completed Stream Response" - ) - s3Logger.log_event( - kwargs=self.model_call_details, - response_obj=self.model_call_details[ - "complete_streaming_response" - ], - start_time=start_time, - end_time=end_time, - print_verbose=print_verbose, - ) - else: - print_verbose( - "S3Logger Logger: Got Stream Event - No complete stream response as yet" - ) - else: - s3Logger.log_event( - kwargs=self.model_call_details, - response_obj=result, - start_time=start_time, - end_time=end_time, - print_verbose=print_verbose, - ) - - if callback == "openmeter" and is_sync_request: - global openMeterLogger - if openMeterLogger is None: - print_verbose("Instantiates openmeter client") - openMeterLogger = OpenMeterLogger() - if self.stream and complete_streaming_response is None: - openMeterLogger.log_stream_event( - kwargs=self.model_call_details, - response_obj=result, - start_time=start_time, - end_time=end_time, - ) - else: - if self.stream and complete_streaming_response: - self.model_call_details[ - "complete_response" - ] = self.model_call_details.get( - "complete_streaming_response", {} - ) - result = self.model_call_details["complete_response"] - openMeterLogger.log_success_event( - kwargs=self.model_call_details, - response_obj=result, - start_time=start_time, - end_time=end_time, - ) - if ( - isinstance(callback, CustomLogger) - and is_sync_request - and self.call_type - != CallTypes.pass_through.value # pass-through endpoints call async_log_success_event - ): # custom logger class - if self.stream and complete_streaming_response is None: - callback.log_stream_event( - kwargs=self.model_call_details, - response_obj=result, - start_time=start_time, - end_time=end_time, - ) - else: - if self.stream and complete_streaming_response: - self.model_call_details[ - "complete_response" - ] = self.model_call_details.get( - "complete_streaming_response", {} - ) - result = self.model_call_details["complete_response"] - - callback.log_success_event( - kwargs=self.model_call_details, - response_obj=result, - start_time=start_time, - end_time=end_time, - ) - if ( - callable(callback) is True - and is_sync_request - and customLogger is not None - ): # custom logger functions - print_verbose( - "success callbacks: Running Custom Callback Function - {}".format( - callback - ) - ) - - customLogger.log_event( - kwargs=self.model_call_details, - response_obj=result, - start_time=start_time, - end_time=end_time, - print_verbose=print_verbose, - callback_func=callback, - ) - - except Exception as e: - print_verbose( - f"LiteLLM.LoggingError: [Non-Blocking] Exception occurred while success logging with integrations {traceback.format_exc()}" - ) - print_verbose( - f"LiteLLM.Logging: is sentry capture exception initialized {capture_exception}" - ) - if capture_exception: # log this error to sentry for debugging - capture_exception(e) - # Track callback logging failures in Prometheus - try: - self._handle_callback_failure(callback=callback) - except Exception: - pass - except Exception as e: - verbose_logger.exception( - "LiteLLM.LoggingError: [Non-Blocking] Exception occurred while success logging {}".format( - str(e) - ), - ) - - async def async_success_handler( # noqa: PLR0915 - self, result=None, start_time=None, end_time=None, cache_hit=None, **kwargs - ): - """ - Implementing async callbacks, to handle asyncio event loop issues when custom integrations need to use async functions. - """ - print_verbose( - "Logging Details LiteLLM-Async Success Call, cache_hit={}".format(cache_hit) - ) - if not self.should_run_logging( - event_type="async_success" - ): # prevent double logging - return - - ## CALCULATE COST FOR BATCH JOBS - if self.call_type == CallTypes.aretrieve_batch.value and isinstance( - result, LiteLLMBatch - ): - litellm_params = self.litellm_params or {} - litellm_metadata = litellm_params.get("litellm_metadata") or {} - if ( - litellm_metadata.get("batch_ignore_default_logging", False) is True - ): # polling job will query these frequently, don't spam db logs - return - - from litellm.proxy.openai_files_endpoints.common_utils import ( - _is_base64_encoded_unified_file_id, - ) - - # check if file id is a unified file id - is_base64_unified_file_id = _is_base64_encoded_unified_file_id(result.id) - - batch_cost = kwargs.get("batch_cost", None) - batch_usage = kwargs.get("batch_usage", None) - batch_models = kwargs.get("batch_models", None) - has_explicit_batch_data = all( - x is not None for x in (batch_cost, batch_usage, batch_models) - ) - - should_compute_batch_data = ( - not is_base64_unified_file_id - or not has_explicit_batch_data - and result.status == "completed" - ) - if has_explicit_batch_data: - result._hidden_params["response_cost"] = batch_cost - result._hidden_params["batch_models"] = batch_models - result.usage = batch_usage - - elif should_compute_batch_data: - ( - response_cost, - batch_usage, - batch_models, - ) = await _handle_completed_batch( - batch=result, - custom_llm_provider=self.custom_llm_provider, - litellm_params=self.litellm_params, - ) - - result._hidden_params["response_cost"] = response_cost - result._hidden_params["batch_models"] = batch_models - result.usage = batch_usage - - start_time, end_time, result = self._success_handler_helper_fn( - start_time=start_time, - end_time=end_time, - result=result, - cache_hit=cache_hit, - standard_logging_object=kwargs.get("standard_logging_object", None), - ) - - ## BUILD COMPLETE STREAMED RESPONSE - if "async_complete_streaming_response" in self.model_call_details: - return # break out of this. - complete_streaming_response: Optional[ - Union[ModelResponse, TextCompletionResponse, ResponsesAPIResponse] - ] = self._get_assembled_streaming_response( - result=result, - start_time=start_time, - end_time=end_time, - is_async=True, - streaming_chunks=self.streaming_chunks, - ) - - if complete_streaming_response is not None: - print_verbose("Async success callbacks: Got a complete streaming response") - - self.model_call_details[ - "async_complete_streaming_response" - ] = complete_streaming_response - - try: - if self.model_call_details.get("cache_hit", False) is True: - self.model_call_details["response_cost"] = 0.0 - else: - # check if base_model set on azure - _get_base_model_from_metadata( - model_call_details=self.model_call_details - ) - # base_model defaults to None if not set on model_info - self.model_call_details[ - "response_cost" - ] = self._response_cost_calculator( - result=complete_streaming_response - ) - - verbose_logger.debug( - f"Model={self.model}; cost={self.model_call_details['response_cost']}" - ) - except litellm.NotFoundError: - verbose_logger.warning( - f"Model={self.model} not found in completion cost map. Setting 'response_cost' to None" - ) - self.model_call_details["response_cost"] = None - - ## STANDARDIZED LOGGING PAYLOAD - self.model_call_details[ - "standard_logging_object" - ] = get_standard_logging_object_payload( - kwargs=self.model_call_details, - init_response_obj=complete_streaming_response, - start_time=start_time, - end_time=end_time, - logging_obj=self, - status="success", - standard_built_in_tools_params=self.standard_built_in_tools_params, - ) - - # print standard logging payload - if ( - standard_logging_payload := self.model_call_details.get( - "standard_logging_object" - ) - ) is not None: - emit_standard_logging_payload(standard_logging_payload) - elif self.call_type == "pass_through_endpoint": - print_verbose( - "Async success callbacks: Got a pass-through endpoint response" - ) - - self.model_call_details["async_complete_streaming_response"] = result - - # cost calculation not possible for pass-through - self.model_call_details["response_cost"] = None - - ## STANDARDIZED LOGGING PAYLOAD - self.model_call_details[ - "standard_logging_object" - ] = get_standard_logging_object_payload( - kwargs=self.model_call_details, - init_response_obj=result, - start_time=start_time, - end_time=end_time, - logging_obj=self, - status="success", - standard_built_in_tools_params=self.standard_built_in_tools_params, - ) - - # print standard logging payload - if ( - standard_logging_payload := self.model_call_details.get( - "standard_logging_object" - ) - ) is not None: - emit_standard_logging_payload(standard_logging_payload) - callbacks = self.get_combined_callback_list( - dynamic_success_callbacks=self.dynamic_async_success_callbacks, - global_callbacks=litellm._async_success_callback, - ) - - result = redact_message_input_output_from_logging( - model_call_details=( - self.model_call_details if hasattr(self, "model_call_details") else {} - ), - result=result, - ) - - ## LOGGING HOOK ## - - for callback in callbacks: - if isinstance(callback, CustomGuardrail): - from litellm.types.guardrails import GuardrailEventHooks - - if ( - callback.should_run_guardrail( - data=self.model_call_details, - event_type=GuardrailEventHooks.logging_only, - ) - is not True - ): - continue - - self.model_call_details, result = await callback.async_logging_hook( - kwargs=self.model_call_details, - result=result, - call_type=self.call_type, - ) - elif isinstance(callback, CustomLogger): - result = redact_message_input_output_from_custom_logger( - result=result, litellm_logging_obj=self, custom_logger=callback - ) - self.model_call_details, result = await callback.async_logging_hook( - kwargs=self.model_call_details, - result=result, - call_type=self.call_type, - ) - - self.has_run_logging(event_type="async_success") - - for callback in callbacks: - # check if callback can run for this request - litellm_params = self.model_call_details.get("litellm_params", {}) - should_run = self.should_run_callback( - callback=callback, - litellm_params=litellm_params, - event_hook="async_success_handler", - ) - if not should_run: - continue - try: - if callback == "openmeter" and openMeterLogger is not None: - if self.stream is True: - if ( - "async_complete_streaming_response" - in self.model_call_details - ): - await openMeterLogger.async_log_success_event( - kwargs=self.model_call_details, - response_obj=self.model_call_details[ - "async_complete_streaming_response" - ], - start_time=start_time, - end_time=end_time, - ) - else: - await openMeterLogger.async_log_stream_event( # [TODO]: move this to being an async log stream event function - kwargs=self.model_call_details, - response_obj=result, - start_time=start_time, - end_time=end_time, - ) - else: - await openMeterLogger.async_log_success_event( - kwargs=self.model_call_details, - response_obj=result, - start_time=start_time, - end_time=end_time, - ) - - if isinstance(callback, CustomLogger): # custom logger class - model_call_details: Dict = self.model_call_details - ################################## - # call redaction hook for custom logger - model_call_details = callback.redact_standard_logging_payload_from_model_call_details( - model_call_details=model_call_details - ) - ################################## - if self.stream is True: - if "async_complete_streaming_response" in model_call_details: - await callback.async_log_success_event( - kwargs=model_call_details, - response_obj=model_call_details[ - "async_complete_streaming_response" - ], - start_time=start_time, - end_time=end_time, - ) - else: - await callback.async_log_stream_event( # [TODO]: move this to being an async log stream event function - kwargs=model_call_details, - response_obj=result, - start_time=start_time, - end_time=end_time, - ) - else: - await callback.async_log_success_event( - kwargs=model_call_details, - response_obj=result, - start_time=start_time, - end_time=end_time, - ) - if callable(callback): # custom logger functions - global customLogger - if customLogger is None: - customLogger = CustomLogger() - if self.stream: - if ( - "async_complete_streaming_response" - in self.model_call_details - ): - await customLogger.async_log_event( - kwargs=self.model_call_details, - response_obj=self.model_call_details[ - "async_complete_streaming_response" - ], - start_time=start_time, - end_time=end_time, - print_verbose=print_verbose, - callback_func=callback, - ) - else: - await customLogger.async_log_event( - kwargs=self.model_call_details, - response_obj=result, - start_time=start_time, - end_time=end_time, - print_verbose=print_verbose, - callback_func=callback, - ) - if callback == "dynamodb": - global dynamoLogger - if dynamoLogger is None: - dynamoLogger = DyanmoDBLogger() - if self.stream: - if ( - "async_complete_streaming_response" - in self.model_call_details - ): - print_verbose( - "DynamoDB Logger: Got Stream Event - Completed Stream Response" - ) - await dynamoLogger._async_log_event( - kwargs=self.model_call_details, - response_obj=self.model_call_details[ - "async_complete_streaming_response" - ], - start_time=start_time, - end_time=end_time, - print_verbose=print_verbose, - ) - else: - print_verbose( - "DynamoDB Logger: Got Stream Event - No complete stream response as yet" - ) - else: - await dynamoLogger._async_log_event( - kwargs=self.model_call_details, - response_obj=result, - start_time=start_time, - end_time=end_time, - print_verbose=print_verbose, - ) - except Exception: - verbose_logger.error( - f"LiteLLM.LoggingError: [Non-Blocking] Exception occurred while success logging {traceback.format_exc()}" - ) - self._handle_callback_failure(callback=callback) - pass - - def _handle_callback_failure(self, callback: Any): - """ - Handle callback logging failures by incrementing Prometheus metrics. - - Works for both sync and async contexts since Prometheus counter increment is synchronous. - - Args: - callback: The callback that failed - """ - try: - callback_name = self._get_callback_name(callback) - - all_callbacks = litellm.logging_callback_manager._get_all_callbacks() - - for callback_obj in all_callbacks: - if hasattr(callback_obj, "increment_callback_logging_failure"): - callback_obj.increment_callback_logging_failure(callback_name=callback_name) # type: ignore - break # Only increment once - - except Exception as e: - verbose_logger.debug(f"Error in _handle_callback_failure: {str(e)}") - - def _failure_handler_helper_fn( - self, exception, traceback_exception, start_time=None, end_time=None - ): - if start_time is None: - start_time = self.start_time - if end_time is None: - end_time = datetime.datetime.now() - - # on some exceptions, model_call_details is not always initialized, this ensures that we still log those exceptions - if not hasattr(self, "model_call_details"): - self.model_call_details = {} - - self.model_call_details["log_event_type"] = "failed_api_call" - self.model_call_details["exception"] = exception - self.model_call_details["traceback_exception"] = traceback_exception - self.model_call_details["end_time"] = end_time - self.model_call_details.setdefault("original_response", None) - self.model_call_details["response_cost"] = 0 - - if hasattr(exception, "headers") and isinstance(exception.headers, dict): - self.model_call_details.setdefault("litellm_params", {}) - metadata = ( - self.model_call_details["litellm_params"].get("metadata", {}) or {} - ) - metadata.update(exception.headers) - - ## STANDARDIZED LOGGING PAYLOAD - - self.model_call_details[ - "standard_logging_object" - ] = get_standard_logging_object_payload( - kwargs=self.model_call_details, - init_response_obj={}, - start_time=start_time, - end_time=end_time, - logging_obj=self, - status="failure", - error_str=str(exception), - original_exception=exception, - standard_built_in_tools_params=self.standard_built_in_tools_params, - ) - return start_time, end_time - - async def special_failure_handlers(self, exception: Exception): - """ - Custom events, emitted for specific failures. - - Currently just for router model group rate limit error - """ - from litellm.types.router import RouterErrors - - litellm_params: dict = self.model_call_details.get("litellm_params") or {} - metadata = litellm_params.get("metadata") or {} - - ## BASE CASE ## check if rate limit error for model group size 1 - is_base_case = False - if metadata.get("model_group_size") is not None: - model_group_size = metadata.get("model_group_size") - if isinstance(model_group_size, int) and model_group_size == 1: - is_base_case = True - ## check if special error ## - if ( - RouterErrors.no_deployments_available.value not in str(exception) - and is_base_case is False - ): - return - - ## get original model group ## - - model_group = metadata.get("model_group") or None - for callback in litellm._async_failure_callback: - if isinstance(callback, CustomLogger): # custom logger class - await callback.log_model_group_rate_limit_error( - exception=exception, - original_model_group=model_group, - kwargs=self.model_call_details, - ) # type: ignore - - def failure_handler( # noqa: PLR0915 - self, exception, traceback_exception, start_time=None, end_time=None - ): - verbose_logger.debug( - f"Logging Details LiteLLM-Failure Call: {litellm.failure_callback}" - ) - if not self.should_run_logging( - event_type="sync_failure" - ): # prevent double logging - return - litellm_params = self.model_call_details.get("litellm_params", {}) - is_sync_request = ( - litellm_params.get(CallTypes.acompletion.value, False) is not True - and litellm_params.get(CallTypes.aresponses.value, False) is not True - and litellm_params.get(CallTypes.aembedding.value, False) is not True - and litellm_params.get(CallTypes.aimage_generation.value, False) is not True - and litellm_params.get(CallTypes.atranscription.value, False) is not True - ) - - try: - start_time, end_time = self._failure_handler_helper_fn( - exception=exception, - traceback_exception=traceback_exception, - start_time=start_time, - end_time=end_time, - ) - callbacks = self.get_combined_callback_list( - dynamic_success_callbacks=self.dynamic_failure_callbacks, - global_callbacks=litellm.failure_callback, - ) - - result = None # result sent to all loggers, init this to None incase it's not created - - result = redact_message_input_output_from_logging( - model_call_details=( - self.model_call_details - if hasattr(self, "model_call_details") - else {} - ), - result=result, - ) - self.has_run_logging(event_type="sync_failure") - for callback in callbacks: - try: - should_run = self.should_run_callback( - callback=callback, - litellm_params=litellm_params, - event_hook="failure_handler", - ) - if not should_run: - continue - if callback == "lunary" and lunaryLogger is not None: - print_verbose("reaches lunary for logging error!") - - model = self.model - - input = self.model_call_details["input"] - - _type = ( - "embed" - if self.call_type == CallTypes.embedding.value - else "llm" - ) - - lunaryLogger.log_event( - kwargs=self.model_call_details, - type=_type, - event="error", - user_id=self.model_call_details.get("user", "default"), - model=model, - input=input, - error=traceback_exception, - run_id=self.litellm_call_id, - start_time=start_time, - end_time=end_time, - print_verbose=print_verbose, - ) - if callback == "sentry": - print_verbose("sending exception to sentry") - if capture_exception: - capture_exception(exception) - else: - print_verbose( - f"capture exception not initialized: {capture_exception}" - ) - elif callback == "supabase" and supabaseClient is not None: - print_verbose("reaches supabase for logging!") - print_verbose(f"supabaseClient: {supabaseClient}") - supabaseClient.log_event( - model=self.model if hasattr(self, "model") else "", - messages=self.messages, - end_user=self.model_call_details.get("user", "default"), - response_obj=result, - start_time=start_time, - end_time=end_time, - litellm_call_id=self.model_call_details["litellm_call_id"], - print_verbose=print_verbose, - ) - if ( - callable(callback) and customLogger is not None - ): # custom logger functions - customLogger.log_event( - kwargs=self.model_call_details, - response_obj=result, - start_time=start_time, - end_time=end_time, - print_verbose=print_verbose, - callback_func=callback, - ) - if ( - isinstance(callback, CustomLogger) and is_sync_request - ): # custom logger class - callback.log_failure_event( - start_time=start_time, - end_time=end_time, - response_obj=result, - kwargs=self.model_call_details, - ) - if callback == "langfuse": - global langFuseLogger - verbose_logger.debug("reaches langfuse for logging failure") - kwargs = {} - for k, v in self.model_call_details.items(): - if ( - k != "original_response" - ): # copy.deepcopy raises errors as this could be a coroutine - kwargs[k] = v - # this only logs streaming once, complete_streaming_response exists i.e when stream ends - langfuse_logger_to_use = LangFuseHandler.get_langfuse_logger_for_request( - globalLangfuseLogger=langFuseLogger, - standard_callback_dynamic_params=self.standard_callback_dynamic_params, - in_memory_dynamic_logger_cache=in_memory_dynamic_logger_cache, - ) - _response = langfuse_logger_to_use.log_event_on_langfuse( - start_time=start_time, - end_time=end_time, - response_obj=None, - user_id=kwargs.get("user", None), - status_message=str(exception), - level="ERROR", - kwargs=self.model_call_details, - ) - if _response is not None and isinstance(_response, dict): - _trace_id = _response.get("trace_id", None) - if _trace_id is not None: - in_memory_trace_id_cache.set_cache( - litellm_call_id=self.litellm_call_id, - service_name="langfuse", - trace_id=_trace_id, - ) - if callback == "traceloop": - traceloopLogger.log_event( - start_time=start_time, - end_time=end_time, - response_obj=None, - user_id=self.model_call_details.get("user", None), - print_verbose=print_verbose, - status_message=str(exception), - level="ERROR", - kwargs=self.model_call_details, - ) - if callback == "logfire" and logfireLogger is not None: - verbose_logger.debug("reaches logfire for failure logging!") - kwargs = {} - for k, v in self.model_call_details.items(): - if ( - k != "original_response" - ): # copy.deepcopy raises errors as this could be a coroutine - kwargs[k] = v - kwargs["exception"] = exception - - logfireLogger.log_event( - kwargs=kwargs, - response_obj=result, - start_time=start_time, - end_time=end_time, - level=LogfireLevel.ERROR.value, # type: ignore - print_verbose=print_verbose, - ) - - except Exception as e: - print_verbose( - f"LiteLLM.LoggingError: [Non-Blocking] Exception occurred while failure logging with integrations {str(e)}" - ) - print_verbose( - f"LiteLLM.Logging: is sentry capture exception initialized {capture_exception}" - ) - if capture_exception: # log this error to sentry for debugging - capture_exception(e) - except Exception as e: - verbose_logger.exception( - "LiteLLM.LoggingError: [Non-Blocking] Exception occurred while failure logging {}".format( - str(e) - ) - ) - - async def async_failure_handler( - self, exception, traceback_exception, start_time=None, end_time=None - ): - """ - Implementing async callbacks, to handle asyncio event loop issues when custom integrations need to use async functions. - """ - await self.special_failure_handlers(exception=exception) - if not self.should_run_logging( - event_type="async_failure" - ): # prevent double logging - return - start_time, end_time = self._failure_handler_helper_fn( - exception=exception, - traceback_exception=traceback_exception, - start_time=start_time, - end_time=end_time, - ) - - callbacks = self.get_combined_callback_list( - dynamic_success_callbacks=self.dynamic_async_failure_callbacks, - global_callbacks=litellm._async_failure_callback, - ) - - result = None # result sent to all loggers, init this to None incase it's not created - - self.has_run_logging(event_type="async_failure") - for callback in callbacks: - try: - litellm_params = self.model_call_details.get("litellm_params", {}) - should_run = self.should_run_callback( - callback=callback, - litellm_params=litellm_params, - event_hook="async_failure_handler", - ) - if not should_run: - continue - if isinstance(callback, CustomLogger): # custom logger class - await callback.async_log_failure_event( - kwargs=self.model_call_details, - response_obj=result, - start_time=start_time, - end_time=end_time, - ) # type: ignore - if ( - callable(callback) and customLogger is not None - ): # custom logger functions - await customLogger.async_log_event( - kwargs=self.model_call_details, - response_obj=result, - start_time=start_time, - end_time=end_time, - print_verbose=print_verbose, - callback_func=callback, - ) - except Exception as e: - verbose_logger.exception( - "LiteLLM.LoggingError: [Non-Blocking] Exception occurred while failure \ - logging {}\nCallback={}".format( - str(e), callback - ) - ) - # Track callback logging failures in Prometheus - self._handle_callback_failure(callback=callback) - - def _get_trace_id(self, service_name: Literal["langfuse"]) -> Optional[str]: - """ - For the given service (e.g. langfuse), return the trace_id actually logged. - - Used for constructing the url in slack alerting. - - Returns: - - str: The logged trace id - - None: If trace id not yet emitted. - """ - trace_id: Optional[str] = None - if service_name == "langfuse": - trace_id = in_memory_trace_id_cache.get_cache( - litellm_call_id=self.litellm_call_id, service_name=service_name - ) - - return trace_id - - def _get_callback_object(self, service_name: Literal["langfuse"]) -> Optional[Any]: - """ - Return dynamic callback object. - - Meant to solve issue when doing key-based/team-based logging - """ - global langFuseLogger - - if service_name == "langfuse": - if langFuseLogger is None or ( - ( - self.standard_callback_dynamic_params.get("langfuse_public_key") - is not None - and self.standard_callback_dynamic_params.get("langfuse_public_key") - != langFuseLogger.public_key - ) - or ( - self.standard_callback_dynamic_params.get("langfuse_public_key") - is not None - and self.standard_callback_dynamic_params.get("langfuse_public_key") - != langFuseLogger.public_key - ) - or ( - self.standard_callback_dynamic_params.get("langfuse_host") - is not None - and self.standard_callback_dynamic_params.get("langfuse_host") - != langFuseLogger.langfuse_host - ) - ): - return LangFuseLogger( - langfuse_public_key=self.standard_callback_dynamic_params.get( - "langfuse_public_key" - ), - langfuse_secret=self.standard_callback_dynamic_params.get( - "langfuse_secret" - ), - langfuse_host=self.standard_callback_dynamic_params.get( - "langfuse_host" - ), - ) - return langFuseLogger - - return None - - def handle_sync_success_callbacks_for_async_calls( - self, - result: Any, - start_time: datetime.datetime, - end_time: datetime.datetime, - cache_hit: Optional[Any] = None, - ) -> None: - """ - Handles calling success callbacks for Async calls. - - Why: Some callbacks - `langfuse`, `s3` are sync callbacks. We need to call them in the executor. - """ - if self._should_run_sync_callbacks_for_async_calls() is False: - return - - executor.submit( - self.success_handler, - result, - start_time, - end_time, - cache_hit, - ) - - def _should_run_sync_callbacks_for_async_calls(self) -> bool: - """ - Returns: - - bool: True if sync callbacks should be run for async calls. eg. `langfuse`, `s3` - """ - _combined_sync_callbacks = self.get_combined_callback_list( - dynamic_success_callbacks=self.dynamic_success_callbacks, - global_callbacks=litellm.success_callback, - ) - _filtered_success_callbacks = self._remove_internal_custom_logger_callbacks( - _combined_sync_callbacks - ) - _filtered_success_callbacks = self._remove_internal_litellm_callbacks( - _filtered_success_callbacks - ) - return len(_filtered_success_callbacks) > 0 - - def get_combined_callback_list( - self, dynamic_success_callbacks: Optional[List], global_callbacks: List - ) -> List: - if dynamic_success_callbacks is None: - return list(global_callbacks) - return list(set(dynamic_success_callbacks + global_callbacks)) - - def _remove_internal_litellm_callbacks(self, callbacks: List) -> List: - """ - Creates a filtered list of callbacks, excluding internal LiteLLM callbacks. - - Args: - callbacks: List of callback functions/strings to filter - - Returns: - List of filtered callbacks with internal ones removed - """ - filtered = [ - cb for cb in callbacks if not self._is_internal_litellm_proxy_callback(cb) - ] - - verbose_logger.debug(f"Filtered callbacks: {filtered}") - return filtered - - def _get_callback_name(self, cb) -> str: - """ - Helper to get the name of a callback function - - Args: - cb: The callback object/function/string to get the name of - - Returns: - The name of the callback - """ - if isinstance(cb, str): - return cb - if hasattr(cb, "__name__"): - return cb.__name__ - if hasattr(cb, "__func__"): - return cb.__func__.__name__ - if hasattr(cb, "__class__"): - return cb.__class__.__name__ - return str(cb) - - def _is_internal_litellm_proxy_callback(self, cb) -> bool: - """Helper to check if a callback is internal""" - INTERNAL_PREFIXES = [ - "_PROXY", - "_service_logger.ServiceLogging", - "sync_deployment_callback_on_success", - ] - if isinstance(cb, str): - return False - - if not callable(cb): - return True - - cb_name = self._get_callback_name(cb) - return any(prefix in cb_name for prefix in INTERNAL_PREFIXES) - - def _remove_internal_custom_logger_callbacks(self, callbacks: List) -> List: - """ - Removes internal custom logger callbacks from the list. - """ - _new_callbacks = [] - for _c in callbacks: - if isinstance(_c, CustomLogger): - continue - elif ( - isinstance(_c, str) - and _c in litellm._known_custom_logger_compatible_callbacks - ): - continue - _new_callbacks.append(_c) - return _new_callbacks - - def _get_assembled_streaming_response( - self, - result: Union[ - ModelResponse, - TextCompletionResponse, - ModelResponseStream, - ResponseCompletedEvent, - Any, - ], - start_time: datetime.datetime, - end_time: datetime.datetime, - is_async: bool, - streaming_chunks: List[Any], - ) -> Optional[Union[ModelResponse, TextCompletionResponse, ResponsesAPIResponse]]: - if isinstance(result, ModelResponse): - return result - elif isinstance(result, TextCompletionResponse): - return result - elif isinstance(result, ResponseCompletedEvent): - ## return unified Usage object - if isinstance(result.response.usage, ResponseAPIUsage): - transformed_usage = ( - ResponseAPILoggingUtils._transform_response_api_usage_to_chat_usage( - result.response.usage - ) - ) - # Set as dict instead of Usage object so model_dump() serializes it correctly - setattr( - result.response, - "usage", - ( - transformed_usage.model_dump() - if hasattr(transformed_usage, "model_dump") - else dict(transformed_usage) - ), - ) - return result.response - else: - return None - return None - - def _handle_anthropic_messages_response_logging(self, result: Any) -> ModelResponse: - """ - Handles logging for Anthropic messages responses. - - Args: - result: The response object from the model call - - Returns: - The the response object from the model call - - - For Non-streaming responses, we need to transform the response to a ModelResponse object. - - For streaming responses, anthropic_messages handler calls success_handler with a assembled ModelResponse. - """ - import httpx - - if self.stream and isinstance(result, ModelResponse): - return result - elif isinstance(result, ModelResponse): - return result - - httpx_response = self.model_call_details.get("httpx_response", None) - if httpx_response and isinstance(httpx_response, httpx.Response): - result = litellm.AnthropicConfig().transform_response( - raw_response=httpx_response, - model_response=litellm.ModelResponse(), - model=self.model, - messages=[], - logging_obj=self, - optional_params={}, - api_key="", - request_data={}, - encoding=litellm.encoding, - json_mode=False, - litellm_params={}, - ) - else: - from litellm.types.llms.anthropic import AnthropicResponse - - pydantic_result = AnthropicResponse.model_validate(result) - import httpx - - result = litellm.AnthropicConfig().transform_parsed_response( - completion_response=pydantic_result.model_dump(), - raw_response=httpx.Response( - status_code=200, - headers={}, - ), - model_response=litellm.ModelResponse(), - json_mode=None, - ) - return result - - def _handle_non_streaming_google_genai_generate_content_response_logging( - self, result: Any - ) -> ModelResponse: - """ - Handles logging for Google GenAI generate content responses. - """ - import httpx - - httpx_response = self.model_call_details.get("httpx_response", None) - if httpx_response is None: - raise ValueError("Google GenAI Generate Content: httpx_response is None") - dict_result = httpx_response.json() - result = litellm.VertexGeminiConfig()._transform_google_generate_content_to_openai_model_response( - completion_response=dict_result, - model_response=litellm.ModelResponse(), - model=self.model, - logging_obj=self, - raw_response=httpx.Response( - status_code=200, - headers={}, - ), - ) - return result - - def _handle_a2a_response_logging(self, result: Any) -> Any: - """ - Handles logging for A2A (Agent-to-Agent) responses. - - Adds usage from model_call_details to the result if available. - Uses Pydantic's model_copy to avoid modifying the original response. - - Args: - result: The LiteLLMSendMessageResponse from the A2A call - - Returns: - The response object with usage added if available - """ - # Get usage from model_call_details (set by asend_message) - usage = self.model_call_details.get("usage") - if usage is None: - return result - - # Deep copy result and add usage - result_copy = result.model_copy(deep=True) - result_copy.usage = ( - usage.model_dump() if hasattr(usage, "model_dump") else dict(usage) - ) - return result_copy - - -def _get_masked_values( - sensitive_object: dict, - ignore_sensitive_values: bool = False, - mask_all_values: bool = False, - unmasked_length: int = 4, - number_of_asterisks: Optional[int] = 4, -) -> dict: - """ - Internal debugging helper function - - Masks the headers of the request sent from LiteLLM - - Args: - masked_length: Optional length for the masked portion (number of *). If set, will use exactly this many * - regardless of original string length. The total length will be unmasked_length + masked_length. - """ - sensitive_keywords = [ - "authorization", - "token", - "key", - "secret", - "vertex_credentials", - ] - return { - k: ( - # If ignore_sensitive_values is True, or if this key doesn't contain sensitive keywords, return original value - v - if ignore_sensitive_values - or not any( - sensitive_keyword in k.lower() - for sensitive_keyword in sensitive_keywords - ) - else ( - # Apply masking to sensitive keys - ( - v[: unmasked_length // 2] - + "*" * number_of_asterisks - + v[-unmasked_length // 2 :] - ) - if ( - isinstance(v, str) - and len(v) > unmasked_length - and number_of_asterisks is not None - ) - else ( - ( - v[: unmasked_length // 2] - + "*" * (len(v) - unmasked_length) - + v[-unmasked_length // 2 :] - ) - if (isinstance(v, str) and len(v) > unmasked_length) - else ("*****" if isinstance(v, str) else v) - ) - ) - ) - for k, v in sensitive_object.items() - } - - -def set_callbacks(callback_list, function_id=None): # noqa: PLR0915 - """ - Globally sets the callback client - """ - global sentry_sdk_instance, capture_exception, add_breadcrumb, slack_app, alerts_channel, traceloopLogger, athinaLogger, heliconeLogger, supabaseClient, lunaryLogger, promptLayerLogger, langFuseLogger, customLogger, weightsBiasesLogger, logfireLogger, dynamoLogger, s3Logger, dataDogLogger, prometheusLogger, greenscaleLogger, openMeterLogger, deepevalLogger - - try: - for callback in callback_list: - if callback == "sentry": - try: - import sentry_sdk - except ImportError: - print_verbose("Package 'sentry_sdk' is missing. Installing it...") - subprocess.check_call( - [sys.executable, "-m", "pip", "install", "sentry_sdk"] - ) - import sentry_sdk - from sentry_sdk.scrubber import EventScrubber - - sentry_sdk_instance = sentry_sdk - sentry_trace_rate = ( - os.environ.get("SENTRY_API_TRACE_RATE") - if "SENTRY_API_TRACE_RATE" in os.environ - else "1.0" - ) - sentry_sample_rate = ( - os.environ.get("SENTRY_API_SAMPLE_RATE") - if "SENTRY_API_SAMPLE_RATE" in os.environ - else "1.0" - ) - sentry_sdk_instance.init( - dsn=os.environ.get("SENTRY_DSN"), - traces_sample_rate=float(sentry_trace_rate), # type: ignore - sample_rate=float( - sentry_sample_rate if sentry_sample_rate else 1.0 - ), - send_default_pii=False, # Prevent sending Personal Identifiable Information - event_scrubber=EventScrubber( - denylist=SENTRY_DENYLIST, pii_denylist=SENTRY_PII_DENYLIST - ), - environment=os.environ.get("SENTRY_ENVIRONMENT", "production"), - ) - capture_exception = sentry_sdk_instance.capture_exception - add_breadcrumb = sentry_sdk_instance.add_breadcrumb - elif callback == "slack": - try: - from slack_bolt import App - except ImportError: - print_verbose("Package 'slack_bolt' is missing. Installing it...") - subprocess.check_call( - [sys.executable, "-m", "pip", "install", "slack_bolt"] - ) - from slack_bolt import App - slack_app = App( - token=os.environ.get("SLACK_API_TOKEN"), - signing_secret=os.environ.get("SLACK_API_SECRET"), - ) - alerts_channel = os.environ["SLACK_API_CHANNEL"] - print_verbose(f"Initialized Slack App: {slack_app}") - elif callback == "traceloop": - traceloopLogger = TraceloopLogger() - elif callback == "athina": - athinaLogger = AthinaLogger() - print_verbose("Initialized Athina Logger") - elif callback == "helicone": - heliconeLogger = HeliconeLogger() - elif callback == "lunary": - lunaryLogger = LunaryLogger() - elif callback == "promptlayer": - promptLayerLogger = PromptLayerLogger() - elif callback == "langfuse": - langFuseLogger = LangFuseLogger( - langfuse_public_key=None, langfuse_secret=None, langfuse_host=None - ) - elif callback == "openmeter": - openMeterLogger = OpenMeterLogger() - elif callback == "datadog": - dataDogLogger = DataDogLogger() - elif callback == "dynamodb": - dynamoLogger = DyanmoDBLogger() - elif callback == "s3": - s3Logger = S3Logger() - elif callback == "wandb": - from litellm.integrations.weights_biases import WeightsBiasesLogger - - weightsBiasesLogger = WeightsBiasesLogger() - elif callback == "logfire": - logfireLogger = LogfireLogger() - elif callback == "supabase": - print_verbose("instantiating supabase") - supabaseClient = Supabase() - elif callback == "greenscale": - greenscaleLogger = GreenscaleLogger() - print_verbose("Initialized Greenscale Logger") - elif callable(callback): - customLogger = CustomLogger() - except Exception as e: - raise e - return None - - -def _init_custom_logger_compatible_class( # noqa: PLR0915 - logging_integration: _custom_logger_compatible_callbacks_literal, - internal_usage_cache: Optional[DualCache], - llm_router: Optional[ - Any - ], # expect litellm.Router, but typing errors due to circular import - custom_logger_init_args: Optional[dict] = {}, -) -> Optional[CustomLogger]: - """ - Initialize a custom logger compatible class - """ - try: - custom_logger_init_args = custom_logger_init_args or {} - if logging_integration == "agentops": # Add AgentOps initialization - for callback in _in_memory_loggers: - if isinstance(callback, AgentOps): - return callback # type: ignore - - agentops_logger = AgentOps() - _in_memory_loggers.append(agentops_logger) - return agentops_logger # type: ignore - elif logging_integration == "lago": - for callback in _in_memory_loggers: - if isinstance(callback, LagoLogger): - return callback # type: ignore - - lago_logger = LagoLogger() - _in_memory_loggers.append(lago_logger) - return lago_logger # type: ignore - elif logging_integration == "openmeter": - for callback in _in_memory_loggers: - if isinstance(callback, OpenMeterLogger): - return callback # type: ignore - - _openmeter_logger = OpenMeterLogger() - _in_memory_loggers.append(_openmeter_logger) - return _openmeter_logger # type: ignore - elif logging_integration == "posthog": - for callback in _in_memory_loggers: - if isinstance(callback, PostHogLogger): - return callback # type: ignore - - _posthog_logger = PostHogLogger() - _in_memory_loggers.append(_posthog_logger) - return _posthog_logger # type: ignore - elif logging_integration == "braintrust": - from litellm.integrations.braintrust_logging import BraintrustLogger - - for callback in _in_memory_loggers: - if isinstance(callback, BraintrustLogger): - return callback # type: ignore - - braintrust_logger = BraintrustLogger() - _in_memory_loggers.append(braintrust_logger) - return braintrust_logger # type: ignore - elif logging_integration == "langsmith": - for callback in _in_memory_loggers: - if isinstance(callback, LangsmithLogger): - return callback # type: ignore - - _langsmith_logger = LangsmithLogger() - _in_memory_loggers.append(_langsmith_logger) - return _langsmith_logger # type: ignore - elif logging_integration == "argilla": - for callback in _in_memory_loggers: - if isinstance(callback, ArgillaLogger): - return callback # type: ignore - - _argilla_logger = ArgillaLogger() - _in_memory_loggers.append(_argilla_logger) - return _argilla_logger # type: ignore - elif logging_integration == "literalai": - for callback in _in_memory_loggers: - if isinstance(callback, LiteralAILogger): - return callback # type: ignore - - _literalai_logger = LiteralAILogger() - _in_memory_loggers.append(_literalai_logger) - return _literalai_logger # type: ignore - elif logging_integration == "prometheus": - PrometheusLogger = _get_cached_prometheus_logger() - - for callback in _in_memory_loggers: - if isinstance(callback, PrometheusLogger): - return callback # type: ignore - - _prometheus_logger = PrometheusLogger() - _in_memory_loggers.append(_prometheus_logger) - return _prometheus_logger # type: ignore - elif logging_integration == "datadog": - for callback in _in_memory_loggers: - if isinstance(callback, DataDogLogger): - return callback # type: ignore - - _datadog_logger = DataDogLogger() - _in_memory_loggers.append(_datadog_logger) - return _datadog_logger # type: ignore - elif logging_integration == "datadog_llm_observability": - _datadog_llm_obs_logger = DataDogLLMObsLogger() - _in_memory_loggers.append(_datadog_llm_obs_logger) - return _datadog_llm_obs_logger # type: ignore - elif logging_integration == "azure_sentinel": - for callback in _in_memory_loggers: - if isinstance(callback, AzureSentinelLogger): - return callback # type: ignore - - _azure_sentinel_logger = AzureSentinelLogger() - _in_memory_loggers.append(_azure_sentinel_logger) - return _azure_sentinel_logger # type: ignore - elif logging_integration == "gcs_bucket": - for callback in _in_memory_loggers: - if isinstance(callback, GCSBucketLogger): - return callback # type: ignore - - _gcs_bucket_logger = GCSBucketLogger() - _in_memory_loggers.append(_gcs_bucket_logger) - return _gcs_bucket_logger # type: ignore - elif logging_integration == "s3_v2": - for callback in _in_memory_loggers: - if isinstance(callback, S3V2Logger): - return callback # type: ignore - - _s3_v2_logger = S3V2Logger() - _in_memory_loggers.append(_s3_v2_logger) - return _s3_v2_logger # type: ignore - elif logging_integration == "aws_sqs": - for callback in _in_memory_loggers: - if isinstance(callback, SQSLogger): - return callback # type: ignore - - _aws_sqs_logger = SQSLogger() - _in_memory_loggers.append(_aws_sqs_logger) - return _aws_sqs_logger # type: ignore - elif logging_integration == "azure_storage": - for callback in _in_memory_loggers: - if isinstance(callback, AzureBlobStorageLogger): - return callback # type: ignore - - _azure_storage_logger = AzureBlobStorageLogger() - _in_memory_loggers.append(_azure_storage_logger) - return _azure_storage_logger # type: ignore - elif logging_integration == "opik": - for callback in _in_memory_loggers: - if isinstance(callback, OpikLogger): - return callback # type: ignore - - _opik_logger = OpikLogger() - _in_memory_loggers.append(_opik_logger) - return _opik_logger # type: ignore - elif logging_integration == "arize": - from litellm.integrations.opentelemetry import ( - OpenTelemetry, - OpenTelemetryConfig, - ) - - arize_config = ArizeLogger.get_arize_config() - if arize_config.endpoint is None: - raise ValueError( - "No valid endpoint found for Arize, please set 'ARIZE_ENDPOINT' to your GRPC endpoint or 'ARIZE_HTTP_ENDPOINT' to your HTTP endpoint" - ) - otel_config = OpenTelemetryConfig( - exporter=arize_config.protocol, - endpoint=arize_config.endpoint, - service_name=arize_config.project_name, - ) - - os.environ[ - "OTEL_EXPORTER_OTLP_TRACES_HEADERS" - ] = f"space_id={arize_config.space_key or arize_config.space_id},api_key={arize_config.api_key}" - for callback in _in_memory_loggers: - if ( - isinstance(callback, ArizeLogger) - and callback.callback_name == "arize" - ): - return callback # type: ignore - _arize_otel_logger = ArizeLogger(config=otel_config, callback_name="arize") - _in_memory_loggers.append(_arize_otel_logger) - return _arize_otel_logger # type: ignore - elif logging_integration == "arize_phoenix": - from litellm.integrations.opentelemetry import ( - OpenTelemetry, - OpenTelemetryConfig, - ) - - arize_phoenix_config = ArizePhoenixLogger.get_arize_phoenix_config() - otel_config = OpenTelemetryConfig( - exporter=arize_phoenix_config.protocol, - endpoint=arize_phoenix_config.endpoint, - headers=arize_phoenix_config.otlp_auth_headers, - ) - if arize_phoenix_config.project_name: - existing_attrs = os.environ.get("OTEL_RESOURCE_ATTRIBUTES", "") - # Add openinference.project.name attribute - if existing_attrs: - os.environ[ - "OTEL_RESOURCE_ATTRIBUTES" - ] = f"{existing_attrs},openinference.project.name={arize_phoenix_config.project_name}" - else: - os.environ[ - "OTEL_RESOURCE_ATTRIBUTES" - ] = f"openinference.project.name={arize_phoenix_config.project_name}" - - # Set Phoenix project name from environment variable - phoenix_project_name = os.environ.get("PHOENIX_PROJECT_NAME", None) - if phoenix_project_name: - existing_attrs = os.environ.get("OTEL_RESOURCE_ATTRIBUTES", "") - # Add openinference.project.name attribute - if existing_attrs: - os.environ[ - "OTEL_RESOURCE_ATTRIBUTES" - ] = f"{existing_attrs},openinference.project.name={phoenix_project_name}" - else: - os.environ[ - "OTEL_RESOURCE_ATTRIBUTES" - ] = f"openinference.project.name={phoenix_project_name}" - - # auth can be disabled on local deployments of arize phoenix - if arize_phoenix_config.otlp_auth_headers is not None: - os.environ[ - "OTEL_EXPORTER_OTLP_TRACES_HEADERS" - ] = arize_phoenix_config.otlp_auth_headers - - for callback in _in_memory_loggers: - if ( - isinstance(callback, ArizePhoenixLogger) - and callback.callback_name == "arize_phoenix" - ): - return callback # type: ignore - _arize_phoenix_otel_logger = ArizePhoenixLogger( - config=otel_config, callback_name="arize_phoenix" - ) - _in_memory_loggers.append(_arize_phoenix_otel_logger) - return _arize_phoenix_otel_logger # type: ignore - elif logging_integration == "levo": - from litellm.integrations.levo.levo import LevoLogger - from litellm.integrations.opentelemetry import ( - OpenTelemetry, - OpenTelemetryConfig, - ) - - levo_config = LevoLogger.get_levo_config() - otel_config = OpenTelemetryConfig( - exporter=levo_config.protocol, - endpoint=levo_config.endpoint, - headers=levo_config.otlp_auth_headers, - ) - - # Check if LevoLogger instance already exists - for callback in _in_memory_loggers: - if ( - isinstance(callback, LevoLogger) - and callback.callback_name == "levo" - ): - return callback # type: ignore - - _levo_otel_logger = LevoLogger(config=otel_config, callback_name="levo") - _in_memory_loggers.append(_levo_otel_logger) - return _levo_otel_logger # type: ignore - elif logging_integration == "otel": - from litellm.integrations.opentelemetry import OpenTelemetry - - for callback in _in_memory_loggers: - if type(callback) is OpenTelemetry: - return callback # type: ignore - otel_logger = OpenTelemetry( - **_get_custom_logger_settings_from_proxy_server( - callback_name=logging_integration - ) - ) - _in_memory_loggers.append(otel_logger) - return otel_logger # type: ignore - - elif logging_integration == "galileo": - for callback in _in_memory_loggers: - if isinstance(callback, GalileoObserve): - return callback # type: ignore - - galileo_logger = GalileoObserve() - _in_memory_loggers.append(galileo_logger) - return galileo_logger # type: ignore - elif logging_integration == "cloudzero": - from litellm.integrations.cloudzero.cloudzero import CloudZeroLogger - - for callback in _in_memory_loggers: - if isinstance(callback, CloudZeroLogger): - return callback # type: ignore - cloudzero_logger = CloudZeroLogger() - _in_memory_loggers.append(cloudzero_logger) - return cloudzero_logger # type: ignore - elif logging_integration == "focus": - from litellm.integrations.focus.focus_logger import FocusLogger - - for callback in _in_memory_loggers: - if isinstance(callback, FocusLogger): - return callback # type: ignore - focus_logger = FocusLogger() - _in_memory_loggers.append(focus_logger) - return focus_logger # type: ignore - elif logging_integration == "deepeval": - for callback in _in_memory_loggers: - if isinstance(callback, DeepEvalLogger): - return callback # type: ignore - deepeval_logger = DeepEvalLogger() - _in_memory_loggers.append(deepeval_logger) - return deepeval_logger # type: ignore - - elif logging_integration == "logfire": - if "LOGFIRE_TOKEN" not in os.environ: - raise ValueError("LOGFIRE_TOKEN not found in environment variables") - from litellm.integrations.opentelemetry import ( - OpenTelemetry, - OpenTelemetryConfig, - ) - - logfire_base_url = os.getenv( - "LOGFIRE_BASE_URL", "https://logfire-api.pydantic.dev" - ) - otel_config = OpenTelemetryConfig( - exporter="otlp_http", - endpoint=f"{logfire_base_url.rstrip('/')}/v1/traces", - headers=f"Authorization={os.getenv('LOGFIRE_TOKEN')}", - ) - for callback in _in_memory_loggers: - if isinstance(callback, OpenTelemetry): - return callback # type: ignore - _otel_logger = OpenTelemetry(config=otel_config) - _in_memory_loggers.append(_otel_logger) - return _otel_logger # type: ignore - elif logging_integration == "dynamic_rate_limiter": - from litellm.proxy.hooks.dynamic_rate_limiter import ( - _PROXY_DynamicRateLimitHandler, - ) - - for callback in _in_memory_loggers: - if isinstance(callback, _PROXY_DynamicRateLimitHandler): - return callback # type: ignore - - if internal_usage_cache is None: - raise Exception( - "Internal Error: Cache cannot be empty - internal_usage_cache={}".format( - internal_usage_cache - ) - ) - - dynamic_rate_limiter_obj = _PROXY_DynamicRateLimitHandler( - internal_usage_cache=internal_usage_cache - ) - - if llm_router is not None and isinstance(llm_router, litellm.Router): - dynamic_rate_limiter_obj.update_variables(llm_router=llm_router) - _in_memory_loggers.append(dynamic_rate_limiter_obj) - return dynamic_rate_limiter_obj # type: ignore - elif logging_integration == "dynamic_rate_limiter_v3": - from litellm.proxy.hooks.dynamic_rate_limiter_v3 import ( - _PROXY_DynamicRateLimitHandlerV3, - ) - - for callback in _in_memory_loggers: - if isinstance(callback, _PROXY_DynamicRateLimitHandlerV3): - return callback # type: ignore - - if internal_usage_cache is None: - raise Exception( - "Internal Error: Cache cannot be empty - internal_usage_cache={}".format( - internal_usage_cache - ) - ) - - dynamic_rate_limiter_obj_v3 = _PROXY_DynamicRateLimitHandlerV3( - internal_usage_cache=internal_usage_cache - ) - - if llm_router is not None and isinstance(llm_router, litellm.Router): - dynamic_rate_limiter_obj_v3.update_variables(llm_router=llm_router) - _in_memory_loggers.append(dynamic_rate_limiter_obj_v3) - return dynamic_rate_limiter_obj_v3 # type: ignore - elif logging_integration == "langtrace": - if "LANGTRACE_API_KEY" not in os.environ: - raise ValueError("LANGTRACE_API_KEY not found in environment variables") - - from litellm.integrations.opentelemetry import ( - OpenTelemetry, - OpenTelemetryConfig, - ) - - otel_config = OpenTelemetryConfig( - exporter="otlp_http", - endpoint="https://langtrace.ai/api/trace", - ) - os.environ[ - "OTEL_EXPORTER_OTLP_TRACES_HEADERS" - ] = f"api_key={os.getenv('LANGTRACE_API_KEY')}" - for callback in _in_memory_loggers: - if ( - isinstance(callback, OpenTelemetry) - and callback.callback_name == "langtrace" - ): - return callback # type: ignore - _otel_logger = OpenTelemetry(config=otel_config, callback_name="langtrace") - _in_memory_loggers.append(_otel_logger) - return _otel_logger # type: ignore - - elif logging_integration == "mlflow": - for callback in _in_memory_loggers: - if isinstance(callback, MlflowLogger): - return callback # type: ignore - - _mlflow_logger = MlflowLogger() - _in_memory_loggers.append(_mlflow_logger) - return _mlflow_logger # type: ignore - elif logging_integration == "langfuse": - for callback in _in_memory_loggers: - if isinstance(callback, LangfusePromptManagement): - return callback - - langfuse_logger = LangfusePromptManagement() - _in_memory_loggers.append(langfuse_logger) - return langfuse_logger # type: ignore - elif logging_integration == "langfuse_otel": - from litellm.integrations.langfuse.langfuse_otel import LangfuseOtelLogger - - for callback in _in_memory_loggers: - if ( - isinstance(callback, LangfuseOtelLogger) - and callback.callback_name == "langfuse_otel" - ): - return callback # type: ignore - # Allow LangfuseOtelLogger to initialize its own config safely - # This prevents startup crashes if LANGFUSE keys are not in env (e.g. for dynamic usage) - _otel_logger = LangfuseOtelLogger( - config=None, callback_name="langfuse_otel" - ) - _in_memory_loggers.append(_otel_logger) - return _otel_logger # type: ignore - elif logging_integration == "weave_otel": - from litellm.integrations.opentelemetry import OpenTelemetryConfig - from litellm.integrations.weave.weave_otel import ( - WeaveOtelLogger, - get_weave_otel_config, - ) - - weave_otel_config = get_weave_otel_config() - - otel_config = OpenTelemetryConfig( - exporter=weave_otel_config.protocol, - endpoint=weave_otel_config.endpoint, - headers=weave_otel_config.otlp_auth_headers, - ) - - for callback in _in_memory_loggers: - if ( - isinstance(callback, WeaveOtelLogger) - and callback.callback_name == "weave_otel" - ): - return callback # type: ignore - _otel_logger = WeaveOtelLogger( - config=otel_config, callback_name="weave_otel" - ) - _in_memory_loggers.append(_otel_logger) - return _otel_logger # type: ignore - elif logging_integration == "pagerduty": - for callback in _in_memory_loggers: - if isinstance(callback, PagerDutyAlerting): - return callback - pagerduty_logger = PagerDutyAlerting(**custom_logger_init_args) - _in_memory_loggers.append(pagerduty_logger) - return pagerduty_logger # type: ignore - elif logging_integration == "anthropic_cache_control_hook": - for callback in _in_memory_loggers: - if isinstance(callback, AnthropicCacheControlHook): - return callback - anthropic_cache_control_hook = AnthropicCacheControlHook() - _in_memory_loggers.append(anthropic_cache_control_hook) - return anthropic_cache_control_hook # type: ignore - elif logging_integration == "vector_store_pre_call_hook": - from litellm.integrations.vector_store_integrations.vector_store_pre_call_hook import ( - VectorStorePreCallHook, - ) - - for callback in _in_memory_loggers: - if isinstance(callback, VectorStorePreCallHook): - return callback - vector_store_pre_call_hook = VectorStorePreCallHook() - _in_memory_loggers.append(vector_store_pre_call_hook) - return vector_store_pre_call_hook # type: ignore - elif logging_integration == "gcs_pubsub": - for callback in _in_memory_loggers: - if isinstance(callback, GcsPubSubLogger): - return callback - _gcs_pubsub_logger = GcsPubSubLogger() - _in_memory_loggers.append(_gcs_pubsub_logger) - return _gcs_pubsub_logger # type: ignore - elif logging_integration == "generic_api": - for callback in _in_memory_loggers: - if isinstance(callback, GenericAPILogger): - return callback - generic_api_logger = GenericAPILogger() - _in_memory_loggers.append(generic_api_logger) - return generic_api_logger # type: ignore - elif logging_integration == "resend_email": - for callback in _in_memory_loggers: - if isinstance(callback, ResendEmailLogger): - return callback - resend_email_logger = ResendEmailLogger() - _in_memory_loggers.append(resend_email_logger) - return resend_email_logger # type: ignore - elif logging_integration == "sendgrid_email": - for callback in _in_memory_loggers: - if isinstance(callback, SendGridEmailLogger): - return callback - sendgrid_email_logger = SendGridEmailLogger() - _in_memory_loggers.append(sendgrid_email_logger) - return sendgrid_email_logger # type: ignore - elif logging_integration == "smtp_email": - for callback in _in_memory_loggers: - if isinstance(callback, SMTPEmailLogger): - return callback - smtp_email_logger = SMTPEmailLogger() - _in_memory_loggers.append(smtp_email_logger) - return smtp_email_logger # type: ignore - elif logging_integration == "humanloop": - for callback in _in_memory_loggers: - if isinstance(callback, HumanloopLogger): - return callback - - humanloop_logger = HumanloopLogger() - _in_memory_loggers.append(humanloop_logger) - return humanloop_logger # type: ignore - elif logging_integration == "dotprompt": - for callback in _in_memory_loggers: - if isinstance(callback, DotpromptManager): - return callback - - dotprompt_logger = DotpromptManager() - _in_memory_loggers.append(dotprompt_logger) - return dotprompt_logger # type: ignore - elif logging_integration == "bitbucket": - from litellm.integrations.bitbucket.bitbucket_prompt_manager import ( - BitBucketPromptManager, - ) - - for callback in _in_memory_loggers: - if isinstance(callback, BitBucketPromptManager): - return callback - - # Get global BitBucket config - bitbucket_config = getattr(litellm, "global_bitbucket_config", None) - if bitbucket_config is None: - raise ValueError( - "BitBucket configuration not found. Please set litellm.global_bitbucket_config first." - ) - - bitbucket_logger = BitBucketPromptManager(bitbucket_config=bitbucket_config) - _in_memory_loggers.append(bitbucket_logger) - return bitbucket_logger # type: ignore - elif logging_integration == "gitlab": - from litellm.integrations.gitlab.gitlab_prompt_manager import ( - GitLabPromptManager, - ) - - for callback in _in_memory_loggers: - if isinstance(callback, GitLabPromptManager): - return callback - - # Get global BitBucket config - gitlab_config = getattr(litellm, "global_gitlab_config", None) - if gitlab_config is None: - raise ValueError( - "Gitlab configuration not found. Please set litellm.global_gitlab_config first." - ) - - gitlab_logger = GitLabPromptManager(gitlab_config=gitlab_config) - _in_memory_loggers.append(gitlab_logger) - return gitlab_logger # type: ignore - return None - except Exception as e: - verbose_logger.exception( - f"[Non-Blocking Error] Error initializing custom logger: {e}" - ) - return None - return None - - -def get_custom_logger_compatible_class( # noqa: PLR0915 - logging_integration: _custom_logger_compatible_callbacks_literal, -) -> Optional[CustomLogger]: - try: - if logging_integration == "lago": - for callback in _in_memory_loggers: - if isinstance(callback, LagoLogger): - return callback - elif logging_integration == "openmeter": - for callback in _in_memory_loggers: - if isinstance(callback, OpenMeterLogger): - return callback - elif logging_integration == "braintrust": - from litellm.integrations.braintrust_logging import BraintrustLogger - - for callback in _in_memory_loggers: - if isinstance(callback, BraintrustLogger): - return callback - elif logging_integration == "galileo": - for callback in _in_memory_loggers: - if isinstance(callback, GalileoObserve): - return callback - elif logging_integration == "cloudzero": - from litellm.integrations.cloudzero.cloudzero import CloudZeroLogger - - for callback in _in_memory_loggers: - if isinstance(callback, CloudZeroLogger): - return callback - elif logging_integration == "focus": - from litellm.integrations.focus.focus_logger import FocusLogger - - for callback in _in_memory_loggers: - if isinstance(callback, FocusLogger): - return callback - elif logging_integration == "deepeval": - for callback in _in_memory_loggers: - if isinstance(callback, DeepEvalLogger): - return callback - elif logging_integration == "langsmith": - for callback in _in_memory_loggers: - if isinstance(callback, LangsmithLogger): - return callback - elif logging_integration == "argilla": - for callback in _in_memory_loggers: - if isinstance(callback, ArgillaLogger): - return callback - elif logging_integration == "literalai": - for callback in _in_memory_loggers: - if isinstance(callback, LiteralAILogger): - return callback - elif logging_integration == "prometheus": - PrometheusLogger = _get_cached_prometheus_logger() - for callback in _in_memory_loggers: - if isinstance(callback, PrometheusLogger): - return callback - elif logging_integration == "datadog": - for callback in _in_memory_loggers: - if isinstance(callback, DataDogLogger): - return callback - elif logging_integration == "datadog_llm_observability": - for callback in _in_memory_loggers: - if isinstance(callback, DataDogLLMObsLogger): - return callback - elif logging_integration == "azure_sentinel": - for callback in _in_memory_loggers: - if isinstance(callback, AzureSentinelLogger): - return callback - elif logging_integration == "gcs_bucket": - for callback in _in_memory_loggers: - if isinstance(callback, GCSBucketLogger): - return callback - elif logging_integration == "s3_v2": - for callback in _in_memory_loggers: - if isinstance(callback, S3V2Logger): - return callback - elif logging_integration == "aws_sqs": - for callback in _in_memory_loggers: - if isinstance(callback, SQSLogger): - return callback - _aws_sqs_logger = SQSLogger() - _in_memory_loggers.append(_aws_sqs_logger) - return _aws_sqs_logger # type: ignore - elif logging_integration == "azure_storage": - for callback in _in_memory_loggers: - if isinstance(callback, AzureBlobStorageLogger): - return callback - elif logging_integration == "opik": - for callback in _in_memory_loggers: - if isinstance(callback, OpikLogger): - return callback - elif logging_integration == "langfuse": - for callback in _in_memory_loggers: - if isinstance(callback, LangfusePromptManagement): - return callback - elif logging_integration == "otel": - from litellm.integrations.opentelemetry import OpenTelemetry - - for callback in _in_memory_loggers: - if isinstance(callback, OpenTelemetry): - return callback - elif logging_integration == "arize": - if "ARIZE_API_KEY" not in os.environ: - raise ValueError("ARIZE_API_KEY not found in environment variables") - for callback in _in_memory_loggers: - if ( - isinstance(callback, ArizeLogger) - and callback.callback_name == "arize" - ): - return callback - elif logging_integration == "logfire": - if "LOGFIRE_TOKEN" not in os.environ: - raise ValueError("LOGFIRE_TOKEN not found in environment variables") - from litellm.integrations.opentelemetry import OpenTelemetry - - for callback in _in_memory_loggers: - if isinstance(callback, OpenTelemetry): - return callback # type: ignore - - elif logging_integration == "dynamic_rate_limiter": - from litellm.proxy.hooks.dynamic_rate_limiter import ( - _PROXY_DynamicRateLimitHandler, - ) - - for callback in _in_memory_loggers: - if isinstance(callback, _PROXY_DynamicRateLimitHandler): - return callback # type: ignore - elif logging_integration == "dynamic_rate_limiter_v3": - from litellm.proxy.hooks.dynamic_rate_limiter_v3 import ( - _PROXY_DynamicRateLimitHandlerV3, - ) - - for callback in _in_memory_loggers: - if isinstance(callback, _PROXY_DynamicRateLimitHandlerV3): - return callback # type: ignore - - elif logging_integration == "langtrace": - from litellm.integrations.opentelemetry import OpenTelemetry - - if "LANGTRACE_API_KEY" not in os.environ: - raise ValueError("LANGTRACE_API_KEY not found in environment variables") - - for callback in _in_memory_loggers: - if ( - isinstance(callback, OpenTelemetry) - and callback.callback_name == "langtrace" - ): - return callback - - elif logging_integration == "mlflow": - for callback in _in_memory_loggers: - if isinstance(callback, MlflowLogger): - return callback - elif logging_integration == "pagerduty": - for callback in _in_memory_loggers: - if isinstance(callback, PagerDutyAlerting): - return callback - elif logging_integration == "anthropic_cache_control_hook": - for callback in _in_memory_loggers: - if isinstance(callback, AnthropicCacheControlHook): - return callback - elif logging_integration == "vector_store_pre_call_hook": - from litellm.integrations.vector_store_integrations.vector_store_pre_call_hook import ( - VectorStorePreCallHook, - ) - - for callback in _in_memory_loggers: - if isinstance(callback, VectorStorePreCallHook): - return callback - elif logging_integration == "gcs_pubsub": - for callback in _in_memory_loggers: - if isinstance(callback, GcsPubSubLogger): - return callback - elif logging_integration == "generic_api": - for callback in _in_memory_loggers: - if isinstance(callback, GenericAPILogger): - return callback - elif logging_integration == "resend_email": - for callback in _in_memory_loggers: - if isinstance(callback, ResendEmailLogger): - return callback - elif logging_integration == "sendgrid_email": - for callback in _in_memory_loggers: - if isinstance(callback, SendGridEmailLogger): - return callback - elif logging_integration == "smtp_email": - for callback in _in_memory_loggers: - if isinstance(callback, SMTPEmailLogger): - return callback - return None - - except Exception as e: - verbose_logger.exception( - f"[Non-Blocking Error] Error getting custom logger: {e}" - ) - return None - - -def _get_custom_logger_settings_from_proxy_server(callback_name: str) -> Dict: - """ - Get the settings for a custom logger from the proxy server config.yaml - - Proxy server config.yaml defines callback_settings as: - - callback_settings: - otel: - message_logging: False - """ - if litellm.callback_settings: - return dict(litellm.callback_settings.get(callback_name, {})) - return {} - - -def use_custom_pricing_for_model(litellm_params: Optional[dict]) -> bool: - """ - Check if the model uses custom pricing - - Returns True if any of `SPECIAL_MODEL_INFO_PARAMS` are present in `litellm_params` or `model_info` - """ - if litellm_params is None: - return False - - # Check litellm_params using set intersection (only check keys that exist in both) - matching_keys = _CUSTOM_PRICING_KEYS & litellm_params.keys() - for key in matching_keys: - if litellm_params.get(key) is not None: - return True - - # Check model_info - metadata: dict = litellm_params.get("metadata", {}) or {} - model_info: dict = metadata.get("model_info", {}) or {} - - if model_info: - matching_keys = _CUSTOM_PRICING_KEYS & model_info.keys() - for key in matching_keys: - if model_info.get(key) is not None: - return True - - return False - - -def is_valid_sha256_hash(value: str) -> bool: - # Check if the value is a valid SHA-256 hash (64 hexadecimal characters) - return bool(re.fullmatch(r"[a-fA-F0-9]{64}", value)) - - -class StandardLoggingPayloadSetup: - @staticmethod - def cleanup_timestamps( - start_time: Union[dt_object, float], - end_time: Union[dt_object, float], - completion_start_time: Union[dt_object, float], - ) -> Tuple[float, float, float]: - """ - Convert datetime objects to floats - - Args: - start_time: Union[dt_object, float] - end_time: Union[dt_object, float] - completion_start_time: Union[dt_object, float] - - Returns: - Tuple[float, float, float]: A tuple containing the start time, end time, and completion start time as floats. - """ - - if isinstance(start_time, datetime.datetime): - start_time_float = start_time.timestamp() - elif isinstance(start_time, float): - start_time_float = start_time - else: - raise ValueError( - f"start_time is required, got={start_time} of type {type(start_time)}" - ) - - if isinstance(end_time, datetime.datetime): - end_time_float = end_time.timestamp() - elif isinstance(end_time, float): - end_time_float = end_time - else: - raise ValueError( - f"end_time is required, got={end_time} of type {type(end_time)}" - ) - - if isinstance(completion_start_time, datetime.datetime): - completion_start_time_float = completion_start_time.timestamp() - elif isinstance(completion_start_time, float): - completion_start_time_float = completion_start_time - else: - completion_start_time_float = end_time_float - - return start_time_float, end_time_float, completion_start_time_float - - @staticmethod - def append_system_prompt_messages( - kwargs: Optional[Dict] = None, messages: Optional[Any] = None - ): - """ - Append system prompt messages to the messages - """ - if kwargs is not None: - if kwargs.get("system") is not None and isinstance( - kwargs.get("system"), str - ): - if messages is None: - return [{"role": "system", "content": kwargs.get("system")}] - elif isinstance(messages, list): - if len(messages) == 0: - return [{"role": "system", "content": kwargs.get("system")}] - # check for duplicates - if messages[0].get("role") == "system" and messages[0].get( - "content" - ) == kwargs.get("system"): - return messages - messages = [ - {"role": "system", "content": kwargs.get("system")} - ] + messages - elif isinstance(messages, str): - messages = [ - {"role": "system", "content": kwargs.get("system")}, - {"role": "user", "content": messages}, - ] - return messages - - return messages - - @staticmethod - def merge_litellm_metadata(litellm_params: dict) -> dict: - """ - Merge both litellm_metadata and metadata from litellm_params. - - litellm_metadata contains model-related fields, metadata contains user API key fields. - We need both for complete standard logging payload. - - Args: - litellm_params: Dictionary containing metadata and litellm_metadata - - Returns: - dict: Merged metadata with user API key fields taking precedence - """ - merged_metadata: dict = {} - - # Start with metadata (user API key fields) - but skip non-serializable objects - if litellm_params.get("metadata") and isinstance( - litellm_params.get("metadata"), dict - ): - for key, value in litellm_params["metadata"].items(): - # Skip non-serializable objects like UserAPIKeyAuth - if key == "user_api_key_auth": - continue - merged_metadata[key] = value - - # Then merge litellm_metadata (model-related fields) - this will NOT overwrite existing keys - if litellm_params.get("litellm_metadata") and isinstance( - litellm_params.get("litellm_metadata"), dict - ): - for key, value in litellm_params["litellm_metadata"].items(): - if ( - key not in merged_metadata - ): # Don't overwrite existing keys from metadata - merged_metadata[key] = value - - return merged_metadata - - @staticmethod - def get_standard_logging_metadata( - metadata: Optional[Dict[str, Any]], - litellm_params: Optional[dict] = None, - prompt_integration: Optional[str] = None, - applied_guardrails: Optional[List[str]] = None, - mcp_tool_call_metadata: Optional[StandardLoggingMCPToolCall] = None, - vector_store_request_metadata: Optional[ - List[StandardLoggingVectorStoreRequest] - ] = None, - usage_object: Optional[dict] = None, - proxy_server_request: Optional[dict] = None, - start_time: Optional[dt_object] = None, - response_id: Optional[str] = None, - ) -> StandardLoggingMetadata: - """ - Clean and filter the metadata dictionary to include only the specified keys in StandardLoggingMetadata. - - Args: - metadata (Optional[Dict[str, Any]]): The original metadata dictionary. - - Returns: - StandardLoggingMetadata: A StandardLoggingMetadata object containing the cleaned metadata. - - Note: - - If the input metadata is None or not a dictionary, an empty StandardLoggingMetadata object is returned. - - If 'user_api_key' is present in metadata and is a valid SHA256 hash, it's stored as 'user_api_key_hash'. - """ - - prompt_management_metadata: Optional[ - StandardLoggingPromptManagementMetadata - ] = None - if litellm_params is not None: - prompt_id = cast(Optional[str], litellm_params.get("prompt_id", None)) - prompt_variables = cast( - Optional[dict], litellm_params.get("prompt_variables", None) - ) - - if prompt_id is not None and prompt_integration is not None: - prompt_management_metadata = StandardLoggingPromptManagementMetadata( - prompt_id=prompt_id, - prompt_variables=prompt_variables, - prompt_integration=prompt_integration, - ) - - # Initialize with default values - clean_metadata = StandardLoggingMetadata( - user_api_key_hash=None, - user_api_key_alias=None, - user_api_key_spend=None, - user_api_key_max_budget=None, - user_api_key_budget_reset_at=None, - user_api_key_team_id=None, - user_api_key_org_id=None, - user_api_key_user_id=None, - user_api_key_team_alias=None, - user_api_key_user_email=None, - user_api_key_end_user_id=None, - user_api_key_request_route=None, - spend_logs_metadata=None, - requester_ip_address=None, - user_agent=None, - requester_metadata=None, - prompt_management_metadata=prompt_management_metadata, - applied_guardrails=applied_guardrails, - mcp_tool_call_metadata=mcp_tool_call_metadata, - vector_store_request_metadata=vector_store_request_metadata, - usage_object=usage_object, - requester_custom_headers=None, - cold_storage_object_key=None, - user_api_key_auth_metadata=None, - team_alias=None, - team_id=None, - ) - if isinstance(metadata, dict): - for key in metadata.keys() & _STANDARD_LOGGING_METADATA_KEYS: - clean_metadata[key] = metadata[key] # type: ignore - - user_api_key = metadata.get("user_api_key") - if ( - user_api_key - and isinstance(user_api_key, str) - and is_valid_sha256_hash(user_api_key) - ): - clean_metadata["user_api_key_hash"] = user_api_key - _potential_requester_metadata = metadata.get( - "metadata", None - ) # check if user passed metadata in the sdk request - e.g. metadata for langsmith logging - https://docs.litellm.ai/docs/observability/langsmith_integration#set-langsmith-fields - if ( - clean_metadata["requester_metadata"] is None - and _potential_requester_metadata is not None - and isinstance(_potential_requester_metadata, dict) - ): - clean_metadata["requester_metadata"] = _potential_requester_metadata - - if ( - EnterpriseStandardLoggingPayloadSetupVAR - and proxy_server_request is not None - ): - clean_metadata = EnterpriseStandardLoggingPayloadSetupVAR.apply_enterprise_specific_metadata( - standard_logging_metadata=clean_metadata, - proxy_server_request=proxy_server_request, - ) - - # Generate cold storage object key if cold storage is configured - if start_time is not None and response_id is not None: - cold_storage_object_key = ( - StandardLoggingPayloadSetup._generate_cold_storage_object_key( - start_time=start_time, - response_id=response_id, - team_alias=clean_metadata.get("user_api_key_team_alias"), - ) - ) - if cold_storage_object_key: - clean_metadata["cold_storage_object_key"] = cold_storage_object_key - - return clean_metadata - - @staticmethod - def get_usage_from_response_obj( - response_obj: Optional[dict], combined_usage_object: Optional[Usage] = None - ) -> Usage: - ## BASE CASE ## - if combined_usage_object is not None: - return combined_usage_object - if response_obj is None: - return Usage( - prompt_tokens=0, - completion_tokens=0, - total_tokens=0, - ) - - usage = response_obj.get("usage", None) or {} - if usage is None or ( - not isinstance(usage, dict) and not isinstance(usage, Usage) - ): - return Usage( - prompt_tokens=0, - completion_tokens=0, - total_tokens=0, - ) - elif isinstance(usage, Usage): - return usage - elif isinstance(usage, ResponseAPIUsage): - return ResponseAPILoggingUtils._transform_response_api_usage_to_chat_usage( - usage - ) - elif isinstance(usage, dict): - if ResponseAPILoggingUtils._is_response_api_usage(usage): - return ( - ResponseAPILoggingUtils._transform_response_api_usage_to_chat_usage( - usage - ) - ) - return Usage(**usage) - - raise ValueError(f"usage is required, got={usage} of type {type(usage)}") - - @staticmethod - def get_model_cost_information( - base_model: Optional[str], - custom_pricing: Optional[bool], - custom_llm_provider: Optional[str], - init_response_obj: Union[Any, BaseModel, dict], - ) -> StandardLoggingModelInformation: - model_cost_name = _select_model_name_for_cost_calc( - model=None, - completion_response=init_response_obj, # type: ignore - base_model=base_model, - custom_pricing=custom_pricing, - ) - if model_cost_name is None: - model_cost_information = StandardLoggingModelInformation( - model_map_key="", model_map_value=None - ) - else: - try: - _model_cost_information = litellm.get_model_info( - model=model_cost_name, custom_llm_provider=custom_llm_provider - ) - model_cost_information = StandardLoggingModelInformation( - model_map_key=model_cost_name, - model_map_value=_model_cost_information, - ) - except Exception: - verbose_logger.debug( # keep in debug otherwise it will trigger on every call - "Model={} is not mapped in model cost map. Defaulting to None model_cost_information for standard_logging_payload".format( - model_cost_name - ) - ) - model_cost_information = StandardLoggingModelInformation( - model_map_key=model_cost_name, model_map_value=None - ) - return model_cost_information - - @staticmethod - def get_final_response_obj( - response_obj: dict, init_response_obj: Union[Any, BaseModel, dict], kwargs: dict - ) -> Optional[Union[dict, str, list]]: - """ - Get final response object after redacting the message input/output from logging - """ - if response_obj: - final_response_obj: Optional[Union[dict, str, list]] = response_obj - elif isinstance(init_response_obj, list) or isinstance(init_response_obj, str): - final_response_obj = init_response_obj - else: - final_response_obj = {} - - modified_final_response_obj = redact_message_input_output_from_logging( - model_call_details=kwargs, - result=final_response_obj, - ) - - if modified_final_response_obj is not None and isinstance( - modified_final_response_obj, BaseModel - ): - final_response_obj = modified_final_response_obj.model_dump() - else: - final_response_obj = modified_final_response_obj - - return final_response_obj - - @staticmethod - def get_additional_headers( - additiona_headers: Optional[dict], - ) -> Optional[StandardLoggingAdditionalHeaders]: - if additiona_headers is None: - return None - - additional_logging_headers: StandardLoggingAdditionalHeaders = {} - - for key in StandardLoggingAdditionalHeaders.__annotations__.keys(): - _key = key.lower() - _key = _key.replace("_", "-") - if _key in additiona_headers: - try: - additional_logging_headers[key] = int(additiona_headers[_key]) # type: ignore - except (ValueError, TypeError): - verbose_logger.debug( - f"Could not convert {additiona_headers[_key]} to int for key {key}." - ) - return additional_logging_headers - - @staticmethod - def get_hidden_params( - hidden_params: Optional[dict], - ) -> StandardLoggingHiddenParams: - clean_hidden_params = StandardLoggingHiddenParams( - model_id=None, - cache_key=None, - api_base=None, - response_cost=None, - additional_headers=None, - litellm_overhead_time_ms=None, - batch_models=None, - litellm_model_name=None, - usage_object=None, - ) - if hidden_params is not None: - for key in StandardLoggingHiddenParams.__annotations__.keys(): - if key in hidden_params: - if key == "additional_headers": - clean_hidden_params[ - "additional_headers" - ] = StandardLoggingPayloadSetup.get_additional_headers( - hidden_params[key] - ) - else: - clean_hidden_params[key] = hidden_params[key] # type: ignore - return clean_hidden_params - - @staticmethod - def strip_trailing_slash(api_base: Optional[str]) -> Optional[str]: - if api_base: - if api_base.endswith("//"): - return api_base.rstrip("/") - if api_base[-1] == "/": - return api_base[:-1] - return api_base - - @staticmethod - def _generate_cold_storage_object_key( - start_time: dt_object, - response_id: str, - team_alias: Optional[str] = None, - ) -> Optional[str]: - """ - Generate cold storage object key in the same format as S3Logger. - - Args: - start_time: The start time of the request - response_id: The response ID - team_alias: Optional team alias for team-based prefixing - - Returns: - Optional[str]: The generated object key or None if cold storage not configured - """ - # Generate object key in same format as S3Logger - from litellm.integrations.s3 import get_s3_object_key - - # Only generate object key if cold storage is configured - cold_storage_custom_logger = litellm.cold_storage_custom_logger - if cold_storage_custom_logger is None: - return None - - try: - # Generate file name in same format as litellm.utils.get_logging_id - s3_file_name = f"time-{start_time.strftime('%H-%M-%S-%f')}_{response_id}" - - # Get the actual s3_path from the configured cold storage logger instance - s3_path = "" # default value - - # Try to get the actual logger instance from the logger name - try: - custom_logger = litellm.logging_callback_manager.get_active_custom_logger_for_callback_name( - cold_storage_custom_logger - ) - if ( - custom_logger - and hasattr(custom_logger, "s3_path") - and getattr(custom_logger, "s3_path") - ): - s3_path = getattr(custom_logger, "s3_path") - except Exception: - # If any error occurs in getting the logger instance, use default empty s3_path - pass - - s3_object_key = get_s3_object_key( - s3_path=s3_path, # Use actual s3_path from logger configuration - prefix="", # Don't split by team alias for cold storage - start_time=start_time, - s3_file_name=s3_file_name, - ) - - return s3_object_key - except Exception: - # If any error occurs in generating the key, return None - return None - - @staticmethod - def get_error_information( - original_exception: Optional[Exception], - traceback_str: Optional[str] = None, - ) -> StandardLoggingPayloadErrorInformation: - from litellm.constants import MAXIMUM_TRACEBACK_LINES_TO_LOG - - # Check for 'code' first (used by ProxyException), then fall back to 'status_code' (used by LiteLLM exceptions) - # Ensure error_code is always a string for Prisma Python JSON field compatibility - error_code_attr = getattr(original_exception, "code", None) - if error_code_attr is not None and str(error_code_attr) not in ("", "None"): - error_status: str = str(error_code_attr) - else: - status_code_attr = getattr(original_exception, "status_code", None) - error_status = str(status_code_attr) if status_code_attr is not None else "" - error_class: str = ( - str(original_exception.__class__.__name__) if original_exception else "" - ) - _llm_provider_in_exception = getattr(original_exception, "llm_provider", "") - - # Get traceback information (first 100 lines) - traceback_info = traceback_str or "" - if original_exception: - tb = getattr(original_exception, "__traceback__", None) - if tb: - tb_lines = traceback.format_tb(tb) - traceback_info += "".join( - tb_lines[:MAXIMUM_TRACEBACK_LINES_TO_LOG] - ) # Limit to first 100 lines - - # Get additional error details - error_message = str(original_exception) - - return StandardLoggingPayloadErrorInformation( - error_code=error_status, - error_class=error_class, - llm_provider=_llm_provider_in_exception, - traceback=traceback_info, - error_message=error_message if original_exception else "", - ) - - @staticmethod - def get_response_time( - start_time_float: float, - end_time_float: float, - completion_start_time_float: float, - stream: bool, - ) -> float: - """ - Get the response time for the LLM response - - Args: - start_time_float: float - start time of the LLM call - end_time_float: float - end time of the LLM call - completion_start_time_float: float - time to first token of the LLM response (for streaming responses) - stream: bool - True when a stream response is returned - - Returns: - float: The response time for the LLM response - """ - if stream is True: - return completion_start_time_float - start_time_float - else: - return end_time_float - start_time_float - - @staticmethod - def _get_standard_logging_payload_trace_id( - logging_obj: Logging, - litellm_params: dict, - ) -> str: - """ - Returns the `litellm_trace_id` for this request - - This helps link sessions when multiple requests are made in a single session - """ - dynamic_litellm_session_id = litellm_params.get("litellm_session_id") - dynamic_litellm_trace_id = litellm_params.get("litellm_trace_id") - - # Note: we recommend using `litellm_session_id` for session tracking - # `litellm_trace_id` is an internal litellm param - if dynamic_litellm_session_id: - return str(dynamic_litellm_session_id) - elif dynamic_litellm_trace_id: - return str(dynamic_litellm_trace_id) - else: - return logging_obj.litellm_trace_id - - @staticmethod - def _get_user_agent_tags(proxy_server_request: dict) -> Optional[List[str]]: - """ - Return the user agent tags from the proxy server request for spend tracking - """ - if litellm.disable_add_user_agent_to_request_tags is True: - return None - user_agent_tags: Optional[List[str]] = None - headers = proxy_server_request.get("headers", {}) - if headers is not None and isinstance(headers, dict): - if "user-agent" in headers: - user_agent = headers["user-agent"] - if user_agent is not None: - if user_agent_tags is None: - user_agent_tags = [] - user_agent_part: Optional[str] = None - if "/" in user_agent: - user_agent_part = user_agent.split("/")[0] - if user_agent_part is not None: - user_agent_tags.append("User-Agent: " + user_agent_part) - if user_agent is not None: - user_agent_tags.append("User-Agent: " + user_agent) - return user_agent_tags - - @staticmethod - def _get_extra_header_tags(proxy_server_request: dict) -> Optional[List[str]]: - """ - Extract additional header tags for spend tracking based on config. - """ - extra_headers: List[str] = ( - getattr(litellm, "extra_spend_tag_headers", None) or [] - ) - if not extra_headers: - return None - - headers = proxy_server_request.get("headers", {}) - if not isinstance(headers, dict): - return None - - header_tags = [] - for header_name in extra_headers: - header_value = headers.get(header_name) - if header_value: - header_tags.append(f"{header_name}: {header_value}") - - return header_tags if header_tags else None - - @staticmethod - def _get_request_tags( - litellm_params: dict, proxy_server_request: dict - ) -> List[str]: - # check for 'tags' in both 'metadata' and 'litellm_metadata' - metadata = litellm_params.get("metadata") or {} - litellm_metadata = litellm_params.get("litellm_metadata") or {} - if metadata.get("tags", []): - request_tags = metadata.get("tags", []).copy() - elif litellm_metadata.get("tags", []): - request_tags = litellm_metadata.get("tags", []).copy() - else: - request_tags = [] - user_agent_tags = StandardLoggingPayloadSetup._get_user_agent_tags( - proxy_server_request - ) - additional_header_tags = StandardLoggingPayloadSetup._get_extra_header_tags( - proxy_server_request - ) - if user_agent_tags is not None: - request_tags.extend(user_agent_tags) - if additional_header_tags is not None: - request_tags.extend(additional_header_tags) - return request_tags - - -def _get_status_fields( - status: StandardLoggingPayloadStatus, - guardrail_information: Optional[List[dict]], - error_str: Optional[str], -) -> "StandardLoggingPayloadStatusFields": - """ - Determine status fields based on request status and guardrail information. - - Args: - status: Overall request status ("success" or "failure") - guardrail_information: Guardrail information from metadata - error_str: Error string if any - - Returns: - StandardLoggingPayloadStatusFields with llm_api_status and guardrail_status - """ - # Mapping for legacy guardrail status values to new GuardrailStatus values - GUARDRAIL_STATUS_MAP: Dict[str, GuardrailStatus] = { - "success": "success", - "blocked": "guardrail_intervened", # legacy - "guardrail_intervened": "guardrail_intervened", # direct - "failure": "guardrail_failed_to_respond", # legacy - "guardrail_failed_to_respond": "guardrail_failed_to_respond", # direct - "not_run": "not_run", - } - - # Set LLM API status - llm_api_status: StandardLoggingPayloadStatus = status - - ######################################################### - # Map - guardrail_information.guardrail_status to guardrail_status - ######################################################### - guardrail_status: GuardrailStatus = "not_run" - if guardrail_information and isinstance(guardrail_information, list): - for information in guardrail_information: - if isinstance(information, dict): - raw_status = information.get("guardrail_status", "not_run") - if raw_status != "not_run": - guardrail_status = GUARDRAIL_STATUS_MAP.get(raw_status, "not_run") - break - - return StandardLoggingPayloadStatusFields( - llm_api_status=llm_api_status, guardrail_status=guardrail_status - ) - - -def _extract_response_obj_and_hidden_params( - init_response_obj: Union[Any, BaseModel, dict], - original_exception: Optional[Exception], -) -> Tuple[dict, Optional[dict]]: - """Extract response_obj and hidden_params from init_response_obj.""" - hidden_params: Optional[dict] = None - if init_response_obj is None: - response_obj = {} - elif isinstance(init_response_obj, BaseModel): - response_obj = init_response_obj.model_dump() - hidden_params = getattr(init_response_obj, "_hidden_params", None) - elif isinstance(init_response_obj, dict): - response_obj = init_response_obj - else: - response_obj = {} - - if original_exception is not None and hidden_params is None: - response_headers = _get_response_headers(original_exception) - if response_headers is not None: - hidden_params = dict( - StandardLoggingHiddenParams( - additional_headers=StandardLoggingPayloadSetup.get_additional_headers( - dict(response_headers) - ), - model_id=None, - cache_key=None, - api_base=None, - response_cost=None, - litellm_overhead_time_ms=None, - batch_models=None, - litellm_model_name=None, - usage_object=None, - ) - ) - - return response_obj, hidden_params - - -def get_standard_logging_object_payload( - kwargs: Optional[dict], - init_response_obj: Union[Any, BaseModel, dict], - start_time: dt_object, - end_time: dt_object, - logging_obj: Logging, - status: StandardLoggingPayloadStatus, - error_str: Optional[str] = None, - original_exception: Optional[Exception] = None, - standard_built_in_tools_params: Optional[StandardBuiltInToolsParams] = None, -) -> Optional[StandardLoggingPayload]: - try: - kwargs = kwargs or {} - - response_obj, hidden_params = _extract_response_obj_and_hidden_params( - init_response_obj, original_exception - ) - - # standardize this function to be used across, s3, dynamoDB, langfuse logging - litellm_params = kwargs.get("litellm_params", {}) or {} - proxy_server_request = litellm_params.get("proxy_server_request") or {} - - # Merge both litellm_metadata and metadata to get complete metadata - metadata: dict = StandardLoggingPayloadSetup.merge_litellm_metadata( - litellm_params - ) - - completion_start_time = kwargs.get("completion_start_time", end_time) - call_type = kwargs.get("call_type") - cache_hit = kwargs.get("cache_hit", False) - usage = StandardLoggingPayloadSetup.get_usage_from_response_obj( - response_obj=response_obj, - combined_usage_object=cast( - Optional[Usage], kwargs.get("combined_usage_object") - ), - ) - - id = response_obj.get("id", kwargs.get("litellm_call_id")) - - _model_id = metadata.get("model_info", {}).get("id", "") - _model_group = metadata.get("model_group", "") - - request_tags = StandardLoggingPayloadSetup._get_request_tags( - litellm_params=litellm_params, proxy_server_request=proxy_server_request - ) - - # cleanup timestamps - ( - start_time_float, - end_time_float, - completion_start_time_float, - ) = StandardLoggingPayloadSetup.cleanup_timestamps( - start_time=start_time, - end_time=end_time, - completion_start_time=completion_start_time, - ) - response_time = StandardLoggingPayloadSetup.get_response_time( - start_time_float=start_time_float, - end_time_float=end_time_float, - completion_start_time_float=completion_start_time_float, - stream=kwargs.get("stream", False), - ) - # clean up litellm hidden params - clean_hidden_params = StandardLoggingPayloadSetup.get_hidden_params( - hidden_params - ) - - # clean up litellm metadata - clean_metadata = StandardLoggingPayloadSetup.get_standard_logging_metadata( - metadata=metadata, - litellm_params=litellm_params, - prompt_integration=kwargs.get("prompt_integration", None), - applied_guardrails=kwargs.get("applied_guardrails", None), - mcp_tool_call_metadata=kwargs.get("mcp_tool_call_metadata", None), - vector_store_request_metadata=kwargs.get( - "vector_store_request_metadata", None - ), - usage_object=usage.model_dump(), - proxy_server_request=proxy_server_request, - start_time=start_time, - response_id=id, - ) - _request_body = proxy_server_request.get("body", {}) - end_user_id = clean_metadata["user_api_key_end_user_id"] or _request_body.get( - "user", None - ) # maintain backwards compatibility with old request body check - - saved_cache_cost: float = 0.0 - if cache_hit is True: - id = f"{id}_cache_hit{time.time()}" # do not duplicate the request id - saved_cache_cost = ( - logging_obj._response_cost_calculator( - result=init_response_obj, cache_hit=False # type: ignore - ) - or 0.0 - ) - - ## Get model cost information ## - base_model = _get_base_model_from_metadata(model_call_details=kwargs) - custom_pricing = use_custom_pricing_for_model(litellm_params=litellm_params) - - model_cost_information = StandardLoggingPayloadSetup.get_model_cost_information( - base_model=base_model, - custom_pricing=custom_pricing, - custom_llm_provider=kwargs.get("custom_llm_provider"), - init_response_obj=init_response_obj, - ) - response_cost: float = kwargs.get("response_cost", 0) or 0.0 - - error_information = StandardLoggingPayloadSetup.get_error_information( - original_exception=original_exception, - ) - - ## get final response object ## - final_response_obj = StandardLoggingPayloadSetup.get_final_response_obj( - response_obj=response_obj, - init_response_obj=init_response_obj, - kwargs=kwargs, - ) - - stream: Optional[bool] = None - if ( - kwargs.get("complete_streaming_response") is not None - or kwargs.get("async_complete_streaming_response") is not None - ) and kwargs.get("stream") is True: - stream = True - - # Reconstruct full model name with provider prefix for logging - # This ensures Bedrock models like "us.anthropic.claude-3-5-sonnet-20240620-v1:0" - # are logged as "bedrock/us.anthropic.claude-3-5-sonnet-20240620-v1:0" - custom_llm_provider = cast(Optional[str], kwargs.get("custom_llm_provider")) - model_name = reconstruct_model_name( - kwargs.get("model", "") or "", custom_llm_provider, metadata - ) - - payload: StandardLoggingPayload = StandardLoggingPayload( - id=str(id), - trace_id=StandardLoggingPayloadSetup._get_standard_logging_payload_trace_id( - logging_obj=logging_obj, - litellm_params=litellm_params, - ), - call_type=call_type or "", - cache_hit=cache_hit, - stream=stream, - status=status, - status_fields=_get_status_fields( - status=status, - guardrail_information=metadata.get( - "standard_logging_guardrail_information", None - ), - error_str=error_str, - ), - custom_llm_provider=custom_llm_provider, - saved_cache_cost=saved_cache_cost, - startTime=start_time_float, - endTime=end_time_float, - completionStartTime=completion_start_time_float, - response_time=response_time, - model=model_name, - metadata=clean_metadata, - cache_key=clean_hidden_params["cache_key"], - response_cost=response_cost, - cost_breakdown=logging_obj.cost_breakdown, - total_tokens=usage.total_tokens, - prompt_tokens=usage.prompt_tokens, - completion_tokens=usage.completion_tokens, - request_tags=request_tags, - end_user=end_user_id or "", - api_base=StandardLoggingPayloadSetup.strip_trailing_slash( - litellm_params.get("api_base", "") - ) - or "", - model_group=_model_group, - model_id=_model_id, - requester_ip_address=clean_metadata.get("requester_ip_address", None), - user_agent=clean_metadata.get("user_agent", None), - messages=StandardLoggingPayloadSetup.append_system_prompt_messages( - kwargs=kwargs, messages=kwargs.get("messages") - ), - response=final_response_obj, - model_parameters=ModelParamHelper.get_standard_logging_model_parameters( - kwargs.get("optional_params", None) or {} - ), - hidden_params=clean_hidden_params, - model_map_information=model_cost_information, - error_str=error_str, - error_information=error_information, - response_cost_failure_debug_info=kwargs.get( - "response_cost_failure_debug_information" - ), - guardrail_information=metadata.get( - "standard_logging_guardrail_information", None - ), - standard_built_in_tools_params=standard_built_in_tools_params, - ) - - # emit_standard_logging_payload(payload) - Moved to success_handler to prevent double emitting - - return payload - except Exception as e: - verbose_logger.exception( - "Error creating standard logging object - {}".format(str(e)) - ) - return None - - -def emit_standard_logging_payload(payload: StandardLoggingPayload): - if os.getenv("LITELLM_PRINT_STANDARD_LOGGING_PAYLOAD"): - print(json.dumps(payload, indent=4)) # noqa - - -def get_standard_logging_metadata( - metadata: Optional[Dict[str, Any]], -) -> StandardLoggingMetadata: - """ - Clean and filter the metadata dictionary to include only the specified keys in StandardLoggingMetadata. - - Args: - metadata (Optional[Dict[str, Any]]): The original metadata dictionary. - - Returns: - StandardLoggingMetadata: A StandardLoggingMetadata object containing the cleaned metadata. - - Note: - - If the input metadata is None or not a dictionary, an empty StandardLoggingMetadata object is returned. - - If 'user_api_key' is present in metadata and is a valid SHA256 hash, it's stored as 'user_api_key_hash'. - """ - # Initialize with default values - clean_metadata = StandardLoggingMetadata( - user_api_key_hash=None, - user_api_key_alias=None, - user_api_key_spend=None, - user_api_key_max_budget=None, - user_api_key_budget_reset_at=None, - user_api_key_team_id=None, - user_api_key_org_id=None, - user_api_key_user_id=None, - user_api_key_user_email=None, - user_api_key_team_alias=None, - spend_logs_metadata=None, - requester_ip_address=None, - user_agent=None, - requester_metadata=None, - user_api_key_end_user_id=None, - prompt_management_metadata=None, - applied_guardrails=None, - mcp_tool_call_metadata=None, - vector_store_request_metadata=None, - usage_object=None, - requester_custom_headers=None, - user_api_key_request_route=None, - cold_storage_object_key=None, - user_api_key_auth_metadata=None, - team_alias=None, - team_id=None, - ) - if isinstance(metadata, dict): - # Update the clean_metadata with values from input metadata that match StandardLoggingMetadata fields - for key in StandardLoggingMetadata.__annotations__.keys(): - if key in metadata: - clean_metadata[key] = metadata[key] # type: ignore - - if metadata.get("user_api_key") is not None: - if is_valid_sha256_hash(str(metadata.get("user_api_key"))): - clean_metadata["user_api_key_hash"] = metadata.get( - "user_api_key" - ) # this is the hash - return clean_metadata - - -def scrub_sensitive_keys_in_metadata(litellm_params: Optional[dict]): - if litellm_params is None: - litellm_params = {} - - metadata = litellm_params.get("metadata", {}) or {} - - ## Extract provider-specific callable values (like langfuse_masking_function) - ## Store them separately so only the intended logger can access them - ## This prevents callables from leaking to other logging integrations - if "langfuse_masking_function" in metadata: - masking_fn = metadata.pop("langfuse_masking_function", None) - if callable(masking_fn): - litellm_params["_langfuse_masking_function"] = masking_fn - litellm_params["metadata"] = metadata - - ## check user_api_key_metadata for sensitive logging keys - cleaned_user_api_key_metadata = {} - if "user_api_key_metadata" in metadata and isinstance( - metadata["user_api_key_metadata"], dict - ): - for k, v in metadata["user_api_key_metadata"].items(): - if k == "logging": # prevent logging user logging keys - cleaned_user_api_key_metadata[ - k - ] = "scrubbed_by_litellm_for_sensitive_keys" - else: - cleaned_user_api_key_metadata[k] = v - - metadata["user_api_key_metadata"] = cleaned_user_api_key_metadata - litellm_params["metadata"] = metadata - - return litellm_params - - -# integration helper function -def modify_integration(integration_name, integration_params): - global supabaseClient - if integration_name == "supabase": - if "table_name" in integration_params: - Supabase.supabase_table_name = integration_params["table_name"] - - -@lru_cache(maxsize=16) -def _get_traceback_str_for_error(error_str: str) -> str: - """ - function wrapped with lru_cache to limit the number of times `traceback.format_exc()` is called - """ - return traceback.format_exc() - - -from decimal import Decimal - -# used for unit testing -from typing import Any, Dict, List, Optional, Union - - -def create_dummy_standard_logging_payload() -> StandardLoggingPayload: - # First create the nested objects with proper typing - model_info = StandardLoggingModelInformation( - model_map_key="gpt-3.5-turbo", model_map_value=None - ) - - metadata = StandardLoggingMetadata( # type: ignore - user_api_key_hash=str("test_hash"), - user_api_key_alias=str("test_alias"), - user_api_key_team_id=str("test_team"), - user_api_key_user_id=str("test_user"), - user_api_key_team_alias=str("test_team_alias"), - user_api_key_org_id=None, - spend_logs_metadata=None, - requester_ip_address=str("127.0.0.1"), - requester_metadata=None, - user_api_key_end_user_id=str("test_end_user"), - ) - - hidden_params = StandardLoggingHiddenParams( - model_id=None, - cache_key=None, - api_base=None, - response_cost=None, - additional_headers=None, - litellm_overhead_time_ms=None, - batch_models=None, - litellm_model_name=None, - usage_object=None, - ) - - # Convert numeric values to appropriate types - response_cost = Decimal("0.1") - start_time = Decimal("1234567890.0") - end_time = Decimal("1234567891.0") - completion_start_time = Decimal("1234567890.5") - saved_cache_cost = Decimal("0.0") - - # Create messages and response with proper typing - messages: List[Dict[str, str]] = [{"role": "user", "content": "Hello, world!"}] - response: Dict[str, List[Dict[str, Dict[str, str]]]] = { - "choices": [{"message": {"content": "Hi there!"}}] - } - - # Main payload initialization - return StandardLoggingPayload( # type: ignore - id=str("test_id"), - call_type=str("completion"), - stream=bool(False), - response_cost=response_cost, - response_cost_failure_debug_info=None, - status=str("success"), - total_tokens=int( - DEFAULT_MOCK_RESPONSE_PROMPT_TOKEN_COUNT - + DEFAULT_MOCK_RESPONSE_COMPLETION_TOKEN_COUNT - ), - prompt_tokens=int(DEFAULT_MOCK_RESPONSE_PROMPT_TOKEN_COUNT), - completion_tokens=int(DEFAULT_MOCK_RESPONSE_COMPLETION_TOKEN_COUNT), - startTime=start_time, - endTime=end_time, - completionStartTime=completion_start_time, - model_map_information=model_info, - model=str("gpt-3.5-turbo"), - model_id=str("model-123"), - model_group=str("openai-gpt"), - custom_llm_provider=str("openai"), - api_base=str("https://api.openai.com"), - metadata=metadata, - cache_hit=bool(False), - cache_key=None, - saved_cache_cost=saved_cache_cost, - request_tags=[], - end_user=None, - requester_ip_address=str("127.0.0.1"), - messages=messages, - response=response, - error_str=None, - model_parameters={"stream": True}, - hidden_params=hidden_params, - ) +# What is this? +## Common Utility file for Logging handler +# Logging function -> log the exact model details + what's being sent | Non-Blocking +import copy +import datetime +import json +import os +import re +import subprocess +import sys +import time +import traceback +from datetime import datetime as dt_object +from functools import lru_cache +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + List, + Literal, + Optional, + Tuple, + Type, + Union, + cast, +) + +from httpx import Response +from pydantic import BaseModel + +import litellm +from litellm import ( + _custom_logger_compatible_callbacks_literal, + json_logs, + log_raw_request_response, + turn_off_message_logging, +) +from litellm._logging import _is_debugging_on, verbose_logger +from litellm._uuid import uuid +from litellm.batches.batch_utils import _handle_completed_batch +from litellm.caching.caching import DualCache, InMemoryCache +from litellm.caching.caching_handler import LLMCachingHandler +from litellm.constants import ( + DEFAULT_MOCK_RESPONSE_COMPLETION_TOKEN_COUNT, + DEFAULT_MOCK_RESPONSE_PROMPT_TOKEN_COUNT, + SENTRY_DENYLIST, + SENTRY_PII_DENYLIST, +) +from litellm.cost_calculator import ( + RealtimeAPITokenUsageProcessor, + _select_model_name_for_cost_calc, +) +from litellm.integrations.agentops import AgentOps +from litellm.integrations.anthropic_cache_control_hook import AnthropicCacheControlHook +from litellm.integrations.arize.arize import ArizeLogger +from litellm.integrations.custom_guardrail import CustomGuardrail +from litellm.integrations.custom_logger import CustomLogger +from litellm.integrations.deepeval.deepeval import DeepEvalLogger +from litellm.integrations.mlflow import MlflowLogger +from litellm.integrations.sqs import SQSLogger +from litellm.litellm_core_utils.core_helpers import reconstruct_model_name +from litellm.litellm_core_utils.get_litellm_params import get_litellm_params +from litellm.litellm_core_utils.llm_cost_calc.tool_call_cost_tracking import ( + StandardBuiltInToolCostTracking, +) +from litellm.litellm_core_utils.model_param_helper import ModelParamHelper +from litellm.litellm_core_utils.redact_messages import ( + redact_message_input_output_from_custom_logger, + redact_message_input_output_from_logging, +) +from litellm.llms.base_llm.ocr.transformation import OCRResponse +from litellm.llms.base_llm.search.transformation import SearchResponse +from litellm.responses.utils import ResponseAPILoggingUtils +from litellm.types.agents import LiteLLMSendMessageResponse +from litellm.types.containers.main import ContainerObject +from litellm.types.llms.openai import ( + AllMessageValues, + Batch, + FineTuningJob, + HttpxBinaryResponseContent, + OpenAIFileObject, + OpenAIModerationResponse, + ResponseAPIUsage, + ResponseCompletedEvent, + ResponsesAPIResponse, +) +from litellm.types.mcp import MCPPostCallResponseObject +from litellm.types.prompts.init_prompts import PromptSpec +from litellm.types.rerank import RerankResponse +from litellm.types.utils import ( + CachingDetails, + CallTypes, + CostBreakdown, + CostResponseTypes, + CustomPricingLiteLLMParams, + DynamicPromptManagementParamLiteral, + EmbeddingResponse, + GuardrailStatus, + ImageResponse, + LiteLLMBatch, + LiteLLMLoggingBaseClass, + LiteLLMRealtimeStreamLoggingObject, + ModelResponse, + ModelResponseStream, + RawRequestTypedDict, + StandardBuiltInToolsParams, + StandardCallbackDynamicParams, + StandardLoggingAdditionalHeaders, + StandardLoggingHiddenParams, + StandardLoggingMCPToolCall, + StandardLoggingMetadata, + StandardLoggingModelCostFailureDebugInformation, + StandardLoggingModelInformation, + StandardLoggingPayload, + StandardLoggingPayloadErrorInformation, + StandardLoggingPayloadStatus, + StandardLoggingPayloadStatusFields, + StandardLoggingPromptManagementMetadata, + StandardLoggingVectorStoreRequest, + TextCompletionResponse, + TranscriptionResponse, + Usage, +) +from litellm.types.videos.main import VideoObject +from litellm.utils import _get_base_model_from_metadata, executor, print_verbose + +from ..integrations.argilla import ArgillaLogger +from ..integrations.arize.arize_phoenix import ArizePhoenixLogger +from ..integrations.athina import AthinaLogger +from ..integrations.azure_sentinel.azure_sentinel import AzureSentinelLogger +from ..integrations.azure_storage.azure_storage import AzureBlobStorageLogger +from ..integrations.custom_prompt_management import CustomPromptManagement +from ..integrations.datadog.datadog import DataDogLogger +from ..integrations.datadog.datadog_llm_obs import DataDogLLMObsLogger +from ..integrations.dotprompt import DotpromptManager +from ..integrations.dynamodb import DyanmoDBLogger +from ..integrations.galileo import GalileoObserve +from ..integrations.gcs_bucket.gcs_bucket import GCSBucketLogger +from ..integrations.gcs_pubsub.pub_sub import GcsPubSubLogger +from ..integrations.greenscale import GreenscaleLogger +from ..integrations.helicone import HeliconeLogger +from ..integrations.humanloop import HumanloopLogger +from ..integrations.lago import LagoLogger +from ..integrations.langfuse.langfuse import LangFuseLogger +from ..integrations.langfuse.langfuse_handler import LangFuseHandler +from ..integrations.langfuse.langfuse_prompt_management import LangfusePromptManagement +from ..integrations.langsmith import LangsmithLogger +from ..integrations.literal_ai import LiteralAILogger +from ..integrations.logfire_logger import LogfireLevel, LogfireLogger +from ..integrations.lunary import LunaryLogger +from ..integrations.openmeter import OpenMeterLogger +from ..integrations.opik.opik import OpikLogger +from ..integrations.posthog import PostHogLogger +from ..integrations.prompt_layer import PromptLayerLogger +from ..integrations.s3 import S3Logger +from ..integrations.s3_v2 import S3Logger as S3V2Logger +from ..integrations.supabase import Supabase +from ..integrations.traceloop import TraceloopLogger +from .exception_mapping_utils import _get_response_headers +from .initialize_dynamic_callback_params import ( + initialize_standard_callback_dynamic_params as _initialize_standard_callback_dynamic_params, +) +from .specialty_caches.dynamic_logging_cache import DynamicLoggingCache + +if TYPE_CHECKING: + from litellm.llms.base_llm.passthrough.transformation import BasePassthroughConfig +try: + from litellm_enterprise.enterprise_callbacks.callback_controls import ( + EnterpriseCallbackControls, + ) + from litellm_enterprise.enterprise_callbacks.pagerduty.pagerduty import ( + PagerDutyAlerting, + ) + from litellm_enterprise.enterprise_callbacks.send_emails.resend_email import ( + ResendEmailLogger, + ) + from litellm_enterprise.enterprise_callbacks.send_emails.sendgrid_email import ( + SendGridEmailLogger, + ) + from litellm_enterprise.enterprise_callbacks.send_emails.smtp_email import ( + SMTPEmailLogger, + ) + from litellm_enterprise.litellm_core_utils.litellm_logging import ( + StandardLoggingPayloadSetup as EnterpriseStandardLoggingPayloadSetup, + ) + + from litellm.integrations.generic_api.generic_api_callback import GenericAPILogger + + EnterpriseStandardLoggingPayloadSetupVAR: Optional[ + Type[EnterpriseStandardLoggingPayloadSetup] + ] = EnterpriseStandardLoggingPayloadSetup +except Exception as e: + verbose_logger.debug( + f"[Non-Blocking] Unable to import GenericAPILogger - LiteLLM Enterprise Feature - {str(e)}" + ) + GenericAPILogger = CustomLogger # type: ignore + ResendEmailLogger = CustomLogger # type: ignore + SendGridEmailLogger = CustomLogger # type: ignore + SMTPEmailLogger = CustomLogger # type: ignore + PagerDutyAlerting = CustomLogger # type: ignore + EnterpriseCallbackControls = None # type: ignore + EnterpriseStandardLoggingPayloadSetupVAR = None +_in_memory_loggers: List[Any] = [] + +_STANDARD_LOGGING_METADATA_KEYS: frozenset = frozenset( + StandardLoggingMetadata.__annotations__.keys() +) + +### GLOBAL VARIABLES ### + +# Cache custom pricing keys as frozenset for O(1) lookups instead of looping through 49 keys +_CUSTOM_PRICING_KEYS: frozenset = frozenset( + CustomPricingLiteLLMParams.model_fields.keys() +) + +sentry_sdk_instance = None +capture_exception = None +add_breadcrumb = None +slack_app = None +alerts_channel = None +heliconeLogger = None +athinaLogger = None +promptLayerLogger = None +logfireLogger = None +weightsBiasesLogger = None +customLogger = None +langFuseLogger = None +openMeterLogger = None +lagoLogger = None +dataDogLogger = None +prometheusLogger = None +dynamoLogger = None +s3Logger = None +greenscaleLogger = None +lunaryLogger = None +supabaseClient = None +deepevalLogger = None +callback_list: Optional[List[str]] = [] +user_logger_fn = None +additional_details: Optional[Dict[str, str]] = {} +local_cache: Optional[Dict[str, str]] = {} +last_fetched_at = None +last_fetched_at_keys = None + + +#### +class ServiceTraceIDCache: + def __init__(self) -> None: + self.cache = InMemoryCache() + + def get_cache(self, litellm_call_id: str, service_name: str) -> Optional[str]: + key_name = "{}:{}".format(service_name, litellm_call_id) + response = self.cache.get_cache(key=key_name) + return response + + def set_cache(self, litellm_call_id: str, service_name: str, trace_id: str) -> None: + key_name = "{}:{}".format(service_name, litellm_call_id) + self.cache.set_cache(key=key_name, value=trace_id) + return None + + +in_memory_trace_id_cache = ServiceTraceIDCache() +in_memory_dynamic_logger_cache = DynamicLoggingCache() + +# Cached lazy import for PrometheusLogger +# Module-level cache to avoid repeated imports while preserving memory benefits +_PrometheusLogger = None + + +def _get_cached_prometheus_logger(): + """ + Get cached PrometheusLogger class. + Lazy imports on first call to avoid loading prometheus.py and utils.py at import time (60MB saved). + Subsequent calls use cached class for better performance. + """ + global _PrometheusLogger + if _PrometheusLogger is None: + from litellm.integrations.prometheus import PrometheusLogger + + _PrometheusLogger = PrometheusLogger + return _PrometheusLogger + + +class Logging(LiteLLMLoggingBaseClass): + global supabaseClient, promptLayerLogger, weightsBiasesLogger, logfireLogger, capture_exception, add_breadcrumb, lunaryLogger, logfireLogger, prometheusLogger, slack_app + custom_pricing: bool = False + stream_options = None + litellm_request_debug: bool = False + + def __init__( + self, + model: str, + messages, + stream, + call_type, + start_time, + litellm_call_id: str, + function_id: str, + litellm_trace_id: Optional[str] = None, + dynamic_input_callbacks: Optional[ + List[Union[str, Callable, CustomLogger]] + ] = None, + dynamic_success_callbacks: Optional[ + List[Union[str, Callable, CustomLogger]] + ] = None, + dynamic_async_success_callbacks: Optional[ + List[Union[str, Callable, CustomLogger]] + ] = None, + dynamic_failure_callbacks: Optional[ + List[Union[str, Callable, CustomLogger]] + ] = None, + dynamic_async_failure_callbacks: Optional[ + List[Union[str, Callable, CustomLogger]] + ] = None, + applied_guardrails: Optional[List[str]] = None, + kwargs: Optional[Dict] = None, + log_raw_request_response: bool = False, + ): + _input: Optional[str] = messages # save original value of messages + if messages is not None: + if isinstance(messages, str): + messages = [ + {"role": "user", "content": messages} + ] # convert text completion input to the chat completion format + elif ( + isinstance(messages, list) + and len(messages) > 0 + and isinstance(messages[0], str) + ): + new_messages = [] + for m in messages: + new_messages.append({"role": "user", "content": m}) + messages = new_messages + + self.model = model + self.messages = copy.deepcopy(messages) if messages is not None else None + self.stream = stream + self.start_time = start_time # log the call start time + self.call_type = call_type + self.litellm_call_id = litellm_call_id + self.litellm_trace_id: str = ( + litellm_trace_id if litellm_trace_id else str(uuid.uuid4()) + ) + self.function_id = function_id + self.streaming_chunks: List[Any] = [] # for generating complete stream response + self.sync_streaming_chunks: List[ + Any + ] = [] # for generating complete stream response + self.log_raw_request_response = log_raw_request_response + + # Initialize dynamic callbacks + self.dynamic_input_callbacks: Optional[ + List[Union[str, Callable, CustomLogger]] + ] = dynamic_input_callbacks + self.dynamic_success_callbacks: Optional[ + List[Union[str, Callable, CustomLogger]] + ] = dynamic_success_callbacks + self.dynamic_async_success_callbacks: Optional[ + List[Union[str, Callable, CustomLogger]] + ] = dynamic_async_success_callbacks + self.dynamic_failure_callbacks: Optional[ + List[Union[str, Callable, CustomLogger]] + ] = dynamic_failure_callbacks + self.dynamic_async_failure_callbacks: Optional[ + List[Union[str, Callable, CustomLogger]] + ] = dynamic_async_failure_callbacks + + # Process dynamic callbacks + self.process_dynamic_callbacks() + + ## DYNAMIC LANGFUSE / GCS / logging callback KEYS ## + self.standard_callback_dynamic_params: StandardCallbackDynamicParams = ( + self.initialize_standard_callback_dynamic_params(kwargs) + ) + self.standard_built_in_tools_params: StandardBuiltInToolsParams = ( + self.initialize_standard_built_in_tools_params(kwargs) + ) + ## TIME TO FIRST TOKEN LOGGING ## + self.completion_start_time: Optional[datetime.datetime] = None + self._llm_caching_handler: Optional[LLMCachingHandler] = None + + # INITIAL LITELLM_PARAMS + litellm_params = {} + if kwargs is not None: + litellm_params = get_litellm_params(**kwargs) + litellm_params = scrub_sensitive_keys_in_metadata(litellm_params) + + self.litellm_params = litellm_params + + # Initialize cost breakdown field + self.cost_breakdown: Optional[CostBreakdown] = None + + # Init Caching related details + self.caching_details: Optional[CachingDetails] = None + + # Passthrough endpoint guardrails config for field targeting + self.passthrough_guardrails_config: Optional[Dict[str, Any]] = None + + self.model_call_details: Dict[str, Any] = { + "litellm_trace_id": litellm_trace_id, + "litellm_call_id": litellm_call_id, + "input": _input, + "litellm_params": litellm_params, + "applied_guardrails": applied_guardrails, + "model": model, + } + + def process_dynamic_callbacks(self): + """ + Initializes CustomLogger compatible callbacks in self.dynamic_* callbacks + + If a callback is in litellm._known_custom_logger_compatible_callbacks, it needs to be intialized and added to the respective dynamic_* callback list. + """ + # Process input callbacks + self.dynamic_input_callbacks = self._process_dynamic_callback_list( + self.dynamic_input_callbacks, dynamic_callbacks_type="input" + ) + + # Process failure callbacks + self.dynamic_failure_callbacks = self._process_dynamic_callback_list( + self.dynamic_failure_callbacks, dynamic_callbacks_type="failure" + ) + + # Process async failure callbacks + self.dynamic_async_failure_callbacks = self._process_dynamic_callback_list( + self.dynamic_async_failure_callbacks, dynamic_callbacks_type="async_failure" + ) + + # Process success callbacks + self.dynamic_success_callbacks = self._process_dynamic_callback_list( + self.dynamic_success_callbacks, dynamic_callbacks_type="success" + ) + + # Process async success callbacks + self.dynamic_async_success_callbacks = self._process_dynamic_callback_list( + self.dynamic_async_success_callbacks, dynamic_callbacks_type="async_success" + ) + + def _process_dynamic_callback_list( + self, + callback_list: Optional[List[Union[str, Callable, CustomLogger]]], + dynamic_callbacks_type: Literal[ + "input", "success", "failure", "async_success", "async_failure" + ], + ) -> Optional[List[Union[str, Callable, CustomLogger]]]: + """ + Helper function to initialize CustomLogger compatible callbacks in self.dynamic_* callbacks + + - If a callback is in litellm._known_custom_logger_compatible_callbacks, + replace the string with the initialized callback class. + - If dynamic callback is a "success" callback that is a known_custom_logger_compatible_callbacks then add it to dynamic_async_success_callbacks + - If dynamic callback is a "failure" callback that is a known_custom_logger_compatible_callbacks then add it to dynamic_failure_callbacks + """ + if callback_list is None: + return None + + processed_list: List[Union[str, Callable, CustomLogger]] = [] + for callback in callback_list: + if ( + isinstance(callback, str) + and callback in litellm._known_custom_logger_compatible_callbacks + ): + callback_class = _init_custom_logger_compatible_class( + callback, internal_usage_cache=None, llm_router=None # type: ignore + ) + if callback_class is not None: + processed_list.append(callback_class) + + # If processing dynamic_success_callbacks, add to dynamic_async_success_callbacks + if dynamic_callbacks_type == "success": + if self.dynamic_async_success_callbacks is None: + self.dynamic_async_success_callbacks = [] + self.dynamic_async_success_callbacks.append(callback_class) + elif dynamic_callbacks_type == "failure": + if self.dynamic_async_failure_callbacks is None: + self.dynamic_async_failure_callbacks = [] + self.dynamic_async_failure_callbacks.append(callback_class) + else: + processed_list.append(callback) + return processed_list + + def initialize_standard_callback_dynamic_params( + self, kwargs: Optional[Dict] = None + ) -> StandardCallbackDynamicParams: + """ + Initialize the standard callback dynamic params from the kwargs + + checks if langfuse_secret_key, gcs_bucket_name in kwargs and sets the corresponding attributes in StandardCallbackDynamicParams + """ + + return _initialize_standard_callback_dynamic_params(kwargs) + + def initialize_standard_built_in_tools_params( + self, kwargs: Optional[Dict] = None + ) -> StandardBuiltInToolsParams: + """ + Initialize the standard built-in tools params from the kwargs + + checks if web_search_options in kwargs or tools and sets the corresponding attribute in StandardBuiltInToolsParams + """ + return StandardBuiltInToolsParams( + web_search_options=StandardBuiltInToolCostTracking._get_web_search_options( + kwargs or {} + ), + file_search=StandardBuiltInToolCostTracking._get_file_search_tool_call( + kwargs or {} + ), + ) + + def update_environment_variables( + self, + litellm_params: Dict, + optional_params: Dict, + model: Optional[str] = None, + user: Optional[str] = None, + **additional_params, + ): + self.optional_params = optional_params + if model is not None: + self.model = model + self.user = user + self.litellm_params = { + **self.litellm_params, + **scrub_sensitive_keys_in_metadata(litellm_params), + } + self.litellm_request_debug = litellm_params.get("litellm_request_debug", False) + self.logger_fn = litellm_params.get("logger_fn", None) + if _is_debugging_on() or self.litellm_request_debug: + verbose_logger.debug(f"self.optional_params: {self.optional_params}") + + self.model_call_details.update( + { + "model": self.model, + "messages": self.messages, + "optional_params": self.optional_params, + "litellm_params": self.litellm_params, + "start_time": self.start_time, + "stream": self.stream, + "user": user, + "call_type": str(self.call_type), + "litellm_call_id": self.litellm_call_id, + "completion_start_time": self.completion_start_time, + "standard_callback_dynamic_params": self.standard_callback_dynamic_params, + **self.optional_params, + **additional_params, + } + ) + + ## check if stream options is set ## - used by CustomStreamWrapper for easy instrumentation + if "stream_options" in additional_params: + self.stream_options = additional_params["stream_options"] + ## check if custom pricing set ## + if any( + litellm_params.get(key) is not None + for key in _CUSTOM_PRICING_KEYS & litellm_params.keys() + ): + self.custom_pricing = True + + if "custom_llm_provider" in self.model_call_details: + self.custom_llm_provider = self.model_call_details["custom_llm_provider"] + + def update_messages(self, messages: List[AllMessageValues]): + """ + Update the logged value of the messages in the model_call_details + + Allows pre-call hooks to update the messages before the call is made + """ + self.messages = messages + self.model_call_details["messages"] = messages + + def should_run_prompt_management_hooks( + self, + non_default_params: Dict, + prompt_id: Optional[str] = None, + tools: Optional[List[Dict]] = None, + ) -> bool: + """ + Return True if prompt management hooks should be run + """ + if prompt_id: + return True + + if self._should_run_prompt_management_hooks_without_prompt_id( + non_default_params=non_default_params, + tools=tools, + ): + return True + + return False + + def _should_run_prompt_management_hooks_without_prompt_id( + self, + non_default_params: Dict, + tools: Optional[List[Dict]] = None, + ) -> bool: + """ + Certain prompt management hooks don't need a `prompt_id` to be passed in, they are triggered by dynamic params + + eg. AnthropicCacheControlHook and BedrockKnowledgeBaseHook both don't require a `prompt_id` to be passed in, they are triggered by dynamic params + """ + for param in non_default_params: + if param in DynamicPromptManagementParamLiteral.list_all_params(): + return True + + ############################################################################# + # Check if Vector Store / Knowledge Base hooks should be applied to the prompt + ############################################################################# + if litellm.vector_store_registry is not None: + if litellm.vector_store_registry.get_vector_store_to_run( + non_default_params=non_default_params, tools=tools + ): + return True + return False + + def get_chat_completion_prompt( + self, + model: str, + messages: List[AllMessageValues], + non_default_params: Dict, + prompt_variables: Optional[dict], + prompt_id: Optional[str] = None, + prompt_spec: Optional[PromptSpec] = None, + prompt_management_logger: Optional[CustomLogger] = None, + prompt_label: Optional[str] = None, + prompt_version: Optional[int] = None, + ) -> Tuple[str, List[AllMessageValues], dict]: + custom_logger = ( + prompt_management_logger + or self.get_custom_logger_for_prompt_management( + model=model, + non_default_params=non_default_params, + prompt_id=prompt_id, + prompt_spec=prompt_spec, + dynamic_callback_params=self.standard_callback_dynamic_params, + ) + ) + + if custom_logger: + ( + model, + messages, + non_default_params, + ) = custom_logger.get_chat_completion_prompt( + model=model, + messages=messages, + non_default_params=non_default_params or {}, + prompt_id=prompt_id, + prompt_spec=prompt_spec, + prompt_variables=prompt_variables, + dynamic_callback_params=self.standard_callback_dynamic_params, + prompt_label=prompt_label, + prompt_version=prompt_version, + ) + self.messages = messages + return model, messages, non_default_params + + async def async_get_chat_completion_prompt( + self, + model: str, + messages: List[AllMessageValues], + non_default_params: Dict, + prompt_variables: Optional[dict], + prompt_id: Optional[str] = None, + prompt_spec: Optional[PromptSpec] = None, + prompt_management_logger: Optional[CustomLogger] = None, + tools: Optional[List[Dict]] = None, + prompt_label: Optional[str] = None, + prompt_version: Optional[int] = None, + ) -> Tuple[str, List[AllMessageValues], dict]: + custom_logger = ( + prompt_management_logger + or self.get_custom_logger_for_prompt_management( + model=model, + tools=tools, + non_default_params=non_default_params, + prompt_id=prompt_id, + prompt_spec=prompt_spec, + dynamic_callback_params=self.standard_callback_dynamic_params, + ) + ) + + if custom_logger: + ( + model, + messages, + non_default_params, + ) = await custom_logger.async_get_chat_completion_prompt( + model=model, + messages=messages, + non_default_params=non_default_params or {}, + prompt_id=prompt_id, + prompt_spec=prompt_spec, + prompt_variables=prompt_variables, + dynamic_callback_params=self.standard_callback_dynamic_params, + litellm_logging_obj=self, + tools=tools, + prompt_label=prompt_label, + prompt_version=prompt_version, + ) + self.messages = messages + return model, messages, non_default_params + + def _auto_detect_prompt_management_logger( + self, + prompt_id: str, + prompt_spec: Optional[PromptSpec], + dynamic_callback_params: StandardCallbackDynamicParams, + ) -> Optional[CustomLogger]: + """ + Auto-detect which prompt management system owns the given prompt_id. + + This allows a user to just pass prompt_id in the completion call and it will be auto-detected which system owns this prompt. + + Args: + prompt_id: The prompt ID to check + dynamic_callback_params: Dynamic callback parameters for should_run_prompt_management checks + + Returns: + A CustomLogger instance if a matching prompt management system is found, None otherwise + """ + prompt_management_loggers = ( + litellm.logging_callback_manager.get_custom_loggers_for_type( + callback_type=CustomPromptManagement + ) + ) + + for logger in prompt_management_loggers: + if isinstance(logger, CustomPromptManagement): + try: + if logger.should_run_prompt_management( + prompt_id=prompt_id, + prompt_spec=prompt_spec, + dynamic_callback_params=dynamic_callback_params, + ): + self.model_call_details[ + "prompt_integration" + ] = logger.__class__.__name__ + return logger + except Exception: + # If check fails, continue to next logger + continue + + return None + + def get_custom_logger_for_prompt_management( + self, + model: str, + non_default_params: Dict, + tools: Optional[List[Dict]] = None, + prompt_id: Optional[str] = None, + prompt_spec: Optional[PromptSpec] = None, + dynamic_callback_params: Optional[StandardCallbackDynamicParams] = None, + ) -> Optional[CustomLogger]: + """ + Get a custom logger for prompt management based on model name or available callbacks. + + Args: + model: The model name to check for prompt management integration + non_default_params: Non-default parameters passed to the completion call + tools: Optional tools passed to the completion call + prompt_id: Optional prompt ID to auto-detect which system owns this prompt + dynamic_callback_params: Dynamic callback parameters for should_run_prompt_management checks + + Returns: + A CustomLogger instance if one is found, None otherwise + """ + # First check if model starts with a known custom logger compatible callback + # This takes precedence for backward compatibility + for callback_name in litellm._known_custom_logger_compatible_callbacks: + if model.startswith(callback_name): + custom_logger = _init_custom_logger_compatible_class( + logging_integration=callback_name, + internal_usage_cache=None, + llm_router=None, + ) + if custom_logger is not None: + self.model_call_details["prompt_integration"] = model.split("/")[0] + return custom_logger + + # If prompt_id is provided, try to auto-detect which system has this prompt + if prompt_id and dynamic_callback_params is not None: + auto_detected_logger = self._auto_detect_prompt_management_logger( + prompt_id=prompt_id, + prompt_spec=prompt_spec, + dynamic_callback_params=dynamic_callback_params, + ) + if auto_detected_logger is not None: + return auto_detected_logger + + # Then check for any registered CustomPromptManagement loggers (fallback) + prompt_management_loggers = ( + litellm.logging_callback_manager.get_custom_loggers_for_type( + callback_type=CustomPromptManagement + ) + ) + + if prompt_management_loggers: + logger = prompt_management_loggers[0] + self.model_call_details["prompt_integration"] = logger.__class__.__name__ + return logger + + if anthropic_cache_control_logger := AnthropicCacheControlHook.get_custom_logger_for_anthropic_cache_control_hook( + non_default_params + ): + self.model_call_details[ + "prompt_integration" + ] = anthropic_cache_control_logger.__class__.__name__ + return anthropic_cache_control_logger + + ######################################################### + # Vector Store / Knowledge Base hooks + ######################################################### + if litellm.vector_store_registry is not None: + vector_store_custom_logger = _init_custom_logger_compatible_class( + logging_integration="vector_store_pre_call_hook", + internal_usage_cache=None, + llm_router=None, + ) + self.model_call_details[ + "prompt_integration" + ] = vector_store_custom_logger.__class__.__name__ + # Add to global callbacks so post-call hooks are invoked + if ( + vector_store_custom_logger + and vector_store_custom_logger not in litellm.callbacks + ): + litellm.logging_callback_manager.add_litellm_callback( + vector_store_custom_logger + ) + return vector_store_custom_logger + + return None + + def get_custom_logger_for_anthropic_cache_control_hook( + self, non_default_params: Dict + ) -> Optional[CustomLogger]: + if non_default_params.get("cache_control_injection_points", None): + custom_logger = _init_custom_logger_compatible_class( + logging_integration="anthropic_cache_control_hook", + internal_usage_cache=None, + llm_router=None, + ) + return custom_logger + return None + + def _get_raw_request_body(self, data: Optional[Union[dict, str]]) -> dict: + if data is None: + return {"error": "Received empty dictionary for raw request body"} + if isinstance(data, str): + try: + return json.loads(data) + except Exception: + return { + "error": "Unable to parse raw request body. Got - {}".format(data) + } + return data + + def _get_masked_api_base(self, api_base: str) -> str: + if "key=" in api_base: + # Find the position of "key=" in the string + key_index = api_base.find("key=") + 4 + # Mask the last 5 characters after "key=" + masked_api_base = api_base[:key_index] + "*" * 5 + api_base[-4:] + else: + masked_api_base = api_base + return str(masked_api_base) + + def _pre_call(self, input, api_key, model=None, additional_args={}): + """ + Common helper function across the sync + async pre-call function + """ + + self.model_call_details["input"] = input + self.model_call_details["api_key"] = api_key + self.model_call_details["additional_args"] = additional_args + self.model_call_details["log_event_type"] = "pre_api_call" + if ( + model + ): # if model name was changes pre-call, overwrite the initial model call name with the new one + self.model_call_details["model"] = model + self.model_call_details["litellm_params"][ + "api_base" + ] = self._get_masked_api_base(additional_args.get("api_base", "")) + + def pre_call(self, input, api_key, model=None, additional_args={}): # noqa: PLR0915 + # Log the exact input to the LLM API + litellm.error_logs["PRE_CALL"] = locals() + try: + self._pre_call( + input=input, + api_key=api_key, + model=model, + additional_args=additional_args, + ) + + # User Logging -> if you pass in a custom logging function + self._print_llm_call_debugging_log( + api_base=additional_args.get("api_base", ""), + headers=additional_args.get("headers", {}), + additional_args=additional_args, + ) + # log raw request to provider (like LangFuse) -- if opted in. + if ( + self.log_raw_request_response is True + or log_raw_request_response is True + ): + _litellm_params = self.model_call_details.get("litellm_params", {}) + _metadata = _litellm_params.get("metadata", {}) or {} + try: + # [Non-blocking Extra Debug Information in metadata] + if turn_off_message_logging is True: + _metadata[ + "raw_request" + ] = "redacted by litellm. \ + 'litellm.turn_off_message_logging=True'" + else: + curl_command = self._get_request_curl_command( + api_base=additional_args.get("api_base", ""), + headers=additional_args.get("headers", {}), + additional_args=additional_args, + data=additional_args.get("complete_input_dict", {}), + ) + + _metadata["raw_request"] = str(curl_command) + # split up, so it's easier to parse in the UI + self.model_call_details[ + "raw_request_typed_dict" + ] = RawRequestTypedDict( + raw_request_api_base=str( + additional_args.get("api_base") or "" + ), + raw_request_body=self._get_raw_request_body( + additional_args.get("complete_input_dict", {}) + ), + # NOTE: setting ignore_sensitive_headers to True will cause + # the Authorization header to be leaked when calls to the health + # endpoint are made and fail. + raw_request_headers=self._get_masked_headers( + additional_args.get("headers", {}) or {}, + ), + error=None, + ) + except Exception as e: + self.model_call_details[ + "raw_request_typed_dict" + ] = RawRequestTypedDict( + error=str(e), + ) + _metadata[ + "raw_request" + ] = "Unable to Log \ + raw request: {}".format( + str(e) + ) + if getattr(self, "logger_fn", None) and callable(self.logger_fn): + try: + self.logger_fn( + self.model_call_details + ) # Expectation: any logger function passed in by the user should accept a dict object + except Exception as e: + verbose_logger.exception( + "LiteLLM.LoggingError: [Non-Blocking] Exception occurred while logging {}".format( + str(e) + ) + ) + + self.model_call_details["api_call_start_time"] = datetime.datetime.now() + # Input Integration Logging -> If you want to log the fact that an attempt to call the model was made + callbacks = litellm.input_callback + (self.dynamic_input_callbacks or []) + for callback in callbacks: + try: + if callback == "supabase" and supabaseClient is not None: + verbose_logger.debug("reaches supabase for logging!") + model = self.model_call_details["model"] + messages = self.model_call_details["input"] + verbose_logger.debug(f"supabaseClient: {supabaseClient}") + supabaseClient.input_log_event( + model=model, + messages=messages, + end_user=self.model_call_details.get("user", "default"), + litellm_call_id=self.litellm_params["litellm_call_id"], + print_verbose=print_verbose, + ) + elif callback == "sentry" and add_breadcrumb: + try: + details_to_log = copy.deepcopy(self.model_call_details) + except Exception: + details_to_log = self.model_call_details + if litellm.turn_off_message_logging: + # make a copy of the _model_Call_details and log it + details_to_log.pop("messages", None) + details_to_log.pop("input", None) + details_to_log.pop("prompt", None) + + add_breadcrumb( + category="litellm.llm_call", + message=f"Model Call Details pre-call: {details_to_log}", + level="info", + ) + + elif isinstance(callback, CustomLogger): # custom logger class + callback.log_pre_api_call( + model=self.model, + messages=self.messages, + kwargs=self.model_call_details, + ) + elif ( + callable(callback) and customLogger is not None + ): # custom logger functions + customLogger.log_input_event( + model=self.model, + messages=self.messages, + kwargs=self.model_call_details, + print_verbose=print_verbose, + callback_func=callback, + ) + except Exception as e: + verbose_logger.exception( + "litellm.Logging.pre_call(): Exception occured - {}".format( + str(e) + ) + ) + verbose_logger.debug( + f"LiteLLM.Logging: is sentry capture exception initialized {capture_exception}" + ) + if capture_exception: # log this error to sentry for debugging + capture_exception(e) + except Exception as e: + verbose_logger.exception( + "LiteLLM.LoggingError: [Non-Blocking] Exception occurred while logging {}".format( + str(e) + ) + ) + verbose_logger.error( + f"LiteLLM.Logging: is sentry capture exception initialized {capture_exception}" + ) + if capture_exception: # log this error to sentry for debugging + capture_exception(e) + + def _print_llm_call_debugging_log( + self, + api_base: str, + headers: dict, + additional_args: dict, + ): + """ + Internal debugging helper function + + Prints the RAW curl command sent from LiteLLM + """ + if _is_debugging_on() or self.litellm_request_debug: + if json_logs: + masked_headers = self._get_masked_headers(headers) + if self.litellm_request_debug: + verbose_logger.warning( # .warning ensures this shows up in all environments + "POST Request Sent from LiteLLM", + extra={"api_base": {api_base}, **masked_headers}, + ) + else: + verbose_logger.debug( + "POST Request Sent from LiteLLM", + extra={"api_base": {api_base}, **masked_headers}, + ) + else: + headers = additional_args.get("headers", {}) + if headers is None: + headers = {} + data = additional_args.get("complete_input_dict", {}) + api_base = str(additional_args.get("api_base", "")) + curl_command = self._get_request_curl_command( + api_base=api_base, + headers=headers, + additional_args=additional_args, + data=data, + ) + if self.litellm_request_debug: + verbose_logger.warning( + f"\033[92m{curl_command}\033[0m\n" + ) # .warning ensures this shows up in all environments + else: + verbose_logger.debug(f"\033[92m{curl_command}\033[0m\n") + + def _get_request_body(self, data: dict) -> str: + return str(data) + + def _get_request_curl_command( + self, api_base: str, headers: Optional[dict], additional_args: dict, data: dict + ) -> str: + masked_api_base = self._get_masked_api_base(api_base) + if headers is None: + headers = {} + curl_command = "\n\nPOST Request Sent from LiteLLM:\n" + curl_command += "curl -X POST \\\n" + curl_command += f"{masked_api_base} \\\n" + masked_headers = self._get_masked_headers(headers) + formatted_headers = " ".join( + [f"-H '{k}: {v}'" for k, v in masked_headers.items()] + ) + curl_command += ( + f"{formatted_headers} \\\n" if formatted_headers.strip() != "" else "" + ) + curl_command += f"-d '{self._get_request_body(data)}'\n" + if additional_args.get("request_str", None) is not None: + # print the sagemaker / bedrock client request + curl_command = "\nRequest Sent from LiteLLM:\n" + request_str = additional_args.get("request_str", "") + curl_command += request_str + elif api_base == "": + curl_command = str(self.model_call_details) + return curl_command + + def _get_masked_headers( + self, headers: dict, ignore_sensitive_headers: bool = False + ) -> dict: + """ + Internal debugging helper function + + Masks the headers of the request sent from LiteLLM + """ + return _get_masked_values( + headers, ignore_sensitive_values=ignore_sensitive_headers + ) + + def post_call( + self, original_response, input=None, api_key=None, additional_args={} + ): + # Log the exact result from the LLM API, for streaming - log the type of response received + litellm.error_logs["POST_CALL"] = locals() + if isinstance(original_response, dict): + original_response = json.dumps(original_response) + try: + self.model_call_details["input"] = input + self.model_call_details["api_key"] = api_key + self.model_call_details["original_response"] = original_response + self.model_call_details["additional_args"] = additional_args + self.model_call_details["log_event_type"] = "post_api_call" + + if self.litellm_request_debug: + attr = "warning" + else: + attr = "debug" + + if json_logs: + callattr = getattr(verbose_logger, attr) + callattr( + "RAW RESPONSE:\n{}\n\n".format( + self.model_call_details.get( + "original_response", self.model_call_details + ) + ), + ) + else: + callattr = getattr(verbose_logger, attr) + callattr( + "RAW RESPONSE:\n{}\n\n".format( + self.model_call_details.get( + "original_response", self.model_call_details + ) + ) + ) + if getattr(self, "logger_fn", None) and callable(self.logger_fn): + try: + self.logger_fn( + self.model_call_details + ) # Expectation: any logger function passed in by the user should accept a dict object + except Exception as e: + verbose_logger.exception( + "LiteLLM.LoggingError: [Non-Blocking] Exception occurred while logging {}".format( + str(e) + ) + ) + original_response = redact_message_input_output_from_logging( + model_call_details=( + self.model_call_details + if hasattr(self, "model_call_details") + else {} + ), + result=original_response, + ) + # Input Integration Logging -> If you want to log the fact that an attempt to call the model was made + + callbacks = litellm.input_callback + (self.dynamic_input_callbacks or []) + for callback in callbacks: + try: + if callback == "sentry" and add_breadcrumb: + verbose_logger.debug("reaches sentry breadcrumbing") + try: + details_to_log = copy.deepcopy(self.model_call_details) + except Exception: + details_to_log = self.model_call_details + if litellm.turn_off_message_logging: + # make a copy of the _model_Call_details and log it + details_to_log.pop("messages", None) + details_to_log.pop("input", None) + details_to_log.pop("prompt", None) + + add_breadcrumb( + category="litellm.llm_call", + message=f"Model Call Details post-call: {details_to_log}", + level="info", + ) + elif isinstance(callback, CustomLogger): # custom logger class + callback.log_post_api_call( + kwargs=self.model_call_details, + response_obj=None, + start_time=self.start_time, + end_time=None, + ) + except Exception as e: + verbose_logger.exception( + "LiteLLM.LoggingError: [Non-Blocking] Exception occurred while post-call logging with integrations {}".format( + str(e) + ) + ) + verbose_logger.debug( + f"LiteLLM.Logging: is sentry capture exception initialized {capture_exception}" + ) + if capture_exception: # log this error to sentry for debugging + capture_exception(e) + except Exception as e: + verbose_logger.exception( + "LiteLLM.LoggingError: [Non-Blocking] Exception occurred while logging {}".format( + str(e) + ) + ) + + async def async_post_mcp_tool_call_hook( + self, + kwargs: dict, + response_obj: Any, + start_time: datetime.datetime, + end_time: datetime.datetime, + ): + """ + Post MCP Tool Call Hook + + Use this to modify the MCP tool call response before it is returned to the user. + """ + from litellm.types.llms.base import HiddenParams + from litellm.types.mcp import MCPPostCallResponseObject + + callbacks = self.get_combined_callback_list( + dynamic_success_callbacks=self.dynamic_success_callbacks, + global_callbacks=litellm.success_callback, + ) + post_mcp_tool_call_response_obj: MCPPostCallResponseObject = ( + MCPPostCallResponseObject( + mcp_tool_call_response=response_obj, hidden_params=HiddenParams() + ) + ) + for callback in callbacks: + try: + if isinstance(callback, CustomLogger): + response: Optional[ + MCPPostCallResponseObject + ] = await callback.async_post_mcp_tool_call_hook( + kwargs=kwargs, + response_obj=post_mcp_tool_call_response_obj, + start_time=start_time, + end_time=end_time, + ) + ###################################################################### + # if any of the callbacks modify the response, use the modified response + # current implementation returns the first modified response + ###################################################################### + if response is not None: + response_obj = self._parse_post_mcp_call_hook_response( + response=response + ) + except Exception as e: + verbose_logger.exception( + "LiteLLM.LoggingError: [Non-Blocking] Exception occurred while logging {}".format( + str(e) + ) + ) + return response_obj + + def _parse_post_mcp_call_hook_response( + self, response: Optional[MCPPostCallResponseObject] + ) -> Any: + """ + Parse the response from the post_mcp_tool_call_hook + + 1. Unpack the mcp_tool_call_response + 2. save the updated response_cost to the model_call_details + """ + if response is None: + return None + self.model_call_details["response_cost"] = response.hidden_params.response_cost + return response.mcp_tool_call_response + + def get_response_ms(self) -> float: + return ( + self.model_call_details.get("end_time", datetime.datetime.now()) + - self.model_call_details.get("start_time", datetime.datetime.now()) + ).total_seconds() * 1000 + + def set_cost_breakdown( + self, + input_cost: float, + output_cost: float, + total_cost: float, + cost_for_built_in_tools_cost_usd_dollar: float, + additional_costs: Optional[dict] = None, + original_cost: Optional[float] = None, + discount_percent: Optional[float] = None, + discount_amount: Optional[float] = None, + margin_percent: Optional[float] = None, + margin_fixed_amount: Optional[float] = None, + margin_total_amount: Optional[float] = None, + ) -> None: + """ + Helper method to store cost breakdown in the logging object. + + Args: + input_cost: Cost of input/prompt tokens + output_cost: Cost of output/completion tokens + cost_for_built_in_tools_cost_usd_dollar: Cost of built-in tools + total_cost: Total cost of request + additional_costs: Free-form additional costs dict (e.g., {"azure_model_router_flat_cost": 0.00014}) + original_cost: Cost before discount + discount_percent: Discount percentage (0.05 = 5%) + discount_amount: Discount amount in USD + margin_percent: Margin percentage applied (0.10 = 10%) + margin_fixed_amount: Fixed margin amount in USD + margin_total_amount: Total margin added in USD + """ + + self.cost_breakdown = CostBreakdown( + input_cost=input_cost, + output_cost=output_cost, + total_cost=total_cost, + tool_usage_cost=cost_for_built_in_tools_cost_usd_dollar, + ) + + # Store additional costs if provided (free-form dict for extensibility) + if ( + additional_costs + and isinstance(additional_costs, dict) + and len(additional_costs) > 0 + ): + self.cost_breakdown["additional_costs"] = additional_costs + + # Store discount information if provided + if original_cost is not None: + self.cost_breakdown["original_cost"] = original_cost + if discount_percent is not None: + self.cost_breakdown["discount_percent"] = discount_percent + if discount_amount is not None: + self.cost_breakdown["discount_amount"] = discount_amount + + # Store margin information if provided + if margin_percent is not None: + self.cost_breakdown["margin_percent"] = margin_percent + if margin_fixed_amount is not None: + self.cost_breakdown["margin_fixed_amount"] = margin_fixed_amount + if margin_total_amount is not None: + self.cost_breakdown["margin_total_amount"] = margin_total_amount + + def _response_cost_calculator( + self, + result: Union[ + ModelResponse, + ModelResponseStream, + EmbeddingResponse, + ImageResponse, + TranscriptionResponse, + TextCompletionResponse, + HttpxBinaryResponseContent, + RerankResponse, + Batch, + FineTuningJob, + ResponsesAPIResponse, + ResponseCompletedEvent, + OpenAIFileObject, + LiteLLMRealtimeStreamLoggingObject, + OpenAIModerationResponse, + "SearchResponse", + ], + cache_hit: Optional[bool] = None, + litellm_model_name: Optional[str] = None, + router_model_id: Optional[str] = None, + ) -> Optional[float]: + """ + Calculate response cost using result + logging object variables. + + used for consistent cost calculation across response headers + logging integrations. + """ + + if isinstance(result, BaseModel) and hasattr(result, "_hidden_params"): + hidden_params = getattr(result, "_hidden_params", {}) + if ( + "response_cost" in hidden_params + and hidden_params["response_cost"] is not None + ): # use cost if already calculated + return hidden_params["response_cost"] + elif ( + router_model_id is None and "model_id" in hidden_params + ): # use model_id if not already set + router_model_id = hidden_params["model_id"] + + ## RESPONSE COST ## + custom_pricing = use_custom_pricing_for_model( + litellm_params=( + self.litellm_params if hasattr(self, "litellm_params") else None + ) + ) + + prompt = "" # use for tts cost calc + _input = self.model_call_details.get("input", None) + if _input is not None and isinstance(_input, str): + prompt = _input + + if cache_hit is None: + cache_hit = self.model_call_details.get("cache_hit", False) + + try: + response_cost_calculator_kwargs = { + "response_object": result, + "model": litellm_model_name or self.model, + "cache_hit": cache_hit, + "custom_llm_provider": self.model_call_details.get( + "custom_llm_provider", None + ), + "base_model": _get_base_model_from_metadata( + model_call_details=self.model_call_details + ), + "call_type": self.call_type, + "optional_params": self.optional_params, + "custom_pricing": custom_pricing, + "prompt": prompt, + "standard_built_in_tools_params": self.standard_built_in_tools_params, + "router_model_id": router_model_id, + "litellm_logging_obj": self, + "service_tier": ( + self.optional_params.get("service_tier") + if self.optional_params + else None + ), + } + except Exception as e: # error creating kwargs for cost calculation + debug_info = StandardLoggingModelCostFailureDebugInformation( + error_str=str(e), + traceback_str=_get_traceback_str_for_error(str(e)), + ) + verbose_logger.debug( + f"response_cost_failure_debug_information: {debug_info}" + ) + self.model_call_details[ + "response_cost_failure_debug_information" + ] = debug_info + return None + + try: + response_cost = litellm.response_cost_calculator( + **response_cost_calculator_kwargs + ) + + verbose_logger.debug(f"response_cost: {response_cost}") + return response_cost + except Exception as e: # error calculating cost + debug_info = StandardLoggingModelCostFailureDebugInformation( + error_str=str(e), + traceback_str=_get_traceback_str_for_error(str(e)), + model=response_cost_calculator_kwargs["model"], + cache_hit=response_cost_calculator_kwargs["cache_hit"], + custom_llm_provider=response_cost_calculator_kwargs[ + "custom_llm_provider" + ], + base_model=response_cost_calculator_kwargs["base_model"], + call_type=response_cost_calculator_kwargs["call_type"], + custom_pricing=response_cost_calculator_kwargs["custom_pricing"], + ) + verbose_logger.debug( + f"response_cost_failure_debug_information: {debug_info}" + ) + self.model_call_details[ + "response_cost_failure_debug_information" + ] = debug_info + + return None + + async def _response_cost_calculator_async( + self, + result: Union[ + ModelResponse, + ModelResponseStream, + EmbeddingResponse, + ImageResponse, + TranscriptionResponse, + TextCompletionResponse, + HttpxBinaryResponseContent, + RerankResponse, + Batch, + FineTuningJob, + ], + cache_hit: Optional[bool] = None, + ) -> Optional[float]: + return self._response_cost_calculator(result=result, cache_hit=cache_hit) + + def should_run_logging( + self, + event_type: Literal[ + "async_success", "sync_success", "async_failure", "sync_failure" + ], + stream: bool = False, + ) -> bool: + try: + if self.model_call_details.get(f"has_logged_{event_type}", False) is True: + return False + + return True + except Exception: + return True + + def has_run_logging( + self, + event_type: Literal[ + "async_success", "sync_success", "async_failure", "sync_failure" + ], + ) -> None: + if self.stream is not None and self.stream is True: + """ + Ignore check on stream, as there can be multiple chunks + """ + return + self.model_call_details[f"has_logged_{event_type}"] = True + return + + def should_run_callback( + self, callback: litellm.CALLBACK_TYPES, litellm_params: dict, event_hook: str + ) -> bool: + if litellm.global_disable_no_log_param: + return True + + if litellm_params.get("no-log", False) is True: + # proxy cost tracking cal backs should run + + if not ( + isinstance(callback, CustomLogger) + and "_PROXY_" in callback.__class__.__name__ + ): + verbose_logger.debug( + f"no-log request, skipping logging for {event_hook} event" + ) + return False + + # Check for dynamically disabled callbacks via headers + if ( + EnterpriseCallbackControls is not None + and EnterpriseCallbackControls.is_callback_disabled_dynamically( + callback=callback, + litellm_params=litellm_params, + standard_callback_dynamic_params=self.standard_callback_dynamic_params, + ) + ): + verbose_logger.debug( + f"Callback {callback} disabled via x-litellm-disable-callbacks header for {event_hook} event" + ) + return False + + return True + + def _update_completion_start_time(self, completion_start_time: datetime.datetime): + self.completion_start_time = completion_start_time + self.model_call_details["completion_start_time"] = self.completion_start_time + + def normalize_logging_result(self, result: Any) -> Any: + """ + Some endpoints return a different type of result than what is expected by the logging system. + This function is used to normalize the result to the expected type. + """ + logging_result = result + if self.call_type == CallTypes.arealtime.value and isinstance(result, list): + combined_usage_object = RealtimeAPITokenUsageProcessor.collect_and_combine_usage_from_realtime_stream_results( + results=result + ) + logging_result = ( + RealtimeAPITokenUsageProcessor.create_logging_realtime_object( + usage=combined_usage_object, + results=result, + ) + ) + + elif ( + self.call_type == CallTypes.llm_passthrough_route.value + or self.call_type == CallTypes.allm_passthrough_route.value + ) and isinstance(result, Response): + from litellm.utils import ProviderConfigManager + + provider_config = ProviderConfigManager.get_provider_passthrough_config( + provider=self.model_call_details.get("custom_llm_provider", ""), + model=self.model, + ) + if provider_config is not None: + logging_result = provider_config.logging_non_streaming_response( + model=self.model, + custom_llm_provider=self.model_call_details.get( + "custom_llm_provider", "" + ), + httpx_response=result, + request_data=self.model_call_details.get("request_data", {}), + logging_obj=self, + endpoint=self.model_call_details.get("endpoint", ""), + ) + return logging_result + + def _process_hidden_params_and_response_cost( + self, + logging_result, + start_time, + end_time, + ): + hidden_params = getattr(logging_result, "_hidden_params", {}) + if hidden_params: + if self.model_call_details.get("litellm_params") is not None: + self.model_call_details["litellm_params"].setdefault("metadata", {}) + if self.model_call_details["litellm_params"]["metadata"] is None: + self.model_call_details["litellm_params"]["metadata"] = {} + self.model_call_details["litellm_params"]["metadata"]["hidden_params"] = getattr(logging_result, "_hidden_params", {}) # type: ignore + + if "response_cost" in hidden_params: + self.model_call_details["response_cost"] = hidden_params["response_cost"] + else: + self.model_call_details["response_cost"] = self._response_cost_calculator( + result=logging_result + ) + + self.model_call_details[ + "standard_logging_object" + ] = get_standard_logging_object_payload( + kwargs=self.model_call_details, + init_response_obj=logging_result, + start_time=start_time, + end_time=end_time, + logging_obj=self, + status="success", + standard_built_in_tools_params=self.standard_built_in_tools_params, + ) + + def _transform_usage_objects(self, result): + if isinstance(result, ResponsesAPIResponse): + result = result.model_copy() + transformed_usage = ( + ResponseAPILoggingUtils._transform_response_api_usage_to_chat_usage( + result.usage + ) + ) + setattr(result, "usage", transformed_usage) + if ( + standard_logging_payload := self.model_call_details.get( + "standard_logging_object" + ) + ) is not None: + response_dict = ( + result.model_dump() + if hasattr(result, "model_dump") + else dict(result) + ) + # Ensure usage is properly included with transformed chat format + if transformed_usage is not None: + response_dict["usage"] = ( + transformed_usage.model_dump() + if hasattr(transformed_usage, "model_dump") + else dict(transformed_usage) + ) + standard_logging_payload["response"] = response_dict + elif isinstance(result, TranscriptionResponse): + from litellm.litellm_core_utils.llm_cost_calc.usage_object_transformation import ( + TranscriptionUsageObjectTransformation, + ) + + result = result.model_copy() + transformed_usage = TranscriptionUsageObjectTransformation.transform_transcription_usage_object(result.usage) # type: ignore + setattr(result, "usage", transformed_usage) + return result + + def _success_handler_helper_fn( + self, + result=None, + start_time=None, + end_time=None, + cache_hit=None, + standard_logging_object: Optional[StandardLoggingPayload] = None, + ): + try: + if start_time is None: + start_time = self.start_time + if end_time is None: + end_time = datetime.datetime.now() + if self.completion_start_time is None: + self.completion_start_time = end_time + self.model_call_details[ + "completion_start_time" + ] = self.completion_start_time + + self.model_call_details["log_event_type"] = "successful_api_call" + self.model_call_details["end_time"] = end_time + self.model_call_details["cache_hit"] = cache_hit + + if self.call_type == CallTypes.anthropic_messages.value: + result = self._handle_anthropic_messages_response_logging(result=result) + elif ( + self.call_type == CallTypes.generate_content.value + or self.call_type == CallTypes.agenerate_content.value + ): + result = self._handle_non_streaming_google_genai_generate_content_response_logging( + result=result + ) + elif ( + self.call_type == CallTypes.asend_message.value + or self.call_type == CallTypes.send_message.value + ): + result = self._handle_a2a_response_logging(result=result) + + logging_result = self.normalize_logging_result(result=result) + + if ( + standard_logging_object is None + and result is not None + and self.stream is not True + ): + if self._is_recognized_call_type_for_logging( + logging_result=logging_result + ): + self._process_hidden_params_and_response_cost( + logging_result=logging_result, + start_time=start_time, + end_time=end_time, + ) + elif isinstance(result, dict) or isinstance(result, list): + self.model_call_details[ + "standard_logging_object" + ] = get_standard_logging_object_payload( + kwargs=self.model_call_details, + init_response_obj=result, + start_time=start_time, + end_time=end_time, + logging_obj=self, + status="success", + standard_built_in_tools_params=self.standard_built_in_tools_params, + ) + elif standard_logging_object is not None: + self.model_call_details[ + "standard_logging_object" + ] = standard_logging_object + else: + self.model_call_details["response_cost"] = None + + result = self._transform_usage_objects(result=result) + + if ( + litellm.max_budget + and self.stream is False + and result is not None + and isinstance(result, dict) + and "content" in result + ): + time_diff = (end_time - start_time).total_seconds() + float_diff = float(time_diff) + litellm._current_cost += litellm.completion_cost( + model=self.model, + prompt="", + completion=getattr(result, "content", ""), + total_time=float_diff, + standard_built_in_tools_params=self.standard_built_in_tools_params, + ) + + return start_time, end_time, result + except Exception as e: + raise Exception(f"[Non-Blocking] LiteLLM.Success_Call Error: {str(e)}") + + def _is_recognized_call_type_for_logging( + self, + logging_result: Any, + ): + """ + Returns True if the call type is recognized for logging (eg. ModelResponse, ModelResponseStream, etc.) + """ + if ( + isinstance(logging_result, ModelResponse) + or isinstance(logging_result, ModelResponseStream) + or isinstance(logging_result, EmbeddingResponse) + or isinstance(logging_result, ImageResponse) + or isinstance(logging_result, TranscriptionResponse) + or isinstance(logging_result, TextCompletionResponse) + or isinstance(logging_result, HttpxBinaryResponseContent) # tts + or isinstance(logging_result, RerankResponse) + or isinstance(logging_result, FineTuningJob) + or isinstance(logging_result, LiteLLMBatch) + or isinstance(logging_result, ResponsesAPIResponse) + or isinstance(logging_result, OpenAIFileObject) + or isinstance(logging_result, LiteLLMRealtimeStreamLoggingObject) + or isinstance(logging_result, OpenAIModerationResponse) + or isinstance(logging_result, OCRResponse) # OCR + or isinstance(logging_result, SearchResponse) # Search API + or isinstance(logging_result, dict) + and logging_result.get("object") == "vector_store.search_results.page" + or isinstance(logging_result, dict) + and logging_result.get("object") == "search" # Search API (dict format) + or isinstance(logging_result, VideoObject) + or isinstance(logging_result, ContainerObject) + or isinstance(logging_result, LiteLLMSendMessageResponse) # A2A + or (self.call_type == CallTypes.call_mcp_tool.value) + ): + return True + return False + + def _flush_passthrough_collected_chunks_helper( + self, + raw_bytes: List[bytes], + provider_config: "BasePassthroughConfig", + ) -> Optional["CostResponseTypes"]: + all_chunks = provider_config._convert_raw_bytes_to_str_lines(raw_bytes) + complete_streaming_response = provider_config.handle_logging_collected_chunks( + all_chunks=all_chunks, + litellm_logging_obj=self, + model=self.model, + custom_llm_provider=self.model_call_details.get("custom_llm_provider", ""), + endpoint=self.model_call_details.get("endpoint", ""), + ) + return complete_streaming_response + + def flush_passthrough_collected_chunks( + self, + raw_bytes: List[bytes], + provider_config: "BasePassthroughConfig", + ): + """ + Flush collected chunks from the logging object + This is used to log the collected chunks once streaming is done on passthrough endpoints + + 1. Decode the raw bytes to string lines + 2. Get the complete streaming response from the provider config + 3. Log the complete streaming response (trigger success handler) + This is used for passthrough endpoints + """ + complete_streaming_response = self._flush_passthrough_collected_chunks_helper( + raw_bytes=raw_bytes, + provider_config=provider_config, + ) + + if complete_streaming_response is not None: + self.success_handler(result=complete_streaming_response) + return + + async def async_flush_passthrough_collected_chunks( + self, + raw_bytes: List[bytes], + provider_config: "BasePassthroughConfig", + ): + complete_streaming_response = self._flush_passthrough_collected_chunks_helper( + raw_bytes=raw_bytes, + provider_config=provider_config, + ) + + if complete_streaming_response is not None: + await self.async_success_handler(result=complete_streaming_response) + return + + def success_handler( # noqa: PLR0915 + self, result=None, start_time=None, end_time=None, cache_hit=None, **kwargs + ): + verbose_logger.debug( + f"Logging Details LiteLLM-Success Call: Cache_hit={cache_hit}" + ) + if not self.should_run_logging( + event_type="sync_success" + ): # prevent double logging + return + start_time, end_time, result = self._success_handler_helper_fn( + start_time=start_time, + end_time=end_time, + result=result, + cache_hit=cache_hit, + standard_logging_object=kwargs.get("standard_logging_object", None), + ) + litellm_params = self.model_call_details.get("litellm_params", {}) + is_sync_request = ( + litellm_params.get(CallTypes.acompletion.value, False) is not True + and litellm_params.get(CallTypes.aresponses.value, False) is not True + and litellm_params.get(CallTypes.aembedding.value, False) is not True + and litellm_params.get(CallTypes.aimage_generation.value, False) is not True + and litellm_params.get(CallTypes.atranscription.value, False) is not True + ) + try: + ## BUILD COMPLETE STREAMED RESPONSE + complete_streaming_response: Optional[ + Union[ModelResponse, TextCompletionResponse, ResponsesAPIResponse] + ] = None + if "complete_streaming_response" in self.model_call_details: + return # break out of this. + complete_streaming_response = self._get_assembled_streaming_response( + result=result, + start_time=start_time, + end_time=end_time, + is_async=False, + streaming_chunks=self.sync_streaming_chunks, + ) + if complete_streaming_response is not None: + verbose_logger.debug( + "Logging Details LiteLLM-Success Call streaming complete" + ) + self.model_call_details[ + "complete_streaming_response" + ] = complete_streaming_response + self.model_call_details[ + "response_cost" + ] = self._response_cost_calculator(result=complete_streaming_response) + ## STANDARDIZED LOGGING PAYLOAD + self.model_call_details[ + "standard_logging_object" + ] = get_standard_logging_object_payload( + kwargs=self.model_call_details, + init_response_obj=complete_streaming_response, + start_time=start_time, + end_time=end_time, + logging_obj=self, + status="success", + standard_built_in_tools_params=self.standard_built_in_tools_params, + ) + if ( + standard_logging_payload := self.model_call_details.get( + "standard_logging_object" + ) + ) is not None: + # Only emit for sync requests (async_success_handler handles async) + if is_sync_request: + emit_standard_logging_payload(standard_logging_payload) + callbacks = self.get_combined_callback_list( + dynamic_success_callbacks=self.dynamic_success_callbacks, + global_callbacks=litellm.success_callback, + ) + + ## REDACT MESSAGES ## + result = redact_message_input_output_from_logging( + model_call_details=( + self.model_call_details + if hasattr(self, "model_call_details") + else {} + ), + result=result, + ) + ## LOGGING HOOK ## + for callback in callbacks: + if isinstance(callback, CustomLogger): + self.model_call_details, result = callback.logging_hook( + kwargs=self.model_call_details, + result=result, + call_type=self.call_type, + ) + + self.has_run_logging(event_type="sync_success") + for callback in callbacks: + try: + should_run = self.should_run_callback( + callback=callback, + litellm_params=litellm_params, + event_hook="success_handler", + ) + if not should_run: + continue + if callback == "promptlayer" and promptLayerLogger is not None: + print_verbose("reaches promptlayer for logging!") + promptLayerLogger.log_event( + kwargs=self.model_call_details, + response_obj=result, + start_time=start_time, + end_time=end_time, + print_verbose=print_verbose, + ) + if callback == "supabase" and supabaseClient is not None: + print_verbose("reaches supabase for logging!") + kwargs = self.model_call_details + + # this only logs streaming once, complete_streaming_response exists i.e when stream ends + if self.stream: + if "complete_streaming_response" not in kwargs: + continue + else: + print_verbose("reaches supabase for streaming logging!") + result = kwargs["complete_streaming_response"] + + model = kwargs["model"] + messages = kwargs["messages"] + optional_params = kwargs.get("optional_params", {}) + litellm_params = kwargs.get("litellm_params", {}) + supabaseClient.log_event( + model=model, + messages=messages, + end_user=optional_params.get("user", "default"), + response_obj=result, + start_time=start_time, + end_time=end_time, + litellm_call_id=( + current_call_id + if ( + current_call_id := litellm_params.get( + "litellm_call_id" + ) + ) + is not None + else str(uuid.uuid4()) + ), + print_verbose=print_verbose, + ) + if callback == "wandb" and weightsBiasesLogger is not None: + print_verbose("reaches wandb for logging!") + weightsBiasesLogger.log_event( + kwargs=self.model_call_details, + response_obj=result, + start_time=start_time, + end_time=end_time, + print_verbose=print_verbose, + ) + if callback == "logfire" and logfireLogger is not None: + verbose_logger.debug("reaches logfire for success logging!") + kwargs = {} + for k, v in self.model_call_details.items(): + if ( + k != "original_response" + ): # copy.deepcopy raises errors as this could be a coroutine + kwargs[k] = v + + # this only logs streaming once, complete_streaming_response exists i.e when stream ends + if self.stream: + if "complete_streaming_response" not in kwargs: + continue + else: + print_verbose("reaches logfire for streaming logging!") + result = kwargs["complete_streaming_response"] + + logfireLogger.log_event( + kwargs=self.model_call_details, + response_obj=result, + start_time=start_time, + end_time=end_time, + print_verbose=print_verbose, + level=LogfireLevel.INFO.value, # type: ignore + ) + + if callback == "lunary" and lunaryLogger is not None: + print_verbose("reaches lunary for logging!") + model = self.model + kwargs = self.model_call_details + + input = kwargs.get("messages", kwargs.get("input", None)) + + type = ( + "embed" + if self.call_type == CallTypes.embedding.value + else "llm" + ) + + # this only logs streaming once, complete_streaming_response exists i.e when stream ends + if self.stream: + if "complete_streaming_response" not in kwargs: + continue + else: + result = kwargs["complete_streaming_response"] + + lunaryLogger.log_event( + type=type, + kwargs=kwargs, + event="end", + model=model, + input=input, + user_id=kwargs.get("user", None), + # user_props=self.model_call_details.get("user_props", None), + extra=kwargs.get("optional_params", {}), + response_obj=result, + start_time=start_time, + end_time=end_time, + run_id=self.litellm_call_id, + print_verbose=print_verbose, + ) + if callback == "helicone" and heliconeLogger is not None: + print_verbose("reaches helicone for logging!") + model = self.model + messages = self.model_call_details["input"] + kwargs = self.model_call_details + + # this only logs streaming once, complete_streaming_response exists i.e when stream ends + if self.stream: + if "complete_streaming_response" not in kwargs: + continue + else: + print_verbose("reaches helicone for streaming logging!") + result = kwargs["complete_streaming_response"] + + heliconeLogger.log_success( + model=model, + messages=messages, + response_obj=result, + start_time=start_time, + end_time=end_time, + print_verbose=print_verbose, + kwargs=kwargs, + ) + if callback == "langfuse": + global langFuseLogger + print_verbose("reaches langfuse for success logging!") + kwargs = {} + for k, v in self.model_call_details.items(): + if ( + k != "original_response" + ): # copy.deepcopy raises errors as this could be a coroutine + kwargs[k] = v + # this only logs streaming once, complete_streaming_response exists i.e when stream ends + if self.stream: + verbose_logger.debug( + f"is complete_streaming_response in kwargs: {kwargs.get('complete_streaming_response', None)}" + ) + if complete_streaming_response is None: + continue + else: + print_verbose("reaches langfuse for streaming logging!") + result = kwargs["complete_streaming_response"] + + langfuse_logger_to_use = LangFuseHandler.get_langfuse_logger_for_request( + globalLangfuseLogger=langFuseLogger, + standard_callback_dynamic_params=self.standard_callback_dynamic_params, + in_memory_dynamic_logger_cache=in_memory_dynamic_logger_cache, + ) + if langfuse_logger_to_use is not None: + _response = langfuse_logger_to_use.log_event_on_langfuse( + kwargs=kwargs, + response_obj=result, + start_time=start_time, + end_time=end_time, + user_id=kwargs.get("user", None), + ) + if _response is not None and isinstance(_response, dict): + _trace_id = _response.get("trace_id", None) + if _trace_id is not None: + in_memory_trace_id_cache.set_cache( + litellm_call_id=self.litellm_call_id, + service_name="langfuse", + trace_id=_trace_id, + ) + if callback == "greenscale" and greenscaleLogger is not None: + kwargs = {} + for k, v in self.model_call_details.items(): + if ( + k != "original_response" + ): # copy.deepcopy raises errors as this could be a coroutine + kwargs[k] = v + # this only logs streaming once, complete_streaming_response exists i.e when stream ends + if self.stream: + verbose_logger.debug( + f"is complete_streaming_response in kwargs: {kwargs.get('complete_streaming_response', None)}" + ) + if complete_streaming_response is None: + continue + else: + print_verbose( + "reaches greenscale for streaming logging!" + ) + result = kwargs["complete_streaming_response"] + + greenscaleLogger.log_event( + kwargs=kwargs, + response_obj=result, + start_time=start_time, + end_time=end_time, + print_verbose=print_verbose, + ) + if callback == "athina" and athinaLogger is not None: + deep_copy = {} + for k, v in self.model_call_details.items(): + deep_copy[k] = v + athinaLogger.log_event( + kwargs=deep_copy, + response_obj=result, + start_time=start_time, + end_time=end_time, + print_verbose=print_verbose, + ) + if callback == "traceloop": + deep_copy = {} + for k, v in self.model_call_details.items(): + if k != "original_response": + deep_copy[k] = v + traceloopLogger.log_event( + kwargs=deep_copy, + response_obj=result, + start_time=start_time, + end_time=end_time, + user_id=kwargs.get("user", None), + print_verbose=print_verbose, + ) + if callback == "s3": + global s3Logger + if s3Logger is None: + s3Logger = S3Logger() + if self.stream: + if "complete_streaming_response" in self.model_call_details: + print_verbose( + "S3Logger Logger: Got Stream Event - Completed Stream Response" + ) + s3Logger.log_event( + kwargs=self.model_call_details, + response_obj=self.model_call_details[ + "complete_streaming_response" + ], + start_time=start_time, + end_time=end_time, + print_verbose=print_verbose, + ) + else: + print_verbose( + "S3Logger Logger: Got Stream Event - No complete stream response as yet" + ) + else: + s3Logger.log_event( + kwargs=self.model_call_details, + response_obj=result, + start_time=start_time, + end_time=end_time, + print_verbose=print_verbose, + ) + + if callback == "openmeter" and is_sync_request: + global openMeterLogger + if openMeterLogger is None: + print_verbose("Instantiates openmeter client") + openMeterLogger = OpenMeterLogger() + if self.stream and complete_streaming_response is None: + openMeterLogger.log_stream_event( + kwargs=self.model_call_details, + response_obj=result, + start_time=start_time, + end_time=end_time, + ) + else: + if self.stream and complete_streaming_response: + self.model_call_details[ + "complete_response" + ] = self.model_call_details.get( + "complete_streaming_response", {} + ) + result = self.model_call_details["complete_response"] + openMeterLogger.log_success_event( + kwargs=self.model_call_details, + response_obj=result, + start_time=start_time, + end_time=end_time, + ) + if ( + isinstance(callback, CustomLogger) + and is_sync_request + and self.call_type + != CallTypes.pass_through.value # pass-through endpoints call async_log_success_event + ): # custom logger class + if self.stream and complete_streaming_response is None: + callback.log_stream_event( + kwargs=self.model_call_details, + response_obj=result, + start_time=start_time, + end_time=end_time, + ) + else: + if self.stream and complete_streaming_response: + self.model_call_details[ + "complete_response" + ] = self.model_call_details.get( + "complete_streaming_response", {} + ) + result = self.model_call_details["complete_response"] + + callback.log_success_event( + kwargs=self.model_call_details, + response_obj=result, + start_time=start_time, + end_time=end_time, + ) + if ( + callable(callback) is True + and is_sync_request + and customLogger is not None + ): # custom logger functions + print_verbose( + "success callbacks: Running Custom Callback Function - {}".format( + callback + ) + ) + + customLogger.log_event( + kwargs=self.model_call_details, + response_obj=result, + start_time=start_time, + end_time=end_time, + print_verbose=print_verbose, + callback_func=callback, + ) + + except Exception as e: + print_verbose( + f"LiteLLM.LoggingError: [Non-Blocking] Exception occurred while success logging with integrations {traceback.format_exc()}" + ) + print_verbose( + f"LiteLLM.Logging: is sentry capture exception initialized {capture_exception}" + ) + if capture_exception: # log this error to sentry for debugging + capture_exception(e) + # Track callback logging failures in Prometheus + try: + self._handle_callback_failure(callback=callback) + except Exception: + pass + except Exception as e: + verbose_logger.exception( + "LiteLLM.LoggingError: [Non-Blocking] Exception occurred while success logging {}".format( + str(e) + ), + ) + + async def async_success_handler( # noqa: PLR0915 + self, result=None, start_time=None, end_time=None, cache_hit=None, **kwargs + ): + """ + Implementing async callbacks, to handle asyncio event loop issues when custom integrations need to use async functions. + """ + print_verbose( + "Logging Details LiteLLM-Async Success Call, cache_hit={}".format(cache_hit) + ) + if not self.should_run_logging( + event_type="async_success" + ): # prevent double logging + return + + ## CALCULATE COST FOR BATCH JOBS + if self.call_type == CallTypes.aretrieve_batch.value and isinstance( + result, LiteLLMBatch + ): + litellm_params = self.litellm_params or {} + litellm_metadata = litellm_params.get("litellm_metadata") or {} + if ( + litellm_metadata.get("batch_ignore_default_logging", False) is True + ): # polling job will query these frequently, don't spam db logs + return + + from litellm.proxy.openai_files_endpoints.common_utils import ( + _is_base64_encoded_unified_file_id, + ) + + # check if file id is a unified file id + is_base64_unified_file_id = _is_base64_encoded_unified_file_id(result.id) + + batch_cost = kwargs.get("batch_cost", None) + batch_usage = kwargs.get("batch_usage", None) + batch_models = kwargs.get("batch_models", None) + has_explicit_batch_data = all( + x is not None for x in (batch_cost, batch_usage, batch_models) + ) + + should_compute_batch_data = ( + not is_base64_unified_file_id + or not has_explicit_batch_data + and result.status == "completed" + ) + if has_explicit_batch_data: + result._hidden_params["response_cost"] = batch_cost + result._hidden_params["batch_models"] = batch_models + result.usage = batch_usage + + elif should_compute_batch_data: + ( + response_cost, + batch_usage, + batch_models, + ) = await _handle_completed_batch( + batch=result, + custom_llm_provider=self.custom_llm_provider, + litellm_params=self.litellm_params, + ) + + result._hidden_params["response_cost"] = response_cost + result._hidden_params["batch_models"] = batch_models + result.usage = batch_usage + + start_time, end_time, result = self._success_handler_helper_fn( + start_time=start_time, + end_time=end_time, + result=result, + cache_hit=cache_hit, + standard_logging_object=kwargs.get("standard_logging_object", None), + ) + + ## BUILD COMPLETE STREAMED RESPONSE + if "async_complete_streaming_response" in self.model_call_details: + return # break out of this. + complete_streaming_response: Optional[ + Union[ModelResponse, TextCompletionResponse, ResponsesAPIResponse] + ] = self._get_assembled_streaming_response( + result=result, + start_time=start_time, + end_time=end_time, + is_async=True, + streaming_chunks=self.streaming_chunks, + ) + + if complete_streaming_response is not None: + print_verbose("Async success callbacks: Got a complete streaming response") + + self.model_call_details[ + "async_complete_streaming_response" + ] = complete_streaming_response + + try: + if self.model_call_details.get("cache_hit", False) is True: + self.model_call_details["response_cost"] = 0.0 + else: + # check if base_model set on azure + _get_base_model_from_metadata( + model_call_details=self.model_call_details + ) + # base_model defaults to None if not set on model_info + self.model_call_details[ + "response_cost" + ] = self._response_cost_calculator( + result=complete_streaming_response + ) + + verbose_logger.debug( + f"Model={self.model}; cost={self.model_call_details['response_cost']}" + ) + except litellm.NotFoundError: + verbose_logger.warning( + f"Model={self.model} not found in completion cost map. Setting 'response_cost' to None" + ) + self.model_call_details["response_cost"] = None + + ## STANDARDIZED LOGGING PAYLOAD + self.model_call_details[ + "standard_logging_object" + ] = get_standard_logging_object_payload( + kwargs=self.model_call_details, + init_response_obj=complete_streaming_response, + start_time=start_time, + end_time=end_time, + logging_obj=self, + status="success", + standard_built_in_tools_params=self.standard_built_in_tools_params, + ) + + # print standard logging payload + if ( + standard_logging_payload := self.model_call_details.get( + "standard_logging_object" + ) + ) is not None: + emit_standard_logging_payload(standard_logging_payload) + elif self.call_type == "pass_through_endpoint": + print_verbose( + "Async success callbacks: Got a pass-through endpoint response" + ) + + self.model_call_details["async_complete_streaming_response"] = result + + # cost calculation not possible for pass-through + self.model_call_details["response_cost"] = None + + ## STANDARDIZED LOGGING PAYLOAD + self.model_call_details[ + "standard_logging_object" + ] = get_standard_logging_object_payload( + kwargs=self.model_call_details, + init_response_obj=result, + start_time=start_time, + end_time=end_time, + logging_obj=self, + status="success", + standard_built_in_tools_params=self.standard_built_in_tools_params, + ) + + # print standard logging payload + if ( + standard_logging_payload := self.model_call_details.get( + "standard_logging_object" + ) + ) is not None: + emit_standard_logging_payload(standard_logging_payload) + callbacks = self.get_combined_callback_list( + dynamic_success_callbacks=self.dynamic_async_success_callbacks, + global_callbacks=litellm._async_success_callback, + ) + + result = redact_message_input_output_from_logging( + model_call_details=( + self.model_call_details if hasattr(self, "model_call_details") else {} + ), + result=result, + ) + + ## LOGGING HOOK ## + + for callback in callbacks: + if isinstance(callback, CustomGuardrail): + from litellm.types.guardrails import GuardrailEventHooks + + if ( + callback.should_run_guardrail( + data=self.model_call_details, + event_type=GuardrailEventHooks.logging_only, + ) + is not True + ): + continue + + self.model_call_details, result = await callback.async_logging_hook( + kwargs=self.model_call_details, + result=result, + call_type=self.call_type, + ) + elif isinstance(callback, CustomLogger): + result = redact_message_input_output_from_custom_logger( + result=result, litellm_logging_obj=self, custom_logger=callback + ) + self.model_call_details, result = await callback.async_logging_hook( + kwargs=self.model_call_details, + result=result, + call_type=self.call_type, + ) + + self.has_run_logging(event_type="async_success") + + for callback in callbacks: + # check if callback can run for this request + litellm_params = self.model_call_details.get("litellm_params", {}) + should_run = self.should_run_callback( + callback=callback, + litellm_params=litellm_params, + event_hook="async_success_handler", + ) + if not should_run: + continue + try: + if callback == "openmeter" and openMeterLogger is not None: + if self.stream is True: + if ( + "async_complete_streaming_response" + in self.model_call_details + ): + await openMeterLogger.async_log_success_event( + kwargs=self.model_call_details, + response_obj=self.model_call_details[ + "async_complete_streaming_response" + ], + start_time=start_time, + end_time=end_time, + ) + else: + await openMeterLogger.async_log_stream_event( # [TODO]: move this to being an async log stream event function + kwargs=self.model_call_details, + response_obj=result, + start_time=start_time, + end_time=end_time, + ) + else: + await openMeterLogger.async_log_success_event( + kwargs=self.model_call_details, + response_obj=result, + start_time=start_time, + end_time=end_time, + ) + + if isinstance(callback, CustomLogger): # custom logger class + model_call_details: Dict = self.model_call_details + ################################## + # call redaction hook for custom logger + model_call_details = callback.redact_standard_logging_payload_from_model_call_details( + model_call_details=model_call_details + ) + ################################## + if self.stream is True: + if "async_complete_streaming_response" in model_call_details: + await callback.async_log_success_event( + kwargs=model_call_details, + response_obj=model_call_details[ + "async_complete_streaming_response" + ], + start_time=start_time, + end_time=end_time, + ) + else: + await callback.async_log_stream_event( # [TODO]: move this to being an async log stream event function + kwargs=model_call_details, + response_obj=result, + start_time=start_time, + end_time=end_time, + ) + else: + await callback.async_log_success_event( + kwargs=model_call_details, + response_obj=result, + start_time=start_time, + end_time=end_time, + ) + if callable(callback): # custom logger functions + global customLogger + if customLogger is None: + customLogger = CustomLogger() + if self.stream: + if ( + "async_complete_streaming_response" + in self.model_call_details + ): + await customLogger.async_log_event( + kwargs=self.model_call_details, + response_obj=self.model_call_details[ + "async_complete_streaming_response" + ], + start_time=start_time, + end_time=end_time, + print_verbose=print_verbose, + callback_func=callback, + ) + else: + await customLogger.async_log_event( + kwargs=self.model_call_details, + response_obj=result, + start_time=start_time, + end_time=end_time, + print_verbose=print_verbose, + callback_func=callback, + ) + if callback == "dynamodb": + global dynamoLogger + if dynamoLogger is None: + dynamoLogger = DyanmoDBLogger() + if self.stream: + if ( + "async_complete_streaming_response" + in self.model_call_details + ): + print_verbose( + "DynamoDB Logger: Got Stream Event - Completed Stream Response" + ) + await dynamoLogger._async_log_event( + kwargs=self.model_call_details, + response_obj=self.model_call_details[ + "async_complete_streaming_response" + ], + start_time=start_time, + end_time=end_time, + print_verbose=print_verbose, + ) + else: + print_verbose( + "DynamoDB Logger: Got Stream Event - No complete stream response as yet" + ) + else: + await dynamoLogger._async_log_event( + kwargs=self.model_call_details, + response_obj=result, + start_time=start_time, + end_time=end_time, + print_verbose=print_verbose, + ) + except Exception: + verbose_logger.error( + f"LiteLLM.LoggingError: [Non-Blocking] Exception occurred while success logging {traceback.format_exc()}" + ) + self._handle_callback_failure(callback=callback) + pass + + def _handle_callback_failure(self, callback: Any): + """ + Handle callback logging failures by incrementing Prometheus metrics. + + Works for both sync and async contexts since Prometheus counter increment is synchronous. + + Args: + callback: The callback that failed + """ + try: + callback_name = self._get_callback_name(callback) + + all_callbacks = litellm.logging_callback_manager._get_all_callbacks() + + for callback_obj in all_callbacks: + if hasattr(callback_obj, "increment_callback_logging_failure"): + callback_obj.increment_callback_logging_failure(callback_name=callback_name) # type: ignore + break # Only increment once + + except Exception as e: + verbose_logger.debug(f"Error in _handle_callback_failure: {str(e)}") + + def _failure_handler_helper_fn( + self, exception, traceback_exception, start_time=None, end_time=None + ): + if start_time is None: + start_time = self.start_time + if end_time is None: + end_time = datetime.datetime.now() + + # on some exceptions, model_call_details is not always initialized, this ensures that we still log those exceptions + if not hasattr(self, "model_call_details"): + self.model_call_details = {} + + self.model_call_details["log_event_type"] = "failed_api_call" + self.model_call_details["exception"] = exception + self.model_call_details["traceback_exception"] = traceback_exception + self.model_call_details["end_time"] = end_time + self.model_call_details.setdefault("original_response", None) + self.model_call_details["response_cost"] = 0 + + if hasattr(exception, "headers") and isinstance(exception.headers, dict): + self.model_call_details.setdefault("litellm_params", {}) + metadata = ( + self.model_call_details["litellm_params"].get("metadata", {}) or {} + ) + metadata.update(exception.headers) + + ## STANDARDIZED LOGGING PAYLOAD + + self.model_call_details[ + "standard_logging_object" + ] = get_standard_logging_object_payload( + kwargs=self.model_call_details, + init_response_obj={}, + start_time=start_time, + end_time=end_time, + logging_obj=self, + status="failure", + error_str=str(exception), + original_exception=exception, + standard_built_in_tools_params=self.standard_built_in_tools_params, + ) + return start_time, end_time + + async def special_failure_handlers(self, exception: Exception): + """ + Custom events, emitted for specific failures. + + Currently just for router model group rate limit error + """ + from litellm.types.router import RouterErrors + + litellm_params: dict = self.model_call_details.get("litellm_params") or {} + metadata = litellm_params.get("metadata") or {} + + ## BASE CASE ## check if rate limit error for model group size 1 + is_base_case = False + if metadata.get("model_group_size") is not None: + model_group_size = metadata.get("model_group_size") + if isinstance(model_group_size, int) and model_group_size == 1: + is_base_case = True + ## check if special error ## + if ( + RouterErrors.no_deployments_available.value not in str(exception) + and is_base_case is False + ): + return + + ## get original model group ## + + model_group = metadata.get("model_group") or None + for callback in litellm._async_failure_callback: + if isinstance(callback, CustomLogger): # custom logger class + await callback.log_model_group_rate_limit_error( + exception=exception, + original_model_group=model_group, + kwargs=self.model_call_details, + ) # type: ignore + + def failure_handler( # noqa: PLR0915 + self, exception, traceback_exception, start_time=None, end_time=None + ): + verbose_logger.debug( + f"Logging Details LiteLLM-Failure Call: {litellm.failure_callback}" + ) + if not self.should_run_logging( + event_type="sync_failure" + ): # prevent double logging + return + litellm_params = self.model_call_details.get("litellm_params", {}) + is_sync_request = ( + litellm_params.get(CallTypes.acompletion.value, False) is not True + and litellm_params.get(CallTypes.aresponses.value, False) is not True + and litellm_params.get(CallTypes.aembedding.value, False) is not True + and litellm_params.get(CallTypes.aimage_generation.value, False) is not True + and litellm_params.get(CallTypes.atranscription.value, False) is not True + ) + + try: + start_time, end_time = self._failure_handler_helper_fn( + exception=exception, + traceback_exception=traceback_exception, + start_time=start_time, + end_time=end_time, + ) + callbacks = self.get_combined_callback_list( + dynamic_success_callbacks=self.dynamic_failure_callbacks, + global_callbacks=litellm.failure_callback, + ) + + result = None # result sent to all loggers, init this to None incase it's not created + + result = redact_message_input_output_from_logging( + model_call_details=( + self.model_call_details + if hasattr(self, "model_call_details") + else {} + ), + result=result, + ) + self.has_run_logging(event_type="sync_failure") + for callback in callbacks: + try: + should_run = self.should_run_callback( + callback=callback, + litellm_params=litellm_params, + event_hook="failure_handler", + ) + if not should_run: + continue + if callback == "lunary" and lunaryLogger is not None: + print_verbose("reaches lunary for logging error!") + + model = self.model + + input = self.model_call_details["input"] + + _type = ( + "embed" + if self.call_type == CallTypes.embedding.value + else "llm" + ) + + lunaryLogger.log_event( + kwargs=self.model_call_details, + type=_type, + event="error", + user_id=self.model_call_details.get("user", "default"), + model=model, + input=input, + error=traceback_exception, + run_id=self.litellm_call_id, + start_time=start_time, + end_time=end_time, + print_verbose=print_verbose, + ) + if callback == "sentry": + print_verbose("sending exception to sentry") + if capture_exception: + capture_exception(exception) + else: + print_verbose( + f"capture exception not initialized: {capture_exception}" + ) + elif callback == "supabase" and supabaseClient is not None: + print_verbose("reaches supabase for logging!") + print_verbose(f"supabaseClient: {supabaseClient}") + supabaseClient.log_event( + model=self.model if hasattr(self, "model") else "", + messages=self.messages, + end_user=self.model_call_details.get("user", "default"), + response_obj=result, + start_time=start_time, + end_time=end_time, + litellm_call_id=self.model_call_details["litellm_call_id"], + print_verbose=print_verbose, + ) + if ( + callable(callback) and customLogger is not None + ): # custom logger functions + customLogger.log_event( + kwargs=self.model_call_details, + response_obj=result, + start_time=start_time, + end_time=end_time, + print_verbose=print_verbose, + callback_func=callback, + ) + if ( + isinstance(callback, CustomLogger) and is_sync_request + ): # custom logger class + callback.log_failure_event( + start_time=start_time, + end_time=end_time, + response_obj=result, + kwargs=self.model_call_details, + ) + if callback == "langfuse": + global langFuseLogger + verbose_logger.debug("reaches langfuse for logging failure") + kwargs = {} + for k, v in self.model_call_details.items(): + if ( + k != "original_response" + ): # copy.deepcopy raises errors as this could be a coroutine + kwargs[k] = v + # this only logs streaming once, complete_streaming_response exists i.e when stream ends + langfuse_logger_to_use = LangFuseHandler.get_langfuse_logger_for_request( + globalLangfuseLogger=langFuseLogger, + standard_callback_dynamic_params=self.standard_callback_dynamic_params, + in_memory_dynamic_logger_cache=in_memory_dynamic_logger_cache, + ) + _response = langfuse_logger_to_use.log_event_on_langfuse( + start_time=start_time, + end_time=end_time, + response_obj=None, + user_id=kwargs.get("user", None), + status_message=str(exception), + level="ERROR", + kwargs=self.model_call_details, + ) + if _response is not None and isinstance(_response, dict): + _trace_id = _response.get("trace_id", None) + if _trace_id is not None: + in_memory_trace_id_cache.set_cache( + litellm_call_id=self.litellm_call_id, + service_name="langfuse", + trace_id=_trace_id, + ) + if callback == "traceloop": + traceloopLogger.log_event( + start_time=start_time, + end_time=end_time, + response_obj=None, + user_id=self.model_call_details.get("user", None), + print_verbose=print_verbose, + status_message=str(exception), + level="ERROR", + kwargs=self.model_call_details, + ) + if callback == "logfire" and logfireLogger is not None: + verbose_logger.debug("reaches logfire for failure logging!") + kwargs = {} + for k, v in self.model_call_details.items(): + if ( + k != "original_response" + ): # copy.deepcopy raises errors as this could be a coroutine + kwargs[k] = v + kwargs["exception"] = exception + + logfireLogger.log_event( + kwargs=kwargs, + response_obj=result, + start_time=start_time, + end_time=end_time, + level=LogfireLevel.ERROR.value, # type: ignore + print_verbose=print_verbose, + ) + + except Exception as e: + print_verbose( + f"LiteLLM.LoggingError: [Non-Blocking] Exception occurred while failure logging with integrations {str(e)}" + ) + print_verbose( + f"LiteLLM.Logging: is sentry capture exception initialized {capture_exception}" + ) + if capture_exception: # log this error to sentry for debugging + capture_exception(e) + except Exception as e: + verbose_logger.exception( + "LiteLLM.LoggingError: [Non-Blocking] Exception occurred while failure logging {}".format( + str(e) + ) + ) + + async def async_failure_handler( + self, exception, traceback_exception, start_time=None, end_time=None + ): + """ + Implementing async callbacks, to handle asyncio event loop issues when custom integrations need to use async functions. + """ + await self.special_failure_handlers(exception=exception) + if not self.should_run_logging( + event_type="async_failure" + ): # prevent double logging + return + start_time, end_time = self._failure_handler_helper_fn( + exception=exception, + traceback_exception=traceback_exception, + start_time=start_time, + end_time=end_time, + ) + + callbacks = self.get_combined_callback_list( + dynamic_success_callbacks=self.dynamic_async_failure_callbacks, + global_callbacks=litellm._async_failure_callback, + ) + + result = None # result sent to all loggers, init this to None incase it's not created + + self.has_run_logging(event_type="async_failure") + for callback in callbacks: + try: + litellm_params = self.model_call_details.get("litellm_params", {}) + should_run = self.should_run_callback( + callback=callback, + litellm_params=litellm_params, + event_hook="async_failure_handler", + ) + if not should_run: + continue + if isinstance(callback, CustomLogger): # custom logger class + await callback.async_log_failure_event( + kwargs=self.model_call_details, + response_obj=result, + start_time=start_time, + end_time=end_time, + ) # type: ignore + if ( + callable(callback) and customLogger is not None + ): # custom logger functions + await customLogger.async_log_event( + kwargs=self.model_call_details, + response_obj=result, + start_time=start_time, + end_time=end_time, + print_verbose=print_verbose, + callback_func=callback, + ) + except Exception as e: + verbose_logger.exception( + "LiteLLM.LoggingError: [Non-Blocking] Exception occurred while failure \ + logging {}\nCallback={}".format( + str(e), callback + ) + ) + # Track callback logging failures in Prometheus + self._handle_callback_failure(callback=callback) + + def _get_trace_id(self, service_name: Literal["langfuse"]) -> Optional[str]: + """ + For the given service (e.g. langfuse), return the trace_id actually logged. + + Used for constructing the url in slack alerting. + + Returns: + - str: The logged trace id + - None: If trace id not yet emitted. + """ + trace_id: Optional[str] = None + if service_name == "langfuse": + trace_id = in_memory_trace_id_cache.get_cache( + litellm_call_id=self.litellm_call_id, service_name=service_name + ) + + return trace_id + + def _get_callback_object(self, service_name: Literal["langfuse"]) -> Optional[Any]: + """ + Return dynamic callback object. + + Meant to solve issue when doing key-based/team-based logging + """ + global langFuseLogger + + if service_name == "langfuse": + if langFuseLogger is None or ( + ( + self.standard_callback_dynamic_params.get("langfuse_public_key") + is not None + and self.standard_callback_dynamic_params.get("langfuse_public_key") + != langFuseLogger.public_key + ) + or ( + self.standard_callback_dynamic_params.get("langfuse_public_key") + is not None + and self.standard_callback_dynamic_params.get("langfuse_public_key") + != langFuseLogger.public_key + ) + or ( + self.standard_callback_dynamic_params.get("langfuse_host") + is not None + and self.standard_callback_dynamic_params.get("langfuse_host") + != langFuseLogger.langfuse_host + ) + ): + return LangFuseLogger( + langfuse_public_key=self.standard_callback_dynamic_params.get( + "langfuse_public_key" + ), + langfuse_secret=self.standard_callback_dynamic_params.get( + "langfuse_secret" + ), + langfuse_host=self.standard_callback_dynamic_params.get( + "langfuse_host" + ), + ) + return langFuseLogger + + return None + + def handle_sync_success_callbacks_for_async_calls( + self, + result: Any, + start_time: datetime.datetime, + end_time: datetime.datetime, + cache_hit: Optional[Any] = None, + ) -> None: + """ + Handles calling success callbacks for Async calls. + + Why: Some callbacks - `langfuse`, `s3` are sync callbacks. We need to call them in the executor. + """ + if self._should_run_sync_callbacks_for_async_calls() is False: + return + + executor.submit( + self.success_handler, + result, + start_time, + end_time, + cache_hit, + ) + + def _should_run_sync_callbacks_for_async_calls(self) -> bool: + """ + Returns: + - bool: True if sync callbacks should be run for async calls. eg. `langfuse`, `s3` + """ + _combined_sync_callbacks = self.get_combined_callback_list( + dynamic_success_callbacks=self.dynamic_success_callbacks, + global_callbacks=litellm.success_callback, + ) + _filtered_success_callbacks = self._remove_internal_custom_logger_callbacks( + _combined_sync_callbacks + ) + _filtered_success_callbacks = self._remove_internal_litellm_callbacks( + _filtered_success_callbacks + ) + return len(_filtered_success_callbacks) > 0 + + def get_combined_callback_list( + self, dynamic_success_callbacks: Optional[List], global_callbacks: List + ) -> List: + if dynamic_success_callbacks is None: + return list(global_callbacks) + return list(set(dynamic_success_callbacks + global_callbacks)) + + def _remove_internal_litellm_callbacks(self, callbacks: List) -> List: + """ + Creates a filtered list of callbacks, excluding internal LiteLLM callbacks. + + Args: + callbacks: List of callback functions/strings to filter + + Returns: + List of filtered callbacks with internal ones removed + """ + filtered = [ + cb for cb in callbacks if not self._is_internal_litellm_proxy_callback(cb) + ] + + verbose_logger.debug(f"Filtered callbacks: {filtered}") + return filtered + + def _get_callback_name(self, cb) -> str: + """ + Helper to get the name of a callback function + + Args: + cb: The callback object/function/string to get the name of + + Returns: + The name of the callback + """ + if isinstance(cb, str): + return cb + if hasattr(cb, "__name__"): + return cb.__name__ + if hasattr(cb, "__func__"): + return cb.__func__.__name__ + if hasattr(cb, "__class__"): + return cb.__class__.__name__ + return str(cb) + + def _is_internal_litellm_proxy_callback(self, cb) -> bool: + """Helper to check if a callback is internal""" + INTERNAL_PREFIXES = [ + "_PROXY", + "_service_logger.ServiceLogging", + "sync_deployment_callback_on_success", + ] + if isinstance(cb, str): + return False + + if not callable(cb): + return True + + cb_name = self._get_callback_name(cb) + return any(prefix in cb_name for prefix in INTERNAL_PREFIXES) + + def _remove_internal_custom_logger_callbacks(self, callbacks: List) -> List: + """ + Removes internal custom logger callbacks from the list. + """ + _new_callbacks = [] + for _c in callbacks: + if isinstance(_c, CustomLogger): + continue + elif ( + isinstance(_c, str) + and _c in litellm._known_custom_logger_compatible_callbacks + ): + continue + _new_callbacks.append(_c) + return _new_callbacks + + def _get_assembled_streaming_response( + self, + result: Union[ + ModelResponse, + TextCompletionResponse, + ModelResponseStream, + ResponseCompletedEvent, + Any, + ], + start_time: datetime.datetime, + end_time: datetime.datetime, + is_async: bool, + streaming_chunks: List[Any], + ) -> Optional[Union[ModelResponse, TextCompletionResponse, ResponsesAPIResponse]]: + if isinstance(result, ModelResponse): + return result + elif isinstance(result, TextCompletionResponse): + return result + elif isinstance(result, ResponseCompletedEvent): + ## return unified Usage object + if isinstance(result.response.usage, ResponseAPIUsage): + transformed_usage = ( + ResponseAPILoggingUtils._transform_response_api_usage_to_chat_usage( + result.response.usage + ) + ) + # Set as dict instead of Usage object so model_dump() serializes it correctly + setattr( + result.response, + "usage", + ( + transformed_usage.model_dump() + if hasattr(transformed_usage, "model_dump") + else dict(transformed_usage) + ), + ) + return result.response + else: + return None + return None + + def _handle_anthropic_messages_response_logging(self, result: Any) -> ModelResponse: + """ + Handles logging for Anthropic messages responses. + + Args: + result: The response object from the model call + + Returns: + The the response object from the model call + + - For Non-streaming responses, we need to transform the response to a ModelResponse object. + - For streaming responses, anthropic_messages handler calls success_handler with a assembled ModelResponse. + """ + import httpx + + if self.stream and isinstance(result, ModelResponse): + return result + elif isinstance(result, ModelResponse): + return result + + httpx_response = self.model_call_details.get("httpx_response", None) + if httpx_response and isinstance(httpx_response, httpx.Response): + result = litellm.AnthropicConfig().transform_response( + raw_response=httpx_response, + model_response=litellm.ModelResponse(), + model=self.model, + messages=[], + logging_obj=self, + optional_params={}, + api_key="", + request_data={}, + encoding=litellm.encoding, + json_mode=False, + litellm_params={}, + ) + else: + from litellm.types.llms.anthropic import AnthropicResponse + + pydantic_result = AnthropicResponse.model_validate(result) + import httpx + + result = litellm.AnthropicConfig().transform_parsed_response( + completion_response=pydantic_result.model_dump(), + raw_response=httpx.Response( + status_code=200, + headers={}, + ), + model_response=litellm.ModelResponse(), + json_mode=None, + ) + return result + + def _handle_non_streaming_google_genai_generate_content_response_logging( + self, result: Any + ) -> ModelResponse: + """ + Handles logging for Google GenAI generate content responses. + """ + import httpx + + httpx_response = self.model_call_details.get("httpx_response", None) + if httpx_response is None: + raise ValueError("Google GenAI Generate Content: httpx_response is None") + dict_result = httpx_response.json() + result = litellm.VertexGeminiConfig()._transform_google_generate_content_to_openai_model_response( + completion_response=dict_result, + model_response=litellm.ModelResponse(), + model=self.model, + logging_obj=self, + raw_response=httpx.Response( + status_code=200, + headers={}, + ), + ) + return result + + def _handle_a2a_response_logging(self, result: Any) -> Any: + """ + Handles logging for A2A (Agent-to-Agent) responses. + + Adds usage from model_call_details to the result if available. + Uses Pydantic's model_copy to avoid modifying the original response. + + Args: + result: The LiteLLMSendMessageResponse from the A2A call + + Returns: + The response object with usage added if available + """ + # Get usage from model_call_details (set by asend_message) + usage = self.model_call_details.get("usage") + if usage is None: + return result + + # Deep copy result and add usage + result_copy = result.model_copy(deep=True) + result_copy.usage = ( + usage.model_dump() if hasattr(usage, "model_dump") else dict(usage) + ) + return result_copy + + +def _get_masked_values( + sensitive_object: dict, + ignore_sensitive_values: bool = False, + mask_all_values: bool = False, + unmasked_length: int = 4, + number_of_asterisks: Optional[int] = 4, +) -> dict: + """ + Internal debugging helper function + + Masks the headers of the request sent from LiteLLM + + Args: + masked_length: Optional length for the masked portion (number of *). If set, will use exactly this many * + regardless of original string length. The total length will be unmasked_length + masked_length. + """ + sensitive_keywords = [ + "authorization", + "token", + "key", + "secret", + "vertex_credentials", + ] + return { + k: ( + # If ignore_sensitive_values is True, or if this key doesn't contain sensitive keywords, return original value + v + if ignore_sensitive_values + or not any( + sensitive_keyword in k.lower() + for sensitive_keyword in sensitive_keywords + ) + else ( + # Apply masking to sensitive keys + ( + v[: unmasked_length // 2] + + "*" * number_of_asterisks + + v[-unmasked_length // 2 :] + ) + if ( + isinstance(v, str) + and len(v) > unmasked_length + and number_of_asterisks is not None + ) + else ( + ( + v[: unmasked_length // 2] + + "*" * (len(v) - unmasked_length) + + v[-unmasked_length // 2 :] + ) + if (isinstance(v, str) and len(v) > unmasked_length) + else ("*****" if isinstance(v, str) else v) + ) + ) + ) + for k, v in sensitive_object.items() + } + + +def set_callbacks(callback_list, function_id=None): # noqa: PLR0915 + """ + Globally sets the callback client + """ + global sentry_sdk_instance, capture_exception, add_breadcrumb, slack_app, alerts_channel, traceloopLogger, athinaLogger, heliconeLogger, supabaseClient, lunaryLogger, promptLayerLogger, langFuseLogger, customLogger, weightsBiasesLogger, logfireLogger, dynamoLogger, s3Logger, dataDogLogger, prometheusLogger, greenscaleLogger, openMeterLogger, deepevalLogger + + try: + for callback in callback_list: + if callback == "sentry": + try: + import sentry_sdk + except ImportError: + print_verbose("Package 'sentry_sdk' is missing. Installing it...") + subprocess.check_call( + [sys.executable, "-m", "pip", "install", "sentry_sdk"] + ) + import sentry_sdk + from sentry_sdk.scrubber import EventScrubber + + sentry_sdk_instance = sentry_sdk + sentry_trace_rate = ( + os.environ.get("SENTRY_API_TRACE_RATE") + if "SENTRY_API_TRACE_RATE" in os.environ + else "1.0" + ) + sentry_sample_rate = ( + os.environ.get("SENTRY_API_SAMPLE_RATE") + if "SENTRY_API_SAMPLE_RATE" in os.environ + else "1.0" + ) + sentry_sdk_instance.init( + dsn=os.environ.get("SENTRY_DSN"), + traces_sample_rate=float(sentry_trace_rate), # type: ignore + sample_rate=float( + sentry_sample_rate if sentry_sample_rate else 1.0 + ), + send_default_pii=False, # Prevent sending Personal Identifiable Information + event_scrubber=EventScrubber( + denylist=SENTRY_DENYLIST, pii_denylist=SENTRY_PII_DENYLIST + ), + environment=os.environ.get("SENTRY_ENVIRONMENT", "production"), + ) + capture_exception = sentry_sdk_instance.capture_exception + add_breadcrumb = sentry_sdk_instance.add_breadcrumb + elif callback == "slack": + try: + from slack_bolt import App + except ImportError: + print_verbose("Package 'slack_bolt' is missing. Installing it...") + subprocess.check_call( + [sys.executable, "-m", "pip", "install", "slack_bolt"] + ) + from slack_bolt import App + slack_app = App( + token=os.environ.get("SLACK_API_TOKEN"), + signing_secret=os.environ.get("SLACK_API_SECRET"), + ) + alerts_channel = os.environ["SLACK_API_CHANNEL"] + print_verbose(f"Initialized Slack App: {slack_app}") + elif callback == "traceloop": + traceloopLogger = TraceloopLogger() + elif callback == "athina": + athinaLogger = AthinaLogger() + print_verbose("Initialized Athina Logger") + elif callback == "helicone": + heliconeLogger = HeliconeLogger() + elif callback == "lunary": + lunaryLogger = LunaryLogger() + elif callback == "promptlayer": + promptLayerLogger = PromptLayerLogger() + elif callback == "langfuse": + langFuseLogger = LangFuseLogger( + langfuse_public_key=None, langfuse_secret=None, langfuse_host=None + ) + elif callback == "openmeter": + openMeterLogger = OpenMeterLogger() + elif callback == "datadog": + dataDogLogger = DataDogLogger() + elif callback == "dynamodb": + dynamoLogger = DyanmoDBLogger() + elif callback == "s3": + s3Logger = S3Logger() + elif callback == "wandb": + from litellm.integrations.weights_biases import WeightsBiasesLogger + + weightsBiasesLogger = WeightsBiasesLogger() + elif callback == "logfire": + logfireLogger = LogfireLogger() + elif callback == "supabase": + print_verbose("instantiating supabase") + supabaseClient = Supabase() + elif callback == "greenscale": + greenscaleLogger = GreenscaleLogger() + print_verbose("Initialized Greenscale Logger") + elif callable(callback): + customLogger = CustomLogger() + except Exception as e: + raise e + return None + + +def _init_custom_logger_compatible_class( # noqa: PLR0915 + logging_integration: _custom_logger_compatible_callbacks_literal, + internal_usage_cache: Optional[DualCache], + llm_router: Optional[ + Any + ], # expect litellm.Router, but typing errors due to circular import + custom_logger_init_args: Optional[dict] = {}, +) -> Optional[CustomLogger]: + """ + Initialize a custom logger compatible class + """ + try: + custom_logger_init_args = custom_logger_init_args or {} + if logging_integration == "agentops": # Add AgentOps initialization + for callback in _in_memory_loggers: + if isinstance(callback, AgentOps): + return callback # type: ignore + + agentops_logger = AgentOps() + _in_memory_loggers.append(agentops_logger) + return agentops_logger # type: ignore + elif logging_integration == "lago": + for callback in _in_memory_loggers: + if isinstance(callback, LagoLogger): + return callback # type: ignore + + lago_logger = LagoLogger() + _in_memory_loggers.append(lago_logger) + return lago_logger # type: ignore + elif logging_integration == "openmeter": + for callback in _in_memory_loggers: + if isinstance(callback, OpenMeterLogger): + return callback # type: ignore + + _openmeter_logger = OpenMeterLogger() + _in_memory_loggers.append(_openmeter_logger) + return _openmeter_logger # type: ignore + elif logging_integration == "posthog": + for callback in _in_memory_loggers: + if isinstance(callback, PostHogLogger): + return callback # type: ignore + + _posthog_logger = PostHogLogger() + _in_memory_loggers.append(_posthog_logger) + return _posthog_logger # type: ignore + elif logging_integration == "braintrust": + from litellm.integrations.braintrust_logging import BraintrustLogger + + for callback in _in_memory_loggers: + if isinstance(callback, BraintrustLogger): + return callback # type: ignore + + braintrust_logger = BraintrustLogger() + _in_memory_loggers.append(braintrust_logger) + return braintrust_logger # type: ignore + elif logging_integration == "langsmith": + for callback in _in_memory_loggers: + if isinstance(callback, LangsmithLogger): + return callback # type: ignore + + _langsmith_logger = LangsmithLogger() + _in_memory_loggers.append(_langsmith_logger) + return _langsmith_logger # type: ignore + elif logging_integration == "argilla": + for callback in _in_memory_loggers: + if isinstance(callback, ArgillaLogger): + return callback # type: ignore + + _argilla_logger = ArgillaLogger() + _in_memory_loggers.append(_argilla_logger) + return _argilla_logger # type: ignore + elif logging_integration == "literalai": + for callback in _in_memory_loggers: + if isinstance(callback, LiteralAILogger): + return callback # type: ignore + + _literalai_logger = LiteralAILogger() + _in_memory_loggers.append(_literalai_logger) + return _literalai_logger # type: ignore + elif logging_integration == "prometheus": + PrometheusLogger = _get_cached_prometheus_logger() + + for callback in _in_memory_loggers: + if isinstance(callback, PrometheusLogger): + return callback # type: ignore + + _prometheus_logger = PrometheusLogger() + _in_memory_loggers.append(_prometheus_logger) + return _prometheus_logger # type: ignore + elif logging_integration == "datadog": + for callback in _in_memory_loggers: + if isinstance(callback, DataDogLogger): + return callback # type: ignore + + _datadog_logger = DataDogLogger() + _in_memory_loggers.append(_datadog_logger) + return _datadog_logger # type: ignore + elif logging_integration == "datadog_llm_observability": + _datadog_llm_obs_logger = DataDogLLMObsLogger() + _in_memory_loggers.append(_datadog_llm_obs_logger) + return _datadog_llm_obs_logger # type: ignore + elif logging_integration == "azure_sentinel": + for callback in _in_memory_loggers: + if isinstance(callback, AzureSentinelLogger): + return callback # type: ignore + + _azure_sentinel_logger = AzureSentinelLogger() + _in_memory_loggers.append(_azure_sentinel_logger) + return _azure_sentinel_logger # type: ignore + elif logging_integration == "gcs_bucket": + for callback in _in_memory_loggers: + if isinstance(callback, GCSBucketLogger): + return callback # type: ignore + + _gcs_bucket_logger = GCSBucketLogger() + _in_memory_loggers.append(_gcs_bucket_logger) + return _gcs_bucket_logger # type: ignore + elif logging_integration == "s3_v2": + for callback in _in_memory_loggers: + if isinstance(callback, S3V2Logger): + return callback # type: ignore + + _s3_v2_logger = S3V2Logger() + _in_memory_loggers.append(_s3_v2_logger) + return _s3_v2_logger # type: ignore + elif logging_integration == "aws_sqs": + for callback in _in_memory_loggers: + if isinstance(callback, SQSLogger): + return callback # type: ignore + + _aws_sqs_logger = SQSLogger() + _in_memory_loggers.append(_aws_sqs_logger) + return _aws_sqs_logger # type: ignore + elif logging_integration == "azure_storage": + for callback in _in_memory_loggers: + if isinstance(callback, AzureBlobStorageLogger): + return callback # type: ignore + + _azure_storage_logger = AzureBlobStorageLogger() + _in_memory_loggers.append(_azure_storage_logger) + return _azure_storage_logger # type: ignore + elif logging_integration == "opik": + for callback in _in_memory_loggers: + if isinstance(callback, OpikLogger): + return callback # type: ignore + + _opik_logger = OpikLogger() + _in_memory_loggers.append(_opik_logger) + return _opik_logger # type: ignore + elif logging_integration == "arize": + from litellm.integrations.opentelemetry import ( + OpenTelemetry, + OpenTelemetryConfig, + ) + + arize_config = ArizeLogger.get_arize_config() + if arize_config.endpoint is None: + raise ValueError( + "No valid endpoint found for Arize, please set 'ARIZE_ENDPOINT' to your GRPC endpoint or 'ARIZE_HTTP_ENDPOINT' to your HTTP endpoint" + ) + otel_config = OpenTelemetryConfig( + exporter=arize_config.protocol, + endpoint=arize_config.endpoint, + service_name=arize_config.project_name, + ) + + os.environ[ + "OTEL_EXPORTER_OTLP_TRACES_HEADERS" + ] = f"space_id={arize_config.space_key or arize_config.space_id},api_key={arize_config.api_key}" + for callback in _in_memory_loggers: + if ( + isinstance(callback, ArizeLogger) + and callback.callback_name == "arize" + ): + return callback # type: ignore + _arize_otel_logger = ArizeLogger(config=otel_config, callback_name="arize") + _in_memory_loggers.append(_arize_otel_logger) + return _arize_otel_logger # type: ignore + elif logging_integration == "arize_phoenix": + from litellm.integrations.opentelemetry import ( + OpenTelemetry, + OpenTelemetryConfig, + ) + + arize_phoenix_config = ArizePhoenixLogger.get_arize_phoenix_config() + otel_config = OpenTelemetryConfig( + exporter=arize_phoenix_config.protocol, + endpoint=arize_phoenix_config.endpoint, + headers=arize_phoenix_config.otlp_auth_headers, + ) + if arize_phoenix_config.project_name: + existing_attrs = os.environ.get("OTEL_RESOURCE_ATTRIBUTES", "") + # Add openinference.project.name attribute + if existing_attrs: + os.environ[ + "OTEL_RESOURCE_ATTRIBUTES" + ] = f"{existing_attrs},openinference.project.name={arize_phoenix_config.project_name}" + else: + os.environ[ + "OTEL_RESOURCE_ATTRIBUTES" + ] = f"openinference.project.name={arize_phoenix_config.project_name}" + + # Set Phoenix project name from environment variable + phoenix_project_name = os.environ.get("PHOENIX_PROJECT_NAME", None) + if phoenix_project_name: + existing_attrs = os.environ.get("OTEL_RESOURCE_ATTRIBUTES", "") + # Add openinference.project.name attribute + if existing_attrs: + os.environ[ + "OTEL_RESOURCE_ATTRIBUTES" + ] = f"{existing_attrs},openinference.project.name={phoenix_project_name}" + else: + os.environ[ + "OTEL_RESOURCE_ATTRIBUTES" + ] = f"openinference.project.name={phoenix_project_name}" + + # auth can be disabled on local deployments of arize phoenix + if arize_phoenix_config.otlp_auth_headers is not None: + os.environ[ + "OTEL_EXPORTER_OTLP_TRACES_HEADERS" + ] = arize_phoenix_config.otlp_auth_headers + + for callback in _in_memory_loggers: + if ( + isinstance(callback, ArizePhoenixLogger) + and callback.callback_name == "arize_phoenix" + ): + return callback # type: ignore + _arize_phoenix_otel_logger = ArizePhoenixLogger( + config=otel_config, callback_name="arize_phoenix" + ) + _in_memory_loggers.append(_arize_phoenix_otel_logger) + return _arize_phoenix_otel_logger # type: ignore + elif logging_integration == "levo": + from litellm.integrations.levo.levo import LevoLogger + from litellm.integrations.opentelemetry import ( + OpenTelemetry, + OpenTelemetryConfig, + ) + + levo_config = LevoLogger.get_levo_config() + otel_config = OpenTelemetryConfig( + exporter=levo_config.protocol, + endpoint=levo_config.endpoint, + headers=levo_config.otlp_auth_headers, + ) + + # Check if LevoLogger instance already exists + for callback in _in_memory_loggers: + if ( + isinstance(callback, LevoLogger) + and callback.callback_name == "levo" + ): + return callback # type: ignore + + _levo_otel_logger = LevoLogger(config=otel_config, callback_name="levo") + _in_memory_loggers.append(_levo_otel_logger) + return _levo_otel_logger # type: ignore + elif logging_integration == "otel": + from litellm.integrations.opentelemetry import OpenTelemetry + + for callback in _in_memory_loggers: + if type(callback) is OpenTelemetry: + return callback # type: ignore + otel_logger = OpenTelemetry( + **_get_custom_logger_settings_from_proxy_server( + callback_name=logging_integration + ) + ) + _in_memory_loggers.append(otel_logger) + return otel_logger # type: ignore + + elif logging_integration == "galileo": + for callback in _in_memory_loggers: + if isinstance(callback, GalileoObserve): + return callback # type: ignore + + galileo_logger = GalileoObserve() + _in_memory_loggers.append(galileo_logger) + return galileo_logger # type: ignore + elif logging_integration == "cloudzero": + from litellm.integrations.cloudzero.cloudzero import CloudZeroLogger + + for callback in _in_memory_loggers: + if isinstance(callback, CloudZeroLogger): + return callback # type: ignore + cloudzero_logger = CloudZeroLogger() + _in_memory_loggers.append(cloudzero_logger) + return cloudzero_logger # type: ignore + elif logging_integration == "focus": + from litellm.integrations.focus.focus_logger import FocusLogger + + for callback in _in_memory_loggers: + if isinstance(callback, FocusLogger): + return callback # type: ignore + focus_logger = FocusLogger() + _in_memory_loggers.append(focus_logger) + return focus_logger # type: ignore + elif logging_integration == "deepeval": + for callback in _in_memory_loggers: + if isinstance(callback, DeepEvalLogger): + return callback # type: ignore + deepeval_logger = DeepEvalLogger() + _in_memory_loggers.append(deepeval_logger) + return deepeval_logger # type: ignore + + elif logging_integration == "logfire": + if "LOGFIRE_TOKEN" not in os.environ: + raise ValueError("LOGFIRE_TOKEN not found in environment variables") + from litellm.integrations.opentelemetry import ( + OpenTelemetry, + OpenTelemetryConfig, + ) + + logfire_base_url = os.getenv( + "LOGFIRE_BASE_URL", "https://logfire-api.pydantic.dev" + ) + otel_config = OpenTelemetryConfig( + exporter="otlp_http", + endpoint=f"{logfire_base_url.rstrip('/')}/v1/traces", + headers=f"Authorization={os.getenv('LOGFIRE_TOKEN')}", + ) + for callback in _in_memory_loggers: + if isinstance(callback, OpenTelemetry): + return callback # type: ignore + _otel_logger = OpenTelemetry(config=otel_config) + _in_memory_loggers.append(_otel_logger) + return _otel_logger # type: ignore + elif logging_integration == "dynamic_rate_limiter": + from litellm.proxy.hooks.dynamic_rate_limiter import ( + _PROXY_DynamicRateLimitHandler, + ) + + for callback in _in_memory_loggers: + if isinstance(callback, _PROXY_DynamicRateLimitHandler): + return callback # type: ignore + + if internal_usage_cache is None: + raise Exception( + "Internal Error: Cache cannot be empty - internal_usage_cache={}".format( + internal_usage_cache + ) + ) + + dynamic_rate_limiter_obj = _PROXY_DynamicRateLimitHandler( + internal_usage_cache=internal_usage_cache + ) + + if llm_router is not None and isinstance(llm_router, litellm.Router): + dynamic_rate_limiter_obj.update_variables(llm_router=llm_router) + _in_memory_loggers.append(dynamic_rate_limiter_obj) + return dynamic_rate_limiter_obj # type: ignore + elif logging_integration == "dynamic_rate_limiter_v3": + from litellm.proxy.hooks.dynamic_rate_limiter_v3 import ( + _PROXY_DynamicRateLimitHandlerV3, + ) + + for callback in _in_memory_loggers: + if isinstance(callback, _PROXY_DynamicRateLimitHandlerV3): + return callback # type: ignore + + if internal_usage_cache is None: + raise Exception( + "Internal Error: Cache cannot be empty - internal_usage_cache={}".format( + internal_usage_cache + ) + ) + + dynamic_rate_limiter_obj_v3 = _PROXY_DynamicRateLimitHandlerV3( + internal_usage_cache=internal_usage_cache + ) + + if llm_router is not None and isinstance(llm_router, litellm.Router): + dynamic_rate_limiter_obj_v3.update_variables(llm_router=llm_router) + _in_memory_loggers.append(dynamic_rate_limiter_obj_v3) + return dynamic_rate_limiter_obj_v3 # type: ignore + elif logging_integration == "langtrace": + if "LANGTRACE_API_KEY" not in os.environ: + raise ValueError("LANGTRACE_API_KEY not found in environment variables") + + from litellm.integrations.opentelemetry import ( + OpenTelemetry, + OpenTelemetryConfig, + ) + + otel_config = OpenTelemetryConfig( + exporter="otlp_http", + endpoint="https://langtrace.ai/api/trace", + ) + os.environ[ + "OTEL_EXPORTER_OTLP_TRACES_HEADERS" + ] = f"api_key={os.getenv('LANGTRACE_API_KEY')}" + for callback in _in_memory_loggers: + if ( + isinstance(callback, OpenTelemetry) + and callback.callback_name == "langtrace" + ): + return callback # type: ignore + _otel_logger = OpenTelemetry(config=otel_config, callback_name="langtrace") + _in_memory_loggers.append(_otel_logger) + return _otel_logger # type: ignore + + elif logging_integration == "mlflow": + for callback in _in_memory_loggers: + if isinstance(callback, MlflowLogger): + return callback # type: ignore + + _mlflow_logger = MlflowLogger() + _in_memory_loggers.append(_mlflow_logger) + return _mlflow_logger # type: ignore + elif logging_integration == "langfuse": + for callback in _in_memory_loggers: + if isinstance(callback, LangfusePromptManagement): + return callback + + langfuse_logger = LangfusePromptManagement() + _in_memory_loggers.append(langfuse_logger) + return langfuse_logger # type: ignore + elif logging_integration == "langfuse_otel": + from litellm.integrations.langfuse.langfuse_otel import LangfuseOtelLogger + + for callback in _in_memory_loggers: + if ( + isinstance(callback, LangfuseOtelLogger) + and callback.callback_name == "langfuse_otel" + ): + return callback # type: ignore + # Allow LangfuseOtelLogger to initialize its own config safely + # This prevents startup crashes if LANGFUSE keys are not in env (e.g. for dynamic usage) + _otel_logger = LangfuseOtelLogger( + config=None, callback_name="langfuse_otel" + ) + _in_memory_loggers.append(_otel_logger) + return _otel_logger # type: ignore + elif logging_integration == "weave_otel": + from litellm.integrations.opentelemetry import OpenTelemetryConfig + from litellm.integrations.weave.weave_otel import ( + WeaveOtelLogger, + get_weave_otel_config, + ) + + weave_otel_config = get_weave_otel_config() + + otel_config = OpenTelemetryConfig( + exporter=weave_otel_config.protocol, + endpoint=weave_otel_config.endpoint, + headers=weave_otel_config.otlp_auth_headers, + ) + + for callback in _in_memory_loggers: + if ( + isinstance(callback, WeaveOtelLogger) + and callback.callback_name == "weave_otel" + ): + return callback # type: ignore + _otel_logger = WeaveOtelLogger( + config=otel_config, callback_name="weave_otel" + ) + _in_memory_loggers.append(_otel_logger) + return _otel_logger # type: ignore + elif logging_integration == "pagerduty": + for callback in _in_memory_loggers: + if isinstance(callback, PagerDutyAlerting): + return callback + pagerduty_logger = PagerDutyAlerting(**custom_logger_init_args) + _in_memory_loggers.append(pagerduty_logger) + return pagerduty_logger # type: ignore + elif logging_integration == "anthropic_cache_control_hook": + for callback in _in_memory_loggers: + if isinstance(callback, AnthropicCacheControlHook): + return callback + anthropic_cache_control_hook = AnthropicCacheControlHook() + _in_memory_loggers.append(anthropic_cache_control_hook) + return anthropic_cache_control_hook # type: ignore + elif logging_integration == "vector_store_pre_call_hook": + from litellm.integrations.vector_store_integrations.vector_store_pre_call_hook import ( + VectorStorePreCallHook, + ) + + for callback in _in_memory_loggers: + if isinstance(callback, VectorStorePreCallHook): + return callback + vector_store_pre_call_hook = VectorStorePreCallHook() + _in_memory_loggers.append(vector_store_pre_call_hook) + return vector_store_pre_call_hook # type: ignore + elif logging_integration == "gcs_pubsub": + for callback in _in_memory_loggers: + if isinstance(callback, GcsPubSubLogger): + return callback + _gcs_pubsub_logger = GcsPubSubLogger() + _in_memory_loggers.append(_gcs_pubsub_logger) + return _gcs_pubsub_logger # type: ignore + elif logging_integration == "generic_api": + for callback in _in_memory_loggers: + if isinstance(callback, GenericAPILogger): + return callback + generic_api_logger = GenericAPILogger() + _in_memory_loggers.append(generic_api_logger) + return generic_api_logger # type: ignore + elif logging_integration == "resend_email": + for callback in _in_memory_loggers: + if isinstance(callback, ResendEmailLogger): + return callback + resend_email_logger = ResendEmailLogger() + _in_memory_loggers.append(resend_email_logger) + return resend_email_logger # type: ignore + elif logging_integration == "sendgrid_email": + for callback in _in_memory_loggers: + if isinstance(callback, SendGridEmailLogger): + return callback + sendgrid_email_logger = SendGridEmailLogger() + _in_memory_loggers.append(sendgrid_email_logger) + return sendgrid_email_logger # type: ignore + elif logging_integration == "smtp_email": + for callback in _in_memory_loggers: + if isinstance(callback, SMTPEmailLogger): + return callback + smtp_email_logger = SMTPEmailLogger() + _in_memory_loggers.append(smtp_email_logger) + return smtp_email_logger # type: ignore + elif logging_integration == "humanloop": + for callback in _in_memory_loggers: + if isinstance(callback, HumanloopLogger): + return callback + + humanloop_logger = HumanloopLogger() + _in_memory_loggers.append(humanloop_logger) + return humanloop_logger # type: ignore + elif logging_integration == "dotprompt": + for callback in _in_memory_loggers: + if isinstance(callback, DotpromptManager): + return callback + + dotprompt_logger = DotpromptManager() + _in_memory_loggers.append(dotprompt_logger) + return dotprompt_logger # type: ignore + elif logging_integration == "bitbucket": + from litellm.integrations.bitbucket.bitbucket_prompt_manager import ( + BitBucketPromptManager, + ) + + for callback in _in_memory_loggers: + if isinstance(callback, BitBucketPromptManager): + return callback + + # Get global BitBucket config + bitbucket_config = getattr(litellm, "global_bitbucket_config", None) + if bitbucket_config is None: + raise ValueError( + "BitBucket configuration not found. Please set litellm.global_bitbucket_config first." + ) + + bitbucket_logger = BitBucketPromptManager(bitbucket_config=bitbucket_config) + _in_memory_loggers.append(bitbucket_logger) + return bitbucket_logger # type: ignore + elif logging_integration == "gitlab": + from litellm.integrations.gitlab.gitlab_prompt_manager import ( + GitLabPromptManager, + ) + + for callback in _in_memory_loggers: + if isinstance(callback, GitLabPromptManager): + return callback + + # Get global BitBucket config + gitlab_config = getattr(litellm, "global_gitlab_config", None) + if gitlab_config is None: + raise ValueError( + "Gitlab configuration not found. Please set litellm.global_gitlab_config first." + ) + + gitlab_logger = GitLabPromptManager(gitlab_config=gitlab_config) + _in_memory_loggers.append(gitlab_logger) + return gitlab_logger # type: ignore + return None + except Exception as e: + verbose_logger.exception( + f"[Non-Blocking Error] Error initializing custom logger: {e}" + ) + return None + return None + + +def get_custom_logger_compatible_class( # noqa: PLR0915 + logging_integration: _custom_logger_compatible_callbacks_literal, +) -> Optional[CustomLogger]: + try: + if logging_integration == "lago": + for callback in _in_memory_loggers: + if isinstance(callback, LagoLogger): + return callback + elif logging_integration == "openmeter": + for callback in _in_memory_loggers: + if isinstance(callback, OpenMeterLogger): + return callback + elif logging_integration == "braintrust": + from litellm.integrations.braintrust_logging import BraintrustLogger + + for callback in _in_memory_loggers: + if isinstance(callback, BraintrustLogger): + return callback + elif logging_integration == "galileo": + for callback in _in_memory_loggers: + if isinstance(callback, GalileoObserve): + return callback + elif logging_integration == "cloudzero": + from litellm.integrations.cloudzero.cloudzero import CloudZeroLogger + + for callback in _in_memory_loggers: + if isinstance(callback, CloudZeroLogger): + return callback + elif logging_integration == "focus": + from litellm.integrations.focus.focus_logger import FocusLogger + + for callback in _in_memory_loggers: + if isinstance(callback, FocusLogger): + return callback + elif logging_integration == "deepeval": + for callback in _in_memory_loggers: + if isinstance(callback, DeepEvalLogger): + return callback + elif logging_integration == "langsmith": + for callback in _in_memory_loggers: + if isinstance(callback, LangsmithLogger): + return callback + elif logging_integration == "argilla": + for callback in _in_memory_loggers: + if isinstance(callback, ArgillaLogger): + return callback + elif logging_integration == "literalai": + for callback in _in_memory_loggers: + if isinstance(callback, LiteralAILogger): + return callback + elif logging_integration == "prometheus": + PrometheusLogger = _get_cached_prometheus_logger() + for callback in _in_memory_loggers: + if isinstance(callback, PrometheusLogger): + return callback + elif logging_integration == "datadog": + for callback in _in_memory_loggers: + if isinstance(callback, DataDogLogger): + return callback + elif logging_integration == "datadog_llm_observability": + for callback in _in_memory_loggers: + if isinstance(callback, DataDogLLMObsLogger): + return callback + elif logging_integration == "azure_sentinel": + for callback in _in_memory_loggers: + if isinstance(callback, AzureSentinelLogger): + return callback + elif logging_integration == "gcs_bucket": + for callback in _in_memory_loggers: + if isinstance(callback, GCSBucketLogger): + return callback + elif logging_integration == "s3_v2": + for callback in _in_memory_loggers: + if isinstance(callback, S3V2Logger): + return callback + elif logging_integration == "aws_sqs": + for callback in _in_memory_loggers: + if isinstance(callback, SQSLogger): + return callback + _aws_sqs_logger = SQSLogger() + _in_memory_loggers.append(_aws_sqs_logger) + return _aws_sqs_logger # type: ignore + elif logging_integration == "azure_storage": + for callback in _in_memory_loggers: + if isinstance(callback, AzureBlobStorageLogger): + return callback + elif logging_integration == "opik": + for callback in _in_memory_loggers: + if isinstance(callback, OpikLogger): + return callback + elif logging_integration == "langfuse": + for callback in _in_memory_loggers: + if isinstance(callback, LangfusePromptManagement): + return callback + elif logging_integration == "otel": + from litellm.integrations.opentelemetry import OpenTelemetry + + for callback in _in_memory_loggers: + if isinstance(callback, OpenTelemetry): + return callback + elif logging_integration == "arize": + if "ARIZE_API_KEY" not in os.environ: + raise ValueError("ARIZE_API_KEY not found in environment variables") + for callback in _in_memory_loggers: + if ( + isinstance(callback, ArizeLogger) + and callback.callback_name == "arize" + ): + return callback + elif logging_integration == "logfire": + if "LOGFIRE_TOKEN" not in os.environ: + raise ValueError("LOGFIRE_TOKEN not found in environment variables") + from litellm.integrations.opentelemetry import OpenTelemetry + + for callback in _in_memory_loggers: + if isinstance(callback, OpenTelemetry): + return callback # type: ignore + + elif logging_integration == "dynamic_rate_limiter": + from litellm.proxy.hooks.dynamic_rate_limiter import ( + _PROXY_DynamicRateLimitHandler, + ) + + for callback in _in_memory_loggers: + if isinstance(callback, _PROXY_DynamicRateLimitHandler): + return callback # type: ignore + elif logging_integration == "dynamic_rate_limiter_v3": + from litellm.proxy.hooks.dynamic_rate_limiter_v3 import ( + _PROXY_DynamicRateLimitHandlerV3, + ) + + for callback in _in_memory_loggers: + if isinstance(callback, _PROXY_DynamicRateLimitHandlerV3): + return callback # type: ignore + + elif logging_integration == "langtrace": + from litellm.integrations.opentelemetry import OpenTelemetry + + if "LANGTRACE_API_KEY" not in os.environ: + raise ValueError("LANGTRACE_API_KEY not found in environment variables") + + for callback in _in_memory_loggers: + if ( + isinstance(callback, OpenTelemetry) + and callback.callback_name == "langtrace" + ): + return callback + + elif logging_integration == "mlflow": + for callback in _in_memory_loggers: + if isinstance(callback, MlflowLogger): + return callback + elif logging_integration == "pagerduty": + for callback in _in_memory_loggers: + if isinstance(callback, PagerDutyAlerting): + return callback + elif logging_integration == "anthropic_cache_control_hook": + for callback in _in_memory_loggers: + if isinstance(callback, AnthropicCacheControlHook): + return callback + elif logging_integration == "vector_store_pre_call_hook": + from litellm.integrations.vector_store_integrations.vector_store_pre_call_hook import ( + VectorStorePreCallHook, + ) + + for callback in _in_memory_loggers: + if isinstance(callback, VectorStorePreCallHook): + return callback + elif logging_integration == "gcs_pubsub": + for callback in _in_memory_loggers: + if isinstance(callback, GcsPubSubLogger): + return callback + elif logging_integration == "generic_api": + for callback in _in_memory_loggers: + if isinstance(callback, GenericAPILogger): + return callback + elif logging_integration == "resend_email": + for callback in _in_memory_loggers: + if isinstance(callback, ResendEmailLogger): + return callback + elif logging_integration == "sendgrid_email": + for callback in _in_memory_loggers: + if isinstance(callback, SendGridEmailLogger): + return callback + elif logging_integration == "smtp_email": + for callback in _in_memory_loggers: + if isinstance(callback, SMTPEmailLogger): + return callback + return None + + except Exception as e: + verbose_logger.exception( + f"[Non-Blocking Error] Error getting custom logger: {e}" + ) + return None + + +def _get_custom_logger_settings_from_proxy_server(callback_name: str) -> Dict: + """ + Get the settings for a custom logger from the proxy server config.yaml + + Proxy server config.yaml defines callback_settings as: + + callback_settings: + otel: + message_logging: False + """ + if litellm.callback_settings: + return dict(litellm.callback_settings.get(callback_name, {})) + return {} + + +def use_custom_pricing_for_model(litellm_params: Optional[dict]) -> bool: + """ + Check if the model uses custom pricing + + Returns True if any of `SPECIAL_MODEL_INFO_PARAMS` are present in `litellm_params` or `model_info` + """ + if litellm_params is None: + return False + + # Check litellm_params using set intersection (only check keys that exist in both) + matching_keys = _CUSTOM_PRICING_KEYS & litellm_params.keys() + for key in matching_keys: + if litellm_params.get(key) is not None: + return True + + # Check model_info + metadata: dict = litellm_params.get("metadata", {}) or {} + model_info: dict = metadata.get("model_info", {}) or {} + + if model_info: + matching_keys = _CUSTOM_PRICING_KEYS & model_info.keys() + for key in matching_keys: + if model_info.get(key) is not None: + return True + + return False + + +def is_valid_sha256_hash(value: str) -> bool: + # Check if the value is a valid SHA-256 hash (64 hexadecimal characters) + return bool(re.fullmatch(r"[a-fA-F0-9]{64}", value)) + + +class StandardLoggingPayloadSetup: + @staticmethod + def cleanup_timestamps( + start_time: Union[dt_object, float], + end_time: Union[dt_object, float], + completion_start_time: Union[dt_object, float], + ) -> Tuple[float, float, float]: + """ + Convert datetime objects to floats + + Args: + start_time: Union[dt_object, float] + end_time: Union[dt_object, float] + completion_start_time: Union[dt_object, float] + + Returns: + Tuple[float, float, float]: A tuple containing the start time, end time, and completion start time as floats. + """ + + if isinstance(start_time, datetime.datetime): + start_time_float = start_time.timestamp() + elif isinstance(start_time, float): + start_time_float = start_time + else: + raise ValueError( + f"start_time is required, got={start_time} of type {type(start_time)}" + ) + + if isinstance(end_time, datetime.datetime): + end_time_float = end_time.timestamp() + elif isinstance(end_time, float): + end_time_float = end_time + else: + raise ValueError( + f"end_time is required, got={end_time} of type {type(end_time)}" + ) + + if isinstance(completion_start_time, datetime.datetime): + completion_start_time_float = completion_start_time.timestamp() + elif isinstance(completion_start_time, float): + completion_start_time_float = completion_start_time + else: + completion_start_time_float = end_time_float + + return start_time_float, end_time_float, completion_start_time_float + + @staticmethod + def append_system_prompt_messages( + kwargs: Optional[Dict] = None, messages: Optional[Any] = None + ): + """ + Append system prompt messages to the messages + """ + if kwargs is not None: + if kwargs.get("system") is not None and isinstance( + kwargs.get("system"), str + ): + if messages is None: + return [{"role": "system", "content": kwargs.get("system")}] + elif isinstance(messages, list): + if len(messages) == 0: + return [{"role": "system", "content": kwargs.get("system")}] + # check for duplicates + if messages[0].get("role") == "system" and messages[0].get( + "content" + ) == kwargs.get("system"): + return messages + messages = [ + {"role": "system", "content": kwargs.get("system")} + ] + messages + elif isinstance(messages, str): + messages = [ + {"role": "system", "content": kwargs.get("system")}, + {"role": "user", "content": messages}, + ] + return messages + + return messages + + @staticmethod + def merge_litellm_metadata(litellm_params: dict) -> dict: + """ + Merge both litellm_metadata and metadata from litellm_params. + + litellm_metadata contains model-related fields, metadata contains user API key fields. + We need both for complete standard logging payload. + + Args: + litellm_params: Dictionary containing metadata and litellm_metadata + + Returns: + dict: Merged metadata with user API key fields taking precedence + """ + merged_metadata: dict = {} + + # Start with metadata (user API key fields) - but skip non-serializable objects + if litellm_params.get("metadata") and isinstance( + litellm_params.get("metadata"), dict + ): + for key, value in litellm_params["metadata"].items(): + # Skip non-serializable objects like UserAPIKeyAuth + if key == "user_api_key_auth": + continue + merged_metadata[key] = value + + # Then merge litellm_metadata (model-related fields) - this will NOT overwrite existing keys + if litellm_params.get("litellm_metadata") and isinstance( + litellm_params.get("litellm_metadata"), dict + ): + for key, value in litellm_params["litellm_metadata"].items(): + if ( + key not in merged_metadata + ): # Don't overwrite existing keys from metadata + merged_metadata[key] = value + + return merged_metadata + + @staticmethod + def get_standard_logging_metadata( + metadata: Optional[Dict[str, Any]], + litellm_params: Optional[dict] = None, + prompt_integration: Optional[str] = None, + applied_guardrails: Optional[List[str]] = None, + mcp_tool_call_metadata: Optional[StandardLoggingMCPToolCall] = None, + vector_store_request_metadata: Optional[ + List[StandardLoggingVectorStoreRequest] + ] = None, + usage_object: Optional[dict] = None, + proxy_server_request: Optional[dict] = None, + start_time: Optional[dt_object] = None, + response_id: Optional[str] = None, + ) -> StandardLoggingMetadata: + """ + Clean and filter the metadata dictionary to include only the specified keys in StandardLoggingMetadata. + + Args: + metadata (Optional[Dict[str, Any]]): The original metadata dictionary. + + Returns: + StandardLoggingMetadata: A StandardLoggingMetadata object containing the cleaned metadata. + + Note: + - If the input metadata is None or not a dictionary, an empty StandardLoggingMetadata object is returned. + - If 'user_api_key' is present in metadata and is a valid SHA256 hash, it's stored as 'user_api_key_hash'. + """ + + prompt_management_metadata: Optional[ + StandardLoggingPromptManagementMetadata + ] = None + if litellm_params is not None: + prompt_id = cast(Optional[str], litellm_params.get("prompt_id", None)) + prompt_variables = cast( + Optional[dict], litellm_params.get("prompt_variables", None) + ) + + if prompt_id is not None and prompt_integration is not None: + prompt_management_metadata = StandardLoggingPromptManagementMetadata( + prompt_id=prompt_id, + prompt_variables=prompt_variables, + prompt_integration=prompt_integration, + ) + + # Initialize with default values + clean_metadata = StandardLoggingMetadata( + user_api_key_hash=None, + user_api_key_alias=None, + user_api_key_spend=None, + user_api_key_max_budget=None, + user_api_key_budget_reset_at=None, + user_api_key_team_id=None, + user_api_key_org_id=None, + user_api_key_project_id=None, + user_api_key_user_id=None, + user_api_key_team_alias=None, + user_api_key_user_email=None, + user_api_key_end_user_id=None, + user_api_key_request_route=None, + spend_logs_metadata=None, + requester_ip_address=None, + user_agent=None, + requester_metadata=None, + prompt_management_metadata=prompt_management_metadata, + applied_guardrails=applied_guardrails, + mcp_tool_call_metadata=mcp_tool_call_metadata, + vector_store_request_metadata=vector_store_request_metadata, + usage_object=usage_object, + requester_custom_headers=None, + cold_storage_object_key=None, + user_api_key_auth_metadata=None, + team_alias=None, + team_id=None, + ) + if isinstance(metadata, dict): + for key in metadata.keys() & _STANDARD_LOGGING_METADATA_KEYS: + clean_metadata[key] = metadata[key] # type: ignore + + user_api_key = metadata.get("user_api_key") + if ( + user_api_key + and isinstance(user_api_key, str) + and is_valid_sha256_hash(user_api_key) + ): + clean_metadata["user_api_key_hash"] = user_api_key + _potential_requester_metadata = metadata.get( + "metadata", None + ) # check if user passed metadata in the sdk request - e.g. metadata for langsmith logging - https://docs.litellm.ai/docs/observability/langsmith_integration#set-langsmith-fields + if ( + clean_metadata["requester_metadata"] is None + and _potential_requester_metadata is not None + and isinstance(_potential_requester_metadata, dict) + ): + clean_metadata["requester_metadata"] = _potential_requester_metadata + + if ( + EnterpriseStandardLoggingPayloadSetupVAR + and proxy_server_request is not None + ): + clean_metadata = EnterpriseStandardLoggingPayloadSetupVAR.apply_enterprise_specific_metadata( + standard_logging_metadata=clean_metadata, + proxy_server_request=proxy_server_request, + ) + + # Generate cold storage object key if cold storage is configured + if start_time is not None and response_id is not None: + cold_storage_object_key = ( + StandardLoggingPayloadSetup._generate_cold_storage_object_key( + start_time=start_time, + response_id=response_id, + team_alias=clean_metadata.get("user_api_key_team_alias"), + ) + ) + if cold_storage_object_key: + clean_metadata["cold_storage_object_key"] = cold_storage_object_key + + return clean_metadata + + @staticmethod + def get_usage_from_response_obj( + response_obj: Optional[dict], combined_usage_object: Optional[Usage] = None + ) -> Usage: + ## BASE CASE ## + if combined_usage_object is not None: + return combined_usage_object + if response_obj is None: + return Usage( + prompt_tokens=0, + completion_tokens=0, + total_tokens=0, + ) + + usage = response_obj.get("usage", None) or {} + if usage is None or ( + not isinstance(usage, dict) and not isinstance(usage, Usage) + ): + return Usage( + prompt_tokens=0, + completion_tokens=0, + total_tokens=0, + ) + elif isinstance(usage, Usage): + return usage + elif isinstance(usage, ResponseAPIUsage): + return ResponseAPILoggingUtils._transform_response_api_usage_to_chat_usage( + usage + ) + elif isinstance(usage, dict): + if ResponseAPILoggingUtils._is_response_api_usage(usage): + return ( + ResponseAPILoggingUtils._transform_response_api_usage_to_chat_usage( + usage + ) + ) + return Usage(**usage) + + raise ValueError(f"usage is required, got={usage} of type {type(usage)}") + + @staticmethod + def get_model_cost_information( + base_model: Optional[str], + custom_pricing: Optional[bool], + custom_llm_provider: Optional[str], + init_response_obj: Union[Any, BaseModel, dict], + ) -> StandardLoggingModelInformation: + model_cost_name = _select_model_name_for_cost_calc( + model=None, + completion_response=init_response_obj, # type: ignore + base_model=base_model, + custom_pricing=custom_pricing, + ) + if model_cost_name is None: + model_cost_information = StandardLoggingModelInformation( + model_map_key="", model_map_value=None + ) + else: + try: + _model_cost_information = litellm.get_model_info( + model=model_cost_name, custom_llm_provider=custom_llm_provider + ) + model_cost_information = StandardLoggingModelInformation( + model_map_key=model_cost_name, + model_map_value=_model_cost_information, + ) + except Exception: + verbose_logger.debug( # keep in debug otherwise it will trigger on every call + "Model={} is not mapped in model cost map. Defaulting to None model_cost_information for standard_logging_payload".format( + model_cost_name + ) + ) + model_cost_information = StandardLoggingModelInformation( + model_map_key=model_cost_name, model_map_value=None + ) + return model_cost_information + + @staticmethod + def get_final_response_obj( + response_obj: dict, init_response_obj: Union[Any, BaseModel, dict], kwargs: dict + ) -> Optional[Union[dict, str, list]]: + """ + Get final response object after redacting the message input/output from logging + """ + if response_obj: + final_response_obj: Optional[Union[dict, str, list]] = response_obj + elif isinstance(init_response_obj, list) or isinstance(init_response_obj, str): + final_response_obj = init_response_obj + else: + final_response_obj = {} + + modified_final_response_obj = redact_message_input_output_from_logging( + model_call_details=kwargs, + result=final_response_obj, + ) + + if modified_final_response_obj is not None and isinstance( + modified_final_response_obj, BaseModel + ): + final_response_obj = modified_final_response_obj.model_dump() + else: + final_response_obj = modified_final_response_obj + + return final_response_obj + + @staticmethod + def get_additional_headers( + additiona_headers: Optional[dict], + ) -> Optional[StandardLoggingAdditionalHeaders]: + if additiona_headers is None: + return None + + additional_logging_headers: StandardLoggingAdditionalHeaders = {} + + for key in StandardLoggingAdditionalHeaders.__annotations__.keys(): + _key = key.lower() + _key = _key.replace("_", "-") + if _key in additiona_headers: + try: + additional_logging_headers[key] = int(additiona_headers[_key]) # type: ignore + except (ValueError, TypeError): + verbose_logger.debug( + f"Could not convert {additiona_headers[_key]} to int for key {key}." + ) + return additional_logging_headers + + @staticmethod + def get_hidden_params( + hidden_params: Optional[dict], + ) -> StandardLoggingHiddenParams: + clean_hidden_params = StandardLoggingHiddenParams( + model_id=None, + cache_key=None, + api_base=None, + response_cost=None, + additional_headers=None, + litellm_overhead_time_ms=None, + batch_models=None, + litellm_model_name=None, + usage_object=None, + ) + if hidden_params is not None: + for key in StandardLoggingHiddenParams.__annotations__.keys(): + if key in hidden_params: + if key == "additional_headers": + clean_hidden_params[ + "additional_headers" + ] = StandardLoggingPayloadSetup.get_additional_headers( + hidden_params[key] + ) + else: + clean_hidden_params[key] = hidden_params[key] # type: ignore + return clean_hidden_params + + @staticmethod + def strip_trailing_slash(api_base: Optional[str]) -> Optional[str]: + if api_base: + if api_base.endswith("//"): + return api_base.rstrip("/") + if api_base[-1] == "/": + return api_base[:-1] + return api_base + + @staticmethod + def _generate_cold_storage_object_key( + start_time: dt_object, + response_id: str, + team_alias: Optional[str] = None, + ) -> Optional[str]: + """ + Generate cold storage object key in the same format as S3Logger. + + Args: + start_time: The start time of the request + response_id: The response ID + team_alias: Optional team alias for team-based prefixing + + Returns: + Optional[str]: The generated object key or None if cold storage not configured + """ + # Generate object key in same format as S3Logger + from litellm.integrations.s3 import get_s3_object_key + + # Only generate object key if cold storage is configured + cold_storage_custom_logger = litellm.cold_storage_custom_logger + if cold_storage_custom_logger is None: + return None + + try: + # Generate file name in same format as litellm.utils.get_logging_id + s3_file_name = f"time-{start_time.strftime('%H-%M-%S-%f')}_{response_id}" + + # Get the actual s3_path from the configured cold storage logger instance + s3_path = "" # default value + + # Try to get the actual logger instance from the logger name + try: + custom_logger = litellm.logging_callback_manager.get_active_custom_logger_for_callback_name( + cold_storage_custom_logger + ) + if ( + custom_logger + and hasattr(custom_logger, "s3_path") + and getattr(custom_logger, "s3_path") + ): + s3_path = getattr(custom_logger, "s3_path") + except Exception: + # If any error occurs in getting the logger instance, use default empty s3_path + pass + + s3_object_key = get_s3_object_key( + s3_path=s3_path, # Use actual s3_path from logger configuration + prefix="", # Don't split by team alias for cold storage + start_time=start_time, + s3_file_name=s3_file_name, + ) + + return s3_object_key + except Exception: + # If any error occurs in generating the key, return None + return None + + @staticmethod + def get_error_information( + original_exception: Optional[Exception], + traceback_str: Optional[str] = None, + ) -> StandardLoggingPayloadErrorInformation: + from litellm.constants import MAXIMUM_TRACEBACK_LINES_TO_LOG + + # Check for 'code' first (used by ProxyException), then fall back to 'status_code' (used by LiteLLM exceptions) + # Ensure error_code is always a string for Prisma Python JSON field compatibility + error_code_attr = getattr(original_exception, "code", None) + if error_code_attr is not None and str(error_code_attr) not in ("", "None"): + error_status: str = str(error_code_attr) + else: + status_code_attr = getattr(original_exception, "status_code", None) + error_status = str(status_code_attr) if status_code_attr is not None else "" + error_class: str = ( + str(original_exception.__class__.__name__) if original_exception else "" + ) + _llm_provider_in_exception = getattr(original_exception, "llm_provider", "") + + # Get traceback information (first 100 lines) + traceback_info = traceback_str or "" + if original_exception: + tb = getattr(original_exception, "__traceback__", None) + if tb: + tb_lines = traceback.format_tb(tb) + traceback_info += "".join( + tb_lines[:MAXIMUM_TRACEBACK_LINES_TO_LOG] + ) # Limit to first 100 lines + + # Get additional error details + error_message = str(original_exception) + + return StandardLoggingPayloadErrorInformation( + error_code=error_status, + error_class=error_class, + llm_provider=_llm_provider_in_exception, + traceback=traceback_info, + error_message=error_message if original_exception else "", + ) + + @staticmethod + def get_response_time( + start_time_float: float, + end_time_float: float, + completion_start_time_float: float, + stream: bool, + ) -> float: + """ + Get the response time for the LLM response + + Args: + start_time_float: float - start time of the LLM call + end_time_float: float - end time of the LLM call + completion_start_time_float: float - time to first token of the LLM response (for streaming responses) + stream: bool - True when a stream response is returned + + Returns: + float: The response time for the LLM response + """ + if stream is True: + return completion_start_time_float - start_time_float + else: + return end_time_float - start_time_float + + @staticmethod + def _get_standard_logging_payload_trace_id( + logging_obj: Logging, + litellm_params: dict, + ) -> str: + """ + Returns the `litellm_trace_id` for this request + + This helps link sessions when multiple requests are made in a single session + """ + dynamic_litellm_session_id = litellm_params.get("litellm_session_id") + dynamic_litellm_trace_id = litellm_params.get("litellm_trace_id") + + # Note: we recommend using `litellm_session_id` for session tracking + # `litellm_trace_id` is an internal litellm param + if dynamic_litellm_session_id: + return str(dynamic_litellm_session_id) + elif dynamic_litellm_trace_id: + return str(dynamic_litellm_trace_id) + else: + return logging_obj.litellm_trace_id + + @staticmethod + def _get_user_agent_tags(proxy_server_request: dict) -> Optional[List[str]]: + """ + Return the user agent tags from the proxy server request for spend tracking + """ + if litellm.disable_add_user_agent_to_request_tags is True: + return None + user_agent_tags: Optional[List[str]] = None + headers = proxy_server_request.get("headers", {}) + if headers is not None and isinstance(headers, dict): + if "user-agent" in headers: + user_agent = headers["user-agent"] + if user_agent is not None: + if user_agent_tags is None: + user_agent_tags = [] + user_agent_part: Optional[str] = None + if "/" in user_agent: + user_agent_part = user_agent.split("/")[0] + if user_agent_part is not None: + user_agent_tags.append("User-Agent: " + user_agent_part) + if user_agent is not None: + user_agent_tags.append("User-Agent: " + user_agent) + return user_agent_tags + + @staticmethod + def _get_extra_header_tags(proxy_server_request: dict) -> Optional[List[str]]: + """ + Extract additional header tags for spend tracking based on config. + """ + extra_headers: List[str] = ( + getattr(litellm, "extra_spend_tag_headers", None) or [] + ) + if not extra_headers: + return None + + headers = proxy_server_request.get("headers", {}) + if not isinstance(headers, dict): + return None + + header_tags = [] + for header_name in extra_headers: + header_value = headers.get(header_name) + if header_value: + header_tags.append(f"{header_name}: {header_value}") + + return header_tags if header_tags else None + + @staticmethod + def _get_request_tags( + litellm_params: dict, proxy_server_request: dict + ) -> List[str]: + # check for 'tags' in both 'metadata' and 'litellm_metadata' + metadata = litellm_params.get("metadata") or {} + litellm_metadata = litellm_params.get("litellm_metadata") or {} + if metadata.get("tags", []): + request_tags = metadata.get("tags", []).copy() + elif litellm_metadata.get("tags", []): + request_tags = litellm_metadata.get("tags", []).copy() + else: + request_tags = [] + user_agent_tags = StandardLoggingPayloadSetup._get_user_agent_tags( + proxy_server_request + ) + additional_header_tags = StandardLoggingPayloadSetup._get_extra_header_tags( + proxy_server_request + ) + if user_agent_tags is not None: + request_tags.extend(user_agent_tags) + if additional_header_tags is not None: + request_tags.extend(additional_header_tags) + return request_tags + + +def _get_status_fields( + status: StandardLoggingPayloadStatus, + guardrail_information: Optional[List[dict]], + error_str: Optional[str], +) -> "StandardLoggingPayloadStatusFields": + """ + Determine status fields based on request status and guardrail information. + + Args: + status: Overall request status ("success" or "failure") + guardrail_information: Guardrail information from metadata + error_str: Error string if any + + Returns: + StandardLoggingPayloadStatusFields with llm_api_status and guardrail_status + """ + # Mapping for legacy guardrail status values to new GuardrailStatus values + GUARDRAIL_STATUS_MAP: Dict[str, GuardrailStatus] = { + "success": "success", + "blocked": "guardrail_intervened", # legacy + "guardrail_intervened": "guardrail_intervened", # direct + "failure": "guardrail_failed_to_respond", # legacy + "guardrail_failed_to_respond": "guardrail_failed_to_respond", # direct + "not_run": "not_run", + } + + # Set LLM API status + llm_api_status: StandardLoggingPayloadStatus = status + + ######################################################### + # Map - guardrail_information.guardrail_status to guardrail_status + ######################################################### + guardrail_status: GuardrailStatus = "not_run" + if guardrail_information and isinstance(guardrail_information, list): + for information in guardrail_information: + if isinstance(information, dict): + raw_status = information.get("guardrail_status", "not_run") + if raw_status != "not_run": + guardrail_status = GUARDRAIL_STATUS_MAP.get(raw_status, "not_run") + break + + return StandardLoggingPayloadStatusFields( + llm_api_status=llm_api_status, guardrail_status=guardrail_status + ) + + +def _extract_response_obj_and_hidden_params( + init_response_obj: Union[Any, BaseModel, dict], + original_exception: Optional[Exception], +) -> Tuple[dict, Optional[dict]]: + """Extract response_obj and hidden_params from init_response_obj.""" + hidden_params: Optional[dict] = None + if init_response_obj is None: + response_obj = {} + elif isinstance(init_response_obj, BaseModel): + response_obj = init_response_obj.model_dump() + hidden_params = getattr(init_response_obj, "_hidden_params", None) + elif isinstance(init_response_obj, dict): + response_obj = init_response_obj + else: + response_obj = {} + + if original_exception is not None and hidden_params is None: + response_headers = _get_response_headers(original_exception) + if response_headers is not None: + hidden_params = dict( + StandardLoggingHiddenParams( + additional_headers=StandardLoggingPayloadSetup.get_additional_headers( + dict(response_headers) + ), + model_id=None, + cache_key=None, + api_base=None, + response_cost=None, + litellm_overhead_time_ms=None, + batch_models=None, + litellm_model_name=None, + usage_object=None, + ) + ) + + return response_obj, hidden_params + + +def get_standard_logging_object_payload( + kwargs: Optional[dict], + init_response_obj: Union[Any, BaseModel, dict], + start_time: dt_object, + end_time: dt_object, + logging_obj: Logging, + status: StandardLoggingPayloadStatus, + error_str: Optional[str] = None, + original_exception: Optional[Exception] = None, + standard_built_in_tools_params: Optional[StandardBuiltInToolsParams] = None, +) -> Optional[StandardLoggingPayload]: + try: + kwargs = kwargs or {} + + response_obj, hidden_params = _extract_response_obj_and_hidden_params( + init_response_obj, original_exception + ) + + # standardize this function to be used across, s3, dynamoDB, langfuse logging + litellm_params = kwargs.get("litellm_params", {}) or {} + proxy_server_request = litellm_params.get("proxy_server_request") or {} + + # Merge both litellm_metadata and metadata to get complete metadata + metadata: dict = StandardLoggingPayloadSetup.merge_litellm_metadata( + litellm_params + ) + + completion_start_time = kwargs.get("completion_start_time", end_time) + call_type = kwargs.get("call_type") + cache_hit = kwargs.get("cache_hit", False) + usage = StandardLoggingPayloadSetup.get_usage_from_response_obj( + response_obj=response_obj, + combined_usage_object=cast( + Optional[Usage], kwargs.get("combined_usage_object") + ), + ) + + id = response_obj.get("id", kwargs.get("litellm_call_id")) + + _model_id = metadata.get("model_info", {}).get("id", "") + _model_group = metadata.get("model_group", "") + + request_tags = StandardLoggingPayloadSetup._get_request_tags( + litellm_params=litellm_params, proxy_server_request=proxy_server_request + ) + + # cleanup timestamps + ( + start_time_float, + end_time_float, + completion_start_time_float, + ) = StandardLoggingPayloadSetup.cleanup_timestamps( + start_time=start_time, + end_time=end_time, + completion_start_time=completion_start_time, + ) + response_time = StandardLoggingPayloadSetup.get_response_time( + start_time_float=start_time_float, + end_time_float=end_time_float, + completion_start_time_float=completion_start_time_float, + stream=kwargs.get("stream", False), + ) + # clean up litellm hidden params + clean_hidden_params = StandardLoggingPayloadSetup.get_hidden_params( + hidden_params + ) + + # clean up litellm metadata + clean_metadata = StandardLoggingPayloadSetup.get_standard_logging_metadata( + metadata=metadata, + litellm_params=litellm_params, + prompt_integration=kwargs.get("prompt_integration", None), + applied_guardrails=kwargs.get("applied_guardrails", None), + mcp_tool_call_metadata=kwargs.get("mcp_tool_call_metadata", None), + vector_store_request_metadata=kwargs.get( + "vector_store_request_metadata", None + ), + usage_object=usage.model_dump(), + proxy_server_request=proxy_server_request, + start_time=start_time, + response_id=id, + ) + _request_body = proxy_server_request.get("body", {}) + end_user_id = clean_metadata["user_api_key_end_user_id"] or _request_body.get( + "user", None + ) # maintain backwards compatibility with old request body check + + saved_cache_cost: float = 0.0 + if cache_hit is True: + id = f"{id}_cache_hit{time.time()}" # do not duplicate the request id + saved_cache_cost = ( + logging_obj._response_cost_calculator( + result=init_response_obj, cache_hit=False # type: ignore + ) + or 0.0 + ) + + ## Get model cost information ## + base_model = _get_base_model_from_metadata(model_call_details=kwargs) + custom_pricing = use_custom_pricing_for_model(litellm_params=litellm_params) + + model_cost_information = StandardLoggingPayloadSetup.get_model_cost_information( + base_model=base_model, + custom_pricing=custom_pricing, + custom_llm_provider=kwargs.get("custom_llm_provider"), + init_response_obj=init_response_obj, + ) + response_cost: float = kwargs.get("response_cost", 0) or 0.0 + + error_information = StandardLoggingPayloadSetup.get_error_information( + original_exception=original_exception, + ) + + ## get final response object ## + final_response_obj = StandardLoggingPayloadSetup.get_final_response_obj( + response_obj=response_obj, + init_response_obj=init_response_obj, + kwargs=kwargs, + ) + + stream: Optional[bool] = None + if ( + kwargs.get("complete_streaming_response") is not None + or kwargs.get("async_complete_streaming_response") is not None + ) and kwargs.get("stream") is True: + stream = True + + # Reconstruct full model name with provider prefix for logging + # This ensures Bedrock models like "us.anthropic.claude-3-5-sonnet-20240620-v1:0" + # are logged as "bedrock/us.anthropic.claude-3-5-sonnet-20240620-v1:0" + custom_llm_provider = cast(Optional[str], kwargs.get("custom_llm_provider")) + model_name = reconstruct_model_name( + kwargs.get("model", "") or "", custom_llm_provider, metadata + ) + + payload: StandardLoggingPayload = StandardLoggingPayload( + id=str(id), + trace_id=StandardLoggingPayloadSetup._get_standard_logging_payload_trace_id( + logging_obj=logging_obj, + litellm_params=litellm_params, + ), + call_type=call_type or "", + cache_hit=cache_hit, + stream=stream, + status=status, + status_fields=_get_status_fields( + status=status, + guardrail_information=metadata.get( + "standard_logging_guardrail_information", None + ), + error_str=error_str, + ), + custom_llm_provider=custom_llm_provider, + saved_cache_cost=saved_cache_cost, + startTime=start_time_float, + endTime=end_time_float, + completionStartTime=completion_start_time_float, + response_time=response_time, + model=model_name, + metadata=clean_metadata, + cache_key=clean_hidden_params["cache_key"], + response_cost=response_cost, + cost_breakdown=logging_obj.cost_breakdown, + total_tokens=usage.total_tokens, + prompt_tokens=usage.prompt_tokens, + completion_tokens=usage.completion_tokens, + request_tags=request_tags, + end_user=end_user_id or "", + api_base=StandardLoggingPayloadSetup.strip_trailing_slash( + litellm_params.get("api_base", "") + ) + or "", + model_group=_model_group, + model_id=_model_id, + requester_ip_address=clean_metadata.get("requester_ip_address", None), + user_agent=clean_metadata.get("user_agent", None), + messages=StandardLoggingPayloadSetup.append_system_prompt_messages( + kwargs=kwargs, messages=kwargs.get("messages") + ), + response=final_response_obj, + model_parameters=ModelParamHelper.get_standard_logging_model_parameters( + kwargs.get("optional_params", None) or {} + ), + hidden_params=clean_hidden_params, + model_map_information=model_cost_information, + error_str=error_str, + error_information=error_information, + response_cost_failure_debug_info=kwargs.get( + "response_cost_failure_debug_information" + ), + guardrail_information=metadata.get( + "standard_logging_guardrail_information", None + ), + standard_built_in_tools_params=standard_built_in_tools_params, + ) + + # emit_standard_logging_payload(payload) - Moved to success_handler to prevent double emitting + + return payload + except Exception as e: + verbose_logger.exception( + "Error creating standard logging object - {}".format(str(e)) + ) + return None + + +def emit_standard_logging_payload(payload: StandardLoggingPayload): + if os.getenv("LITELLM_PRINT_STANDARD_LOGGING_PAYLOAD"): + print(json.dumps(payload, indent=4)) # noqa + + +def get_standard_logging_metadata( + metadata: Optional[Dict[str, Any]], +) -> StandardLoggingMetadata: + """ + Clean and filter the metadata dictionary to include only the specified keys in StandardLoggingMetadata. + + Args: + metadata (Optional[Dict[str, Any]]): The original metadata dictionary. + + Returns: + StandardLoggingMetadata: A StandardLoggingMetadata object containing the cleaned metadata. + + Note: + - If the input metadata is None or not a dictionary, an empty StandardLoggingMetadata object is returned. + - If 'user_api_key' is present in metadata and is a valid SHA256 hash, it's stored as 'user_api_key_hash'. + """ + # Initialize with default values + clean_metadata = StandardLoggingMetadata( + user_api_key_hash=None, + user_api_key_alias=None, + user_api_key_spend=None, + user_api_key_max_budget=None, + user_api_key_budget_reset_at=None, + user_api_key_team_id=None, + user_api_key_org_id=None, + user_api_key_project_id=None, + user_api_key_user_id=None, + user_api_key_user_email=None, + user_api_key_team_alias=None, + spend_logs_metadata=None, + requester_ip_address=None, + user_agent=None, + requester_metadata=None, + user_api_key_end_user_id=None, + prompt_management_metadata=None, + applied_guardrails=None, + mcp_tool_call_metadata=None, + vector_store_request_metadata=None, + usage_object=None, + requester_custom_headers=None, + user_api_key_request_route=None, + cold_storage_object_key=None, + user_api_key_auth_metadata=None, + team_alias=None, + team_id=None, + ) + if isinstance(metadata, dict): + # Update the clean_metadata with values from input metadata that match StandardLoggingMetadata fields + for key in StandardLoggingMetadata.__annotations__.keys(): + if key in metadata: + clean_metadata[key] = metadata[key] # type: ignore + + if metadata.get("user_api_key") is not None: + if is_valid_sha256_hash(str(metadata.get("user_api_key"))): + clean_metadata["user_api_key_hash"] = metadata.get( + "user_api_key" + ) # this is the hash + return clean_metadata + + +def scrub_sensitive_keys_in_metadata(litellm_params: Optional[dict]): + if litellm_params is None: + litellm_params = {} + + metadata = litellm_params.get("metadata", {}) or {} + + ## Extract provider-specific callable values (like langfuse_masking_function) + ## Store them separately so only the intended logger can access them + ## This prevents callables from leaking to other logging integrations + if "langfuse_masking_function" in metadata: + masking_fn = metadata.pop("langfuse_masking_function", None) + if callable(masking_fn): + litellm_params["_langfuse_masking_function"] = masking_fn + litellm_params["metadata"] = metadata + + ## check user_api_key_metadata for sensitive logging keys + cleaned_user_api_key_metadata = {} + if "user_api_key_metadata" in metadata and isinstance( + metadata["user_api_key_metadata"], dict + ): + for k, v in metadata["user_api_key_metadata"].items(): + if k == "logging": # prevent logging user logging keys + cleaned_user_api_key_metadata[ + k + ] = "scrubbed_by_litellm_for_sensitive_keys" + else: + cleaned_user_api_key_metadata[k] = v + + metadata["user_api_key_metadata"] = cleaned_user_api_key_metadata + litellm_params["metadata"] = metadata + + return litellm_params + + +# integration helper function +def modify_integration(integration_name, integration_params): + global supabaseClient + if integration_name == "supabase": + if "table_name" in integration_params: + Supabase.supabase_table_name = integration_params["table_name"] + + +@lru_cache(maxsize=16) +def _get_traceback_str_for_error(error_str: str) -> str: + """ + function wrapped with lru_cache to limit the number of times `traceback.format_exc()` is called + """ + return traceback.format_exc() + + +from decimal import Decimal + +# used for unit testing +from typing import Any, Dict, List, Optional, Union + + +def create_dummy_standard_logging_payload() -> StandardLoggingPayload: + # First create the nested objects with proper typing + model_info = StandardLoggingModelInformation( + model_map_key="gpt-3.5-turbo", model_map_value=None + ) + + metadata = StandardLoggingMetadata( # type: ignore + user_api_key_hash=str("test_hash"), + user_api_key_alias=str("test_alias"), + user_api_key_team_id=str("test_team"), + user_api_key_user_id=str("test_user"), + user_api_key_team_alias=str("test_team_alias"), + user_api_key_org_id=None, + spend_logs_metadata=None, + requester_ip_address=str("127.0.0.1"), + requester_metadata=None, + user_api_key_end_user_id=str("test_end_user"), + ) + + hidden_params = StandardLoggingHiddenParams( + model_id=None, + cache_key=None, + api_base=None, + response_cost=None, + additional_headers=None, + litellm_overhead_time_ms=None, + batch_models=None, + litellm_model_name=None, + usage_object=None, + ) + + # Convert numeric values to appropriate types + response_cost = Decimal("0.1") + start_time = Decimal("1234567890.0") + end_time = Decimal("1234567891.0") + completion_start_time = Decimal("1234567890.5") + saved_cache_cost = Decimal("0.0") + + # Create messages and response with proper typing + messages: List[Dict[str, str]] = [{"role": "user", "content": "Hello, world!"}] + response: Dict[str, List[Dict[str, Dict[str, str]]]] = { + "choices": [{"message": {"content": "Hi there!"}}] + } + + # Main payload initialization + return StandardLoggingPayload( # type: ignore + id=str("test_id"), + call_type=str("completion"), + stream=bool(False), + response_cost=response_cost, + response_cost_failure_debug_info=None, + status=str("success"), + total_tokens=int( + DEFAULT_MOCK_RESPONSE_PROMPT_TOKEN_COUNT + + DEFAULT_MOCK_RESPONSE_COMPLETION_TOKEN_COUNT + ), + prompt_tokens=int(DEFAULT_MOCK_RESPONSE_PROMPT_TOKEN_COUNT), + completion_tokens=int(DEFAULT_MOCK_RESPONSE_COMPLETION_TOKEN_COUNT), + startTime=start_time, + endTime=end_time, + completionStartTime=completion_start_time, + model_map_information=model_info, + model=str("gpt-3.5-turbo"), + model_id=str("model-123"), + model_group=str("openai-gpt"), + custom_llm_provider=str("openai"), + api_base=str("https://api.openai.com"), + metadata=metadata, + cache_hit=bool(False), + cache_key=None, + saved_cache_cost=saved_cache_cost, + request_tags=[], + end_user=None, + requester_ip_address=str("127.0.0.1"), + messages=messages, + response=response, + error_str=None, + model_parameters={"stream": True}, + hidden_params=hidden_params, + ) diff --git a/litellm/proxy/hooks/proxy_track_cost_callback.py b/litellm/proxy/hooks/proxy_track_cost_callback.py index d903ce0d9d..815d64f22a 100644 --- a/litellm/proxy/hooks/proxy_track_cost_callback.py +++ b/litellm/proxy/hooks/proxy_track_cost_callback.py @@ -1,287 +1,295 @@ -import asyncio -import traceback -from datetime import datetime -from typing import Any, List, Optional, Union, cast - -import litellm -from litellm._logging import verbose_proxy_logger -from litellm.integrations.custom_logger import CustomLogger -from litellm.litellm_core_utils.core_helpers import ( - _get_parent_otel_span_from_kwargs, - get_litellm_metadata_from_kwargs, -) -from litellm.litellm_core_utils.litellm_logging import StandardLoggingPayloadSetup -from litellm.proxy._types import UserAPIKeyAuth -from litellm.proxy.auth.auth_checks import log_db_metrics -from litellm.proxy.auth.route_checks import RouteChecks -from litellm.proxy.utils import ProxyUpdateSpend -from litellm.types.utils import ( - StandardLoggingPayload, - StandardLoggingUserAPIKeyMetadata, -) -from litellm.utils import get_end_user_id_for_cost_tracking - - -class _ProxyDBLogger(CustomLogger): - async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): - await self._PROXY_track_cost_callback( - kwargs, response_obj, start_time, end_time - ) - - async def async_post_call_failure_hook( - self, - request_data: dict, - original_exception: Exception, - user_api_key_dict: UserAPIKeyAuth, - traceback_str: Optional[str] = None, - ): - request_route = user_api_key_dict.request_route - if _ProxyDBLogger._should_track_errors_in_db() is False: - return - elif request_route is not None and not RouteChecks.is_llm_api_route( - route=request_route - ): - return - - from litellm.proxy.proxy_server import proxy_logging_obj - - _metadata = dict( - StandardLoggingUserAPIKeyMetadata( - user_api_key_hash=user_api_key_dict.api_key, - user_api_key_alias=user_api_key_dict.key_alias, - user_api_key_spend=user_api_key_dict.spend, - user_api_key_max_budget=user_api_key_dict.max_budget, - user_api_key_budget_reset_at=( - user_api_key_dict.budget_reset_at.isoformat() - if user_api_key_dict.budget_reset_at - else None - ), - user_api_key_user_email=user_api_key_dict.user_email, - user_api_key_user_id=user_api_key_dict.user_id, - user_api_key_team_id=user_api_key_dict.team_id, - user_api_key_org_id=user_api_key_dict.org_id, - user_api_key_team_alias=user_api_key_dict.team_alias, - user_api_key_end_user_id=user_api_key_dict.end_user_id, - user_api_key_request_route=user_api_key_dict.request_route, - user_api_key_auth_metadata=user_api_key_dict.metadata, - ) - ) - _metadata["user_api_key"] = user_api_key_dict.api_key - _metadata["status"] = "failure" - _metadata["error_information"] = ( - StandardLoggingPayloadSetup.get_error_information( - original_exception=original_exception, - traceback_str=traceback_str, - ) - ) - - existing_metadata: dict = request_data.get("metadata", None) or {} - existing_metadata.update(_metadata) - - if "litellm_params" not in request_data: - request_data["litellm_params"] = {} - - existing_litellm_params = request_data.get("litellm_params", {}) - existing_litellm_metadata = existing_litellm_params.get("metadata", {}) or {} - - # Preserve tags from existing metadata - if existing_litellm_metadata.get("tags"): - existing_metadata["tags"] = existing_litellm_metadata.get("tags") - - request_data["litellm_params"]["proxy_server_request"] = ( - request_data.get("proxy_server_request") or existing_litellm_params.get("proxy_server_request") or {} - ) - request_data["litellm_params"]["metadata"] = existing_metadata - - # Preserve model name and custom_llm_provider - if "model" not in request_data: - request_data["model"] = existing_litellm_params.get("model") or request_data.get("model", "") - if "custom_llm_provider" not in request_data: - request_data["custom_llm_provider"] = existing_litellm_params.get("custom_llm_provider") or request_data.get("custom_llm_provider", "") - - await proxy_logging_obj.db_spend_update_writer.update_database( - token=user_api_key_dict.api_key, - response_cost=0.0, - user_id=user_api_key_dict.user_id, - end_user_id=user_api_key_dict.end_user_id, - team_id=user_api_key_dict.team_id, - kwargs=request_data, - completion_response=original_exception, - start_time=datetime.now(), - end_time=datetime.now(), - org_id=user_api_key_dict.org_id, - ) - - @log_db_metrics - async def _PROXY_track_cost_callback( - self, - kwargs, # kwargs to completion - completion_response: Optional[ - Union[litellm.ModelResponse, Any] - ], # response from completion - start_time=None, - end_time=None, # start/end time for completion - ): - from litellm.proxy.proxy_server import proxy_logging_obj, update_cache - - verbose_proxy_logger.debug("INSIDE _PROXY_track_cost_callback") - try: - verbose_proxy_logger.debug( - f"kwargs stream: {kwargs.get('stream', None)} + complete streaming response: {kwargs.get('complete_streaming_response', None)}" - ) - parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs=kwargs) - litellm_params = kwargs.get("litellm_params", {}) or {} - end_user_id = get_end_user_id_for_cost_tracking(litellm_params) - metadata = get_litellm_metadata_from_kwargs(kwargs=kwargs) - user_id = cast(Optional[str], metadata.get("user_api_key_user_id", None)) - team_id = cast(Optional[str], metadata.get("user_api_key_team_id", None)) - org_id = cast(Optional[str], metadata.get("user_api_key_org_id", None)) - key_alias = cast(Optional[str], metadata.get("user_api_key_alias", None)) - end_user_max_budget = metadata.get("user_api_end_user_max_budget", None) - sl_object: Optional[StandardLoggingPayload] = kwargs.get( - "standard_logging_object", None - ) - response_cost = ( - sl_object.get("response_cost", None) - if sl_object is not None - else kwargs.get("response_cost", None) - ) - tags: Optional[List[str]] = ( - sl_object.get("request_tags", None) if sl_object is not None else None - ) - - if response_cost is not None: - user_api_key = metadata.get("user_api_key", None) - if kwargs.get("cache_hit", False) is True: - response_cost = 0.0 - verbose_proxy_logger.debug( - f"Cache Hit: response_cost {response_cost}, for user_id {user_id}" - ) - - verbose_proxy_logger.debug( - f"user_api_key {user_api_key}, user_id {user_id}, team_id {team_id}, end_user_id {end_user_id}" - ) - if _should_track_cost_callback( - user_api_key=user_api_key, - user_id=user_id, - team_id=team_id, - end_user_id=end_user_id, - ): - ## UPDATE DATABASE - await proxy_logging_obj.db_spend_update_writer.update_database( - token=user_api_key, - response_cost=response_cost, - user_id=user_id, - end_user_id=end_user_id, - team_id=team_id, - kwargs=kwargs, - completion_response=completion_response, - start_time=start_time, - end_time=end_time, - org_id=org_id, - ) - - # update cache - asyncio.create_task( - update_cache( - token=user_api_key, - user_id=user_id, - end_user_id=end_user_id, - response_cost=response_cost, - team_id=team_id, - parent_otel_span=parent_otel_span, - tags=tags, - ) - ) - - await proxy_logging_obj.slack_alerting_instance.customer_spend_alert( - token=user_api_key, - key_alias=key_alias, - end_user_id=end_user_id, - response_cost=response_cost, - max_budget=end_user_max_budget, - ) - else: - # Non-model call types (health checks, afile_delete) have no model or standard_logging_object. - # Use .get() for "stream" to avoid KeyError on health checks. - if sl_object is None and not kwargs.get("model"): - verbose_proxy_logger.warning( - "Cost tracking - skipping, no standard_logging_object and no model for call_type=%s", - kwargs.get("call_type", "unknown"), - ) - return - if kwargs.get("stream") is not True or ( - kwargs.get("stream") is True and "complete_streaming_response" in kwargs - ): - if sl_object is not None: - cost_tracking_failure_debug_info: Union[dict, str] = ( - sl_object["response_cost_failure_debug_info"] # type: ignore - or "response_cost_failure_debug_info is None in standard_logging_object" - ) - else: - cost_tracking_failure_debug_info = ( - "standard_logging_object not found" - ) - model = kwargs.get("model") - raise Exception( - f"Cost tracking failed for model={model}.\nDebug info - {cost_tracking_failure_debug_info}\nAdd custom pricing - https://docs.litellm.ai/docs/proxy/custom_pricing" - ) - except Exception as e: - error_msg = f"Error in tracking cost callback - {str(e)}\n Traceback:{traceback.format_exc()}" - model = kwargs.get("model", "") - metadata = get_litellm_metadata_from_kwargs(kwargs=kwargs) - litellm_metadata = kwargs.get("litellm_params", {}).get( - "litellm_metadata", {} - ) - old_metadata = kwargs.get("litellm_params", {}).get("metadata", {}) - call_type = kwargs.get("call_type", "") - error_msg += f"\n Args to _PROXY_track_cost_callback\n model: {model}\n chosen_metadata: {metadata}\n litellm_metadata: {litellm_metadata}\n old_metadata: {old_metadata}\n call_type: {call_type}\n" - asyncio.create_task( - proxy_logging_obj.failed_tracking_alert( - error_message=error_msg, - failing_model=model, - ) - ) - - verbose_proxy_logger.exception( - "Error in tracking cost callback - %s", str(e) - ) - - @staticmethod - def _should_track_errors_in_db(): - """ - Returns True if errors should be tracked in the database - - By default, errors are tracked in the database - - If users want to disable error tracking, they can set the disable_error_logs flag in the general_settings - """ - from litellm.proxy.proxy_server import general_settings - - if general_settings.get("disable_error_logs") is True: - return False - return - - -def _should_track_cost_callback( - user_api_key: Optional[str], - user_id: Optional[str], - team_id: Optional[str], - end_user_id: Optional[str], -) -> bool: - """ - Determine if the cost callback should be tracked based on the kwargs - """ - - # don't run track cost callback if user opted into disabling spend - if ProxyUpdateSpend.disable_spend_updates() is True: - return False - - if ( - user_api_key is not None - or user_id is not None - or team_id is not None - or end_user_id is not None - ): - return True - return False +import asyncio +import traceback +from datetime import datetime +from typing import Any, List, Optional, Union, cast + +import litellm +from litellm._logging import verbose_proxy_logger +from litellm.integrations.custom_logger import CustomLogger +from litellm.litellm_core_utils.core_helpers import ( + _get_parent_otel_span_from_kwargs, + get_litellm_metadata_from_kwargs, +) +from litellm.litellm_core_utils.litellm_logging import StandardLoggingPayloadSetup +from litellm.proxy._types import UserAPIKeyAuth +from litellm.proxy.auth.auth_checks import log_db_metrics +from litellm.proxy.auth.route_checks import RouteChecks +from litellm.proxy.utils import ProxyUpdateSpend +from litellm.types.utils import ( + StandardLoggingPayload, + StandardLoggingUserAPIKeyMetadata, +) +from litellm.utils import get_end_user_id_for_cost_tracking + + +class _ProxyDBLogger(CustomLogger): + async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): + await self._PROXY_track_cost_callback( + kwargs, response_obj, start_time, end_time + ) + + async def async_post_call_failure_hook( + self, + request_data: dict, + original_exception: Exception, + user_api_key_dict: UserAPIKeyAuth, + traceback_str: Optional[str] = None, + ): + request_route = user_api_key_dict.request_route + if _ProxyDBLogger._should_track_errors_in_db() is False: + return + elif request_route is not None and not RouteChecks.is_llm_api_route( + route=request_route + ): + return + + from litellm.proxy.proxy_server import proxy_logging_obj + + _metadata = dict( + StandardLoggingUserAPIKeyMetadata( + user_api_key_hash=user_api_key_dict.api_key, + user_api_key_alias=user_api_key_dict.key_alias, + user_api_key_spend=user_api_key_dict.spend, + user_api_key_max_budget=user_api_key_dict.max_budget, + user_api_key_budget_reset_at=( + user_api_key_dict.budget_reset_at.isoformat() + if user_api_key_dict.budget_reset_at + else None + ), + user_api_key_user_email=user_api_key_dict.user_email, + user_api_key_user_id=user_api_key_dict.user_id, + user_api_key_team_id=user_api_key_dict.team_id, + user_api_key_org_id=user_api_key_dict.org_id, + user_api_key_project_id=user_api_key_dict.project_id, + user_api_key_team_alias=user_api_key_dict.team_alias, + user_api_key_end_user_id=user_api_key_dict.end_user_id, + user_api_key_request_route=user_api_key_dict.request_route, + user_api_key_auth_metadata=user_api_key_dict.metadata, + ) + ) + _metadata["user_api_key"] = user_api_key_dict.api_key + _metadata["status"] = "failure" + _metadata[ + "error_information" + ] = StandardLoggingPayloadSetup.get_error_information( + original_exception=original_exception, + traceback_str=traceback_str, + ) + + existing_metadata: dict = request_data.get("metadata", None) or {} + existing_metadata.update(_metadata) + + if "litellm_params" not in request_data: + request_data["litellm_params"] = {} + + existing_litellm_params = request_data.get("litellm_params", {}) + existing_litellm_metadata = existing_litellm_params.get("metadata", {}) or {} + + # Preserve tags from existing metadata + if existing_litellm_metadata.get("tags"): + existing_metadata["tags"] = existing_litellm_metadata.get("tags") + + request_data["litellm_params"]["proxy_server_request"] = ( + request_data.get("proxy_server_request") + or existing_litellm_params.get("proxy_server_request") + or {} + ) + request_data["litellm_params"]["metadata"] = existing_metadata + + # Preserve model name and custom_llm_provider + if "model" not in request_data: + request_data["model"] = existing_litellm_params.get( + "model" + ) or request_data.get("model", "") + if "custom_llm_provider" not in request_data: + request_data["custom_llm_provider"] = existing_litellm_params.get( + "custom_llm_provider" + ) or request_data.get("custom_llm_provider", "") + + await proxy_logging_obj.db_spend_update_writer.update_database( + token=user_api_key_dict.api_key, + response_cost=0.0, + user_id=user_api_key_dict.user_id, + end_user_id=user_api_key_dict.end_user_id, + team_id=user_api_key_dict.team_id, + kwargs=request_data, + completion_response=original_exception, + start_time=datetime.now(), + end_time=datetime.now(), + org_id=user_api_key_dict.org_id, + ) + + @log_db_metrics + async def _PROXY_track_cost_callback( + self, + kwargs, # kwargs to completion + completion_response: Optional[ + Union[litellm.ModelResponse, Any] + ], # response from completion + start_time=None, + end_time=None, # start/end time for completion + ): + from litellm.proxy.proxy_server import proxy_logging_obj, update_cache + + verbose_proxy_logger.debug("INSIDE _PROXY_track_cost_callback") + try: + verbose_proxy_logger.debug( + f"kwargs stream: {kwargs.get('stream', None)} + complete streaming response: {kwargs.get('complete_streaming_response', None)}" + ) + parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs=kwargs) + litellm_params = kwargs.get("litellm_params", {}) or {} + end_user_id = get_end_user_id_for_cost_tracking(litellm_params) + metadata = get_litellm_metadata_from_kwargs(kwargs=kwargs) + user_id = cast(Optional[str], metadata.get("user_api_key_user_id", None)) + team_id = cast(Optional[str], metadata.get("user_api_key_team_id", None)) + org_id = cast(Optional[str], metadata.get("user_api_key_org_id", None)) + key_alias = cast(Optional[str], metadata.get("user_api_key_alias", None)) + end_user_max_budget = metadata.get("user_api_end_user_max_budget", None) + sl_object: Optional[StandardLoggingPayload] = kwargs.get( + "standard_logging_object", None + ) + response_cost = ( + sl_object.get("response_cost", None) + if sl_object is not None + else kwargs.get("response_cost", None) + ) + tags: Optional[List[str]] = ( + sl_object.get("request_tags", None) if sl_object is not None else None + ) + + if response_cost is not None: + user_api_key = metadata.get("user_api_key", None) + if kwargs.get("cache_hit", False) is True: + response_cost = 0.0 + verbose_proxy_logger.debug( + f"Cache Hit: response_cost {response_cost}, for user_id {user_id}" + ) + + verbose_proxy_logger.debug( + f"user_api_key {user_api_key}, user_id {user_id}, team_id {team_id}, end_user_id {end_user_id}" + ) + if _should_track_cost_callback( + user_api_key=user_api_key, + user_id=user_id, + team_id=team_id, + end_user_id=end_user_id, + ): + ## UPDATE DATABASE + await proxy_logging_obj.db_spend_update_writer.update_database( + token=user_api_key, + response_cost=response_cost, + user_id=user_id, + end_user_id=end_user_id, + team_id=team_id, + kwargs=kwargs, + completion_response=completion_response, + start_time=start_time, + end_time=end_time, + org_id=org_id, + ) + + # update cache + asyncio.create_task( + update_cache( + token=user_api_key, + user_id=user_id, + end_user_id=end_user_id, + response_cost=response_cost, + team_id=team_id, + parent_otel_span=parent_otel_span, + tags=tags, + ) + ) + + await proxy_logging_obj.slack_alerting_instance.customer_spend_alert( + token=user_api_key, + key_alias=key_alias, + end_user_id=end_user_id, + response_cost=response_cost, + max_budget=end_user_max_budget, + ) + else: + # Non-model call types (health checks, afile_delete) have no model or standard_logging_object. + # Use .get() for "stream" to avoid KeyError on health checks. + if sl_object is None and not kwargs.get("model"): + verbose_proxy_logger.warning( + "Cost tracking - skipping, no standard_logging_object and no model for call_type=%s", + kwargs.get("call_type", "unknown"), + ) + return + if kwargs.get("stream") is not True or ( + kwargs.get("stream") is True + and "complete_streaming_response" in kwargs + ): + if sl_object is not None: + cost_tracking_failure_debug_info: Union[dict, str] = ( + sl_object["response_cost_failure_debug_info"] # type: ignore + or "response_cost_failure_debug_info is None in standard_logging_object" + ) + else: + cost_tracking_failure_debug_info = ( + "standard_logging_object not found" + ) + model = kwargs.get("model") + raise Exception( + f"Cost tracking failed for model={model}.\nDebug info - {cost_tracking_failure_debug_info}\nAdd custom pricing - https://docs.litellm.ai/docs/proxy/custom_pricing" + ) + except Exception as e: + error_msg = f"Error in tracking cost callback - {str(e)}\n Traceback:{traceback.format_exc()}" + model = kwargs.get("model", "") + metadata = get_litellm_metadata_from_kwargs(kwargs=kwargs) + litellm_metadata = kwargs.get("litellm_params", {}).get( + "litellm_metadata", {} + ) + old_metadata = kwargs.get("litellm_params", {}).get("metadata", {}) + call_type = kwargs.get("call_type", "") + error_msg += f"\n Args to _PROXY_track_cost_callback\n model: {model}\n chosen_metadata: {metadata}\n litellm_metadata: {litellm_metadata}\n old_metadata: {old_metadata}\n call_type: {call_type}\n" + asyncio.create_task( + proxy_logging_obj.failed_tracking_alert( + error_message=error_msg, + failing_model=model, + ) + ) + + verbose_proxy_logger.exception( + "Error in tracking cost callback - %s", str(e) + ) + + @staticmethod + def _should_track_errors_in_db(): + """ + Returns True if errors should be tracked in the database + + By default, errors are tracked in the database + + If users want to disable error tracking, they can set the disable_error_logs flag in the general_settings + """ + from litellm.proxy.proxy_server import general_settings + + if general_settings.get("disable_error_logs") is True: + return False + return + + +def _should_track_cost_callback( + user_api_key: Optional[str], + user_id: Optional[str], + team_id: Optional[str], + end_user_id: Optional[str], +) -> bool: + """ + Determine if the cost callback should be tracked based on the kwargs + """ + + # don't run track cost callback if user opted into disabling spend + if ProxyUpdateSpend.disable_spend_updates() is True: + return False + + if ( + user_api_key is not None + or user_id is not None + or team_id is not None + or end_user_id is not None + ): + return True + return False diff --git a/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py b/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py index 56b513554a..2104a60604 100644 --- a/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py +++ b/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py @@ -1,2744 +1,2745 @@ -import ast -import asyncio -import copy -import json -import traceback -from base64 import b64encode -from datetime import datetime -from typing import Any, Dict, List, Optional, Tuple, Union, cast -from urllib.parse import urlencode, urlparse - -import httpx -from fastapi import ( - APIRouter, - Depends, - FastAPI, - HTTPException, - Request, - Response, - UploadFile, - WebSocket, - status, -) -from fastapi.responses import StreamingResponse -from starlette.datastructures import UploadFile as StarletteUploadFile -from starlette.websockets import WebSocketState -from websockets.asyncio.client import connect -from websockets.exceptions import ( - ConnectionClosedError, - ConnectionClosedOK, - InvalidStatus, -) - -import litellm -from litellm._logging import verbose_proxy_logger -from litellm._uuid import uuid -from litellm.constants import MAXIMUM_TRACEBACK_LINES_TO_LOG -from litellm.integrations.custom_logger import CustomLogger -from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj -from litellm.litellm_core_utils.safe_json_dumps import safe_dumps -from litellm.llms.custom_httpx.http_handler import get_async_httpx_client -from litellm.passthrough import BasePassthroughUtils -from litellm.proxy._types import ( - ConfigFieldInfo, - ConfigFieldUpdate, - LiteLLMRoutes, - PassThroughEndpointResponse, - PassThroughGenericEndpoint, - ProxyException, - UserAPIKeyAuth, -) -from litellm.proxy.auth.user_api_key_auth import user_api_key_auth -from litellm.proxy.common_request_processing import ProxyBaseLLMRequestProcessing -from litellm.proxy.common_utils.http_parsing_utils import _read_request_body -from litellm.proxy.utils import get_server_root_path -from litellm.secret_managers.main import get_secret_str -from litellm.types.llms.custom_http import httpxSpecialProvider -from litellm.types.passthrough_endpoints.pass_through_endpoints import ( - EndpointType, - PassthroughStandardLoggingPayload, -) -from litellm.types.utils import StandardLoggingUserAPIKeyMetadata - -from .streaming_handler import PassThroughStreamingHandler -from .success_handler import PassThroughEndpointLogging - -router = APIRouter() - -pass_through_endpoint_logging = PassThroughEndpointLogging() - -# Global registry to track registered pass-through routes and prevent memory leaks -_registered_pass_through_routes: Dict[str, Dict[str, Union[str, Dict[str, Any]]]] = {} - - -def get_response_body(response: httpx.Response) -> Optional[dict]: - try: - return response.json() - except Exception: - return None - - -async def set_env_variables_in_header(custom_headers: Optional[dict]) -> Optional[dict]: - """ - checks if any headers on config.yaml are defined as os.environ/COHERE_API_KEY etc - - only runs for headers defined on config.yaml - - example header can be - - {"Authorization": "Bearer os.environ/COHERE_API_KEY"} - """ - if custom_headers is None: - return None - headers = {} - for key, value in custom_headers.items(): - # langfuse Api requires base64 encoded headers - it's simpleer to just ask litellm users to set their langfuse public and secret keys - # we can then get the b64 encoded keys here - if key == "LANGFUSE_PUBLIC_KEY" or key == "LANGFUSE_SECRET_KEY": - # langfuse requires b64 encoded headers - we construct that here - _langfuse_public_key = custom_headers["LANGFUSE_PUBLIC_KEY"] - _langfuse_secret_key = custom_headers["LANGFUSE_SECRET_KEY"] - if isinstance( - _langfuse_public_key, str - ) and _langfuse_public_key.startswith("os.environ/"): - _langfuse_public_key = get_secret_str(_langfuse_public_key) - if isinstance( - _langfuse_secret_key, str - ) and _langfuse_secret_key.startswith("os.environ/"): - _langfuse_secret_key = get_secret_str(_langfuse_secret_key) - headers["Authorization"] = "Basic " + b64encode( - f"{_langfuse_public_key}:{_langfuse_secret_key}".encode("utf-8") - ).decode("ascii") - else: - # for all other headers - headers[key] = value - if isinstance(value, str) and "os.environ/" in value: - verbose_proxy_logger.debug( - "pass through endpoint - looking up 'os.environ/' variable" - ) - # get string section that is os.environ/ - start_index = value.find("os.environ/") - _variable_name = value[start_index:] - - verbose_proxy_logger.debug( - "pass through endpoint - getting secret for variable name: %s", - _variable_name, - ) - _secret_value = get_secret_str(_variable_name) - if _secret_value is not None: - new_value = value.replace(_variable_name, _secret_value) - headers[key] = new_value - return headers - - -async def chat_completion_pass_through_endpoint( # noqa: PLR0915 - fastapi_response: Response, - request: Request, - adapter_id: str, - user_api_key_dict: UserAPIKeyAuth, -): - from litellm.proxy.proxy_server import ( - add_litellm_data_to_request, - general_settings, - llm_router, - proxy_config, - proxy_logging_obj, - user_api_base, - user_max_tokens, - user_model, - user_request_timeout, - user_temperature, - version, - ) - - data = {} - try: - body = await request.body() - body_str = body.decode() - try: - data = ast.literal_eval(body_str) - except Exception: - data = json.loads(body_str) - - data["adapter_id"] = adapter_id - - verbose_proxy_logger.debug( - "Request received by LiteLLM:\n{}".format(json.dumps(data, indent=4)), - ) - data["model"] = ( - general_settings.get("completion_model", None) # server default - or user_model # model name passed via cli args - or data.get("model", None) # default passed in http request - ) - if user_model: - data["model"] = user_model - - data = await add_litellm_data_to_request( - data=data, # type: ignore - request=request, - general_settings=general_settings, - user_api_key_dict=user_api_key_dict, - version=version, - proxy_config=proxy_config, - ) - - # override with user settings, these are params passed via cli - if user_temperature: - data["temperature"] = user_temperature - if user_request_timeout: - data["request_timeout"] = user_request_timeout - if user_max_tokens: - data["max_tokens"] = user_max_tokens - if user_api_base: - data["api_base"] = user_api_base - - ### MODEL ALIAS MAPPING ### - # check if model name in model alias map - # get the actual model name - if data["model"] in litellm.model_alias_map: - data["model"] = litellm.model_alias_map[data["model"]] - - # Check key-specific aliases - if ( - isinstance(data["model"], str) - and user_api_key_dict.aliases - and isinstance(user_api_key_dict.aliases, dict) - and data["model"] in user_api_key_dict.aliases - ): - data["model"] = user_api_key_dict.aliases[data["model"]] - - ### CALL HOOKS ### - modify incoming data before calling the model - data = await proxy_logging_obj.pre_call_hook( # type: ignore - user_api_key_dict=user_api_key_dict, data=data, call_type="text_completion" - ) - - ### ROUTE THE REQUESTs ### - router_model_names = llm_router.model_names if llm_router is not None else [] - # skip router if user passed their key - if "api_key" in data: - llm_response = asyncio.create_task(litellm.aadapter_completion(**data)) - elif ( - llm_router is not None and data["model"] in router_model_names - ): # model in router model list - llm_response = asyncio.create_task(llm_router.aadapter_completion(**data)) - elif ( - llm_router is not None - and llm_router.model_group_alias is not None - and data["model"] in llm_router.model_group_alias - ): # model set in model_group_alias - llm_response = asyncio.create_task(llm_router.aadapter_completion(**data)) - elif llm_router is not None and llm_router.has_model_id( - data["model"] - ): # model in router model list - llm_response = asyncio.create_task(llm_router.aadapter_completion(**data)) - elif ( - llm_router is not None - and data["model"] not in router_model_names - and ( - llm_router.default_deployment is not None - or len(llm_router.pattern_router.patterns) > 0 - ) - ): # check for wildcard routes or default deployment before checking deployment_names - llm_response = asyncio.create_task(llm_router.aadapter_completion(**data)) - elif ( - llm_router is not None and data["model"] in llm_router.deployment_names - ): # model in router deployments, calling a specific deployment on the router (lowest priority) - llm_response = asyncio.create_task( - llm_router.aadapter_completion(**data, specific_deployment=True) - ) - elif user_model is not None: # `litellm --model ` - llm_response = asyncio.create_task(litellm.aadapter_completion(**data)) - else: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail={ - "error": "completion: Invalid model name passed in model=" - + data.get("model", "") - }, - ) - - # Await the llm_response task - response = await llm_response - - hidden_params = getattr(response, "_hidden_params", {}) or {} - model_id = hidden_params.get("model_id", None) or "" - cache_key = hidden_params.get("cache_key", None) or "" - api_base = hidden_params.get("api_base", None) or "" - response_cost = hidden_params.get("response_cost", None) or "" - - ### ALERTING ### - asyncio.create_task( - proxy_logging_obj.update_request_status( - litellm_call_id=data.get("litellm_call_id", ""), status="success" - ) - ) - - verbose_proxy_logger.debug("final response: %s", response) - - fastapi_response.headers.update( - ProxyBaseLLMRequestProcessing.get_custom_headers( - user_api_key_dict=user_api_key_dict, - model_id=model_id, - cache_key=cache_key, - api_base=api_base, - version=version, - response_cost=response_cost, - ) - ) - - verbose_proxy_logger.debug("\nResponse from Litellm:\n{}".format(response)) - return response - except Exception as e: - await proxy_logging_obj.post_call_failure_hook( - user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data - ) - verbose_proxy_logger.exception( - "litellm.proxy.proxy_server.completion(): Exception occured - {}".format( - str(e) - ) - ) - error_msg = f"{str(e)}" - raise ProxyException( - message=getattr(e, "message", error_msg), - type=getattr(e, "type", "None"), - param=getattr(e, "param", "None"), - code=getattr(e, "status_code", 500), - ) - - -class HttpPassThroughEndpointHelpers(BasePassthroughUtils): - @staticmethod - def get_response_headers( - headers: httpx.Headers, - litellm_call_id: Optional[str] = None, - custom_headers: Optional[dict] = None, - ) -> dict: - excluded_headers = {"transfer-encoding", "content-encoding"} - - return_headers = { - key: value - for key, value in headers.items() - if key.lower() not in excluded_headers - } - if litellm_call_id: - return_headers["x-litellm-call-id"] = litellm_call_id - if custom_headers: - return_headers.update(custom_headers) - - return return_headers - - @staticmethod - def get_endpoint_type(url: str) -> EndpointType: - parsed_url = urlparse(url) - if ( - ("generateContent") in url - or ("streamGenerateContent") in url - or ("rawPredict") in url - or ("streamRawPredict") in url - ): - return EndpointType.VERTEX_AI - elif parsed_url.hostname == "api.anthropic.com": - return EndpointType.ANTHROPIC - elif ( - parsed_url.hostname == "api.openai.com" - or parsed_url.hostname == "openai.azure.com" - or (parsed_url.hostname and "openai.com" in parsed_url.hostname) - ): - return EndpointType.OPENAI - return EndpointType.GENERIC - - @staticmethod - async def _make_non_streaming_http_request( - request: Request, - async_client: httpx.AsyncClient, - url: str, - headers: dict, - requested_query_params: Optional[dict] = None, - custom_body: Optional[dict] = None, - ) -> httpx.Response: - """ - Make a non-streaming HTTP request - - If request is GET, don't include a JSON body - """ - if request.method == "GET": - response = await async_client.request( - method=request.method, - url=url, - headers=headers, - params=requested_query_params, - ) - else: - response = await async_client.request( - method=request.method, - url=url, - headers=headers, - params=requested_query_params, - json=custom_body, - ) - return response - - @staticmethod - async def non_streaming_http_request_handler( - request: Request, - async_client: httpx.AsyncClient, - url: httpx.URL, - headers: dict, - requested_query_params: Optional[dict] = None, - _parsed_body: Optional[dict] = None, - ) -> httpx.Response: - """ - Handle non-streaming HTTP requests - - Handles special cases when GET requests, multipart/form-data requests, and generic httpx requests - """ - if request.method == "GET": - response = await async_client.request( - method=request.method, - url=url, - headers=headers, - params=requested_query_params, - ) - elif HttpPassThroughEndpointHelpers.is_multipart(request) is True: - return await HttpPassThroughEndpointHelpers.make_multipart_http_request( - request=request, - async_client=async_client, - url=url, - headers=headers, - requested_query_params=requested_query_params, - ) - else: - # Generic httpx method - response = await async_client.request( - method=request.method, - url=url, - headers=headers, - params=requested_query_params, - json=_parsed_body, - ) - return response - - @staticmethod - def is_multipart(request: Request) -> bool: - """Check if the request is a multipart/form-data request""" - return "multipart/form-data" in request.headers.get("content-type", "") - - @staticmethod - async def _build_request_files_from_upload_file( - upload_file: Union[UploadFile, StarletteUploadFile], - ) -> Tuple[Optional[str], bytes, Optional[str]]: - """Build a request files dict from an UploadFile object""" - file_content = await upload_file.read() - return (upload_file.filename, file_content, upload_file.content_type) - - @staticmethod - async def make_multipart_http_request( - request: Request, - async_client: httpx.AsyncClient, - url: httpx.URL, - headers: dict, - requested_query_params: Optional[dict] = None, - ) -> httpx.Response: - """Process multipart/form-data requests, handling both files and form fields""" - form_data = await request.form() - files = {} - form_data_dict = {} - - for field_name, field_value in form_data.items(): - if isinstance(field_value, (StarletteUploadFile, UploadFile)): - files[ - field_name - ] = await HttpPassThroughEndpointHelpers._build_request_files_from_upload_file( - upload_file=field_value - ) - else: - form_data_dict[field_name] = field_value - - # Remove content-type header - httpx will set it correctly with the new boundary - # when it creates the multipart body from files/data parameters - headers_copy = headers.copy() - headers_copy.pop("content-type", None) - - response = await async_client.request( - method=request.method, - url=url, - headers=headers_copy, - params=requested_query_params, - files=files, - data=form_data_dict, - ) - return response - - @staticmethod - def _init_kwargs_for_pass_through_endpoint( - request: Request, - user_api_key_dict: UserAPIKeyAuth, - passthrough_logging_payload: PassthroughStandardLoggingPayload, - logging_obj: LiteLLMLoggingObj, - _parsed_body: Optional[dict] = None, - litellm_call_id: Optional[str] = None, - ) -> dict: - """ - Filter out litellm params from the request body - """ - from litellm.types.utils import all_litellm_params - - _parsed_body = _parsed_body or {} - - litellm_params_in_body = {} - for k in all_litellm_params: - if k in _parsed_body: - litellm_params_in_body[k] = _parsed_body.pop(k, None) - - _metadata = dict( - StandardLoggingUserAPIKeyMetadata( - user_api_key_hash=user_api_key_dict.api_key, - user_api_key_alias=user_api_key_dict.key_alias, - user_api_key_user_email=user_api_key_dict.user_email, - user_api_key_user_id=user_api_key_dict.user_id, - user_api_key_team_id=user_api_key_dict.team_id, - user_api_key_org_id=user_api_key_dict.org_id, - user_api_key_team_alias=user_api_key_dict.team_alias, - user_api_key_end_user_id=user_api_key_dict.end_user_id, - user_api_key_request_route=user_api_key_dict.request_route, - user_api_key_spend=user_api_key_dict.spend, - user_api_key_max_budget=user_api_key_dict.max_budget, - user_api_key_budget_reset_at=( - user_api_key_dict.budget_reset_at.isoformat() - if user_api_key_dict.budget_reset_at - else None - ), - user_api_key_auth_metadata=user_api_key_dict.metadata, - ) - ) - - _metadata["user_api_key"] = user_api_key_dict.api_key - - litellm_metadata = litellm_params_in_body.pop("litellm_metadata", None) - metadata = litellm_params_in_body.pop("metadata", None) - if litellm_metadata: - _metadata.update(litellm_metadata) - if metadata: - _metadata.update(metadata) - - _metadata = _update_metadata_with_tags_in_header( - request=request, - metadata=_metadata, - ) - - kwargs = { - "litellm_params": { - **litellm_params_in_body, # type: ignore - "metadata": _metadata, - "proxy_server_request": { - "url": str(request.url), - "method": request.method, - "body": copy.copy(_parsed_body), # use copy instead of deepcopy - "headers": request.headers, - }, - }, - "call_type": "pass_through_endpoint", - "litellm_call_id": litellm_call_id, - "passthrough_logging_payload": passthrough_logging_payload, - } - - logging_obj.model_call_details[ - "passthrough_logging_payload" - ] = passthrough_logging_payload - - return kwargs - - @staticmethod - def construct_target_url_with_subpath( - base_target: str, subpath: str, include_subpath: Optional[bool] - ) -> str: - """ - Helper function to construct the full target URL with subpath handling. - - Args: - base_target: The base target URL - subpath: The captured subpath from the request - include_subpath: Whether to include the subpath in the target URL - - Returns: - The constructed full target URL - """ - if not include_subpath: - return base_target - - if not subpath: - return base_target - - # Ensure base_target ends with / and subpath doesn't start with / - if not base_target.endswith("/"): - base_target = base_target + "/" - if subpath.startswith("/"): - subpath = subpath[1:] - - return base_target + subpath - - @staticmethod - def _update_stream_param_based_on_request_body( - parsed_body: dict, - stream: Optional[bool] = None, - ) -> Optional[bool]: - """ - If stream is provided in the request body, use it. - Otherwise, use the stream parameter passed to the `pass_through_request` function - """ - if "stream" in parsed_body: - return parsed_body.get("stream", stream) - return stream - - -async def pass_through_request( # noqa: PLR0915 - request: Request, - target: str, - custom_headers: dict, - user_api_key_dict: UserAPIKeyAuth, - custom_body: Optional[dict] = None, - forward_headers: Optional[bool] = False, - merge_query_params: Optional[bool] = False, - query_params: Optional[dict] = None, - stream: Optional[bool] = None, - cost_per_request: Optional[float] = None, - custom_llm_provider: Optional[str] = None, - guardrails_config: Optional[dict] = None, -): - """ - Pass through endpoint handler, makes the httpx request for pass-through endpoints and ensures logging hooks are called - - Args: - request: The incoming request - target: The target URL - custom_headers: The custom headers - user_api_key_dict: The user API key dictionary - custom_body: The custom body - forward_headers: Whether to forward headers - merge_query_params: Whether to merge query params - query_params: The query params - stream: Whether to stream the response - cost_per_request: Optional field - cost per request to the target endpoint - custom_llm_provider: Optional field - custom LLM provider for the endpoint - guardrails_config: Optional field - guardrails configuration for passthrough endpoint - """ - from litellm.litellm_core_utils.litellm_logging import Logging - from litellm.proxy.pass_through_endpoints.passthrough_guardrails import ( - PassthroughGuardrailHandler, - ) - from litellm.proxy.proxy_server import proxy_logging_obj - - ######################################################### - # Initialize variables - ######################################################### - litellm_call_id = str(uuid.uuid4()) - url: Optional[httpx.URL] = None - - # parsed request body - _parsed_body: Optional[dict] = None - # kwargs for pass through endpoint, contains metadata, litellm_params, call_type, litellm_call_id, passthrough_logging_payload - kwargs: Optional[dict] = None - - ######################################################### - try: - url = httpx.URL(target) - headers = custom_headers - headers = HttpPassThroughEndpointHelpers.forward_headers_from_request( - request_headers=dict(request.headers), - headers=headers, - forward_headers=forward_headers, - ) - - if merge_query_params: - # Create a new URL with the merged query params - url = url.copy_with( - query=urlencode( - HttpPassThroughEndpointHelpers.get_merged_query_parameters( - existing_url=url, - request_query_params=dict(request.query_params), - ) - ).encode("ascii") - ) - - endpoint_type: EndpointType = HttpPassThroughEndpointHelpers.get_endpoint_type( - str(url) - ) - - if custom_body: - _parsed_body = custom_body - else: - _parsed_body = await _read_request_body(request) - verbose_proxy_logger.debug( - "Pass through endpoint sending request to \nURL {}\nheaders: {}\nbody: {}\n".format( - url, headers, _parsed_body - ) - ) - - ### COLLECT GUARDRAILS FOR PASSTHROUGH ENDPOINT ### - # Passthrough endpoints are opt-in only for guardrails - # When enabled, collect guardrails from org/team/key levels + passthrough-specific - guardrails_to_run = PassthroughGuardrailHandler.collect_guardrails( - user_api_key_dict=user_api_key_dict, - passthrough_guardrails_config=guardrails_config, - ) - - # Add guardrails to metadata if any should run - if guardrails_to_run and len(guardrails_to_run) > 0: - if _parsed_body is None: - _parsed_body = {} - if "metadata" not in _parsed_body: - _parsed_body["metadata"] = {} - _parsed_body["metadata"]["guardrails"] = guardrails_to_run - verbose_proxy_logger.debug( - f"Added guardrails to passthrough request metadata: {guardrails_to_run}" - ) - - ## LOGGING OBJECT ## - initialize before pre_call_hook so guardrails can access it - start_time = datetime.now() - logging_obj = Logging( - model="unknown", - messages=[{"role": "user", "content": safe_dumps(_parsed_body)}], - stream=False, - call_type="pass_through_endpoint", - start_time=start_time, - litellm_call_id=litellm_call_id, - function_id="1245", - ) - - # Store passthrough guardrails config on logging_obj for field targeting - logging_obj.passthrough_guardrails_config = guardrails_config - - # Store logging_obj in data so guardrails can access it - if _parsed_body is None: - _parsed_body = {} - _parsed_body["litellm_logging_obj"] = logging_obj - - ### CALL HOOKS ### - modify incoming data / reject request before calling the model - _parsed_body = await proxy_logging_obj.pre_call_hook( - user_api_key_dict=user_api_key_dict, - data=_parsed_body, - call_type="pass_through_endpoint", - ) - async_client_obj = get_async_httpx_client( - llm_provider=httpxSpecialProvider.PassThroughEndpoint, - params={"timeout": 600}, - ) - async_client = async_client_obj.client - passthrough_logging_payload = PassthroughStandardLoggingPayload( - url=str(url), - request_body=_parsed_body, - request_method=getattr(request, "method", None), - cost_per_request=cost_per_request, - ) - kwargs = HttpPassThroughEndpointHelpers._init_kwargs_for_pass_through_endpoint( - user_api_key_dict=user_api_key_dict, - _parsed_body=_parsed_body, - passthrough_logging_payload=passthrough_logging_payload, - litellm_call_id=litellm_call_id, - request=request, - logging_obj=logging_obj, - ) - - # Store custom_llm_provider in kwargs and logging object if provided - if custom_llm_provider: - logging_obj.model_call_details["custom_llm_provider"] = custom_llm_provider - logging_obj.model_call_details["litellm_params"] = kwargs.get( - "litellm_params", {} - ) - - # done for supporting 'parallel_request_limiter.py' with pass-through endpoints - logging_obj.update_environment_variables( - model="unknown", - user="unknown", - optional_params={}, - litellm_params=kwargs["litellm_params"], - call_type="pass_through_endpoint", - ) - logging_obj.model_call_details["litellm_call_id"] = litellm_call_id - - # combine url with query params for logging - requested_query_params: Optional[dict] = query_params or dict( - request.query_params - ) - - requested_query_params_str = None - if requested_query_params: - requested_query_params_str = "&".join( - f"{k}={v}" for k, v in requested_query_params.items() - ) - - logging_url = str(url) - if requested_query_params_str: - if "?" in str(url): - logging_url = str(url) + "&" + requested_query_params_str - else: - logging_url = str(url) + "?" + requested_query_params_str - - logging_obj.pre_call( - input=[{"role": "user", "content": safe_dumps(_parsed_body)}], - api_key="", - additional_args={ - "complete_input_dict": _parsed_body, - "api_base": str(logging_url), - "headers": headers, - }, - ) - stream = ( - HttpPassThroughEndpointHelpers._update_stream_param_based_on_request_body( - parsed_body=_parsed_body, - stream=stream, - ) - ) - - if stream: - req = async_client.build_request( - "POST", - url, - json=_parsed_body, - params=requested_query_params, - headers=headers, - ) - - response = await async_client.send(req, stream=stream) - - try: - response.raise_for_status() - except httpx.HTTPStatusError as e: - raise HTTPException( - status_code=e.response.status_code, detail=await e.response.aread() - ) - - return StreamingResponse( - PassThroughStreamingHandler.chunk_processor( - response=response, - request_body=_parsed_body, - litellm_logging_obj=logging_obj, - endpoint_type=endpoint_type, - start_time=start_time, - passthrough_success_handler_obj=pass_through_endpoint_logging, - url_route=str(url), - ), - headers=HttpPassThroughEndpointHelpers.get_response_headers( - headers=response.headers, - litellm_call_id=litellm_call_id, - ), - status_code=response.status_code, - ) - - response = ( - await HttpPassThroughEndpointHelpers.non_streaming_http_request_handler( - request=request, - async_client=async_client, - url=url, - headers=headers, - requested_query_params=requested_query_params, - _parsed_body=_parsed_body, - ) - ) - verbose_proxy_logger.debug("response.headers= %s", response.headers) - - if _is_streaming_response(response) is True: - try: - response.raise_for_status() - except httpx.HTTPStatusError as e: - raise HTTPException( - status_code=e.response.status_code, detail=await e.response.aread() - ) - - return StreamingResponse( - PassThroughStreamingHandler.chunk_processor( - response=response, - request_body=_parsed_body, - litellm_logging_obj=logging_obj, - endpoint_type=endpoint_type, - start_time=start_time, - passthrough_success_handler_obj=pass_through_endpoint_logging, - url_route=str(url), - ), - headers=HttpPassThroughEndpointHelpers.get_response_headers( - headers=response.headers, - litellm_call_id=litellm_call_id, - ), - status_code=response.status_code, - ) - - try: - response.raise_for_status() - except httpx.HTTPStatusError as e: - raise HTTPException( - status_code=e.response.status_code, detail=e.response.text - ) - - if response.status_code >= 300: - raise HTTPException(status_code=response.status_code, detail=response.text) - - content = await response.aread() - - ## LOG SUCCESS - response_body: Optional[dict] = get_response_body(response) - passthrough_logging_payload["response_body"] = response_body - end_time = datetime.now() - asyncio.create_task( - pass_through_endpoint_logging.pass_through_async_success_handler( - httpx_response=response, - response_body=response_body, - url_route=str(url), - result="", - start_time=start_time, - end_time=end_time, - logging_obj=logging_obj, - cache_hit=False, - request_body=_parsed_body, - custom_llm_provider=custom_llm_provider, - **kwargs, - ) - ) - - ## CUSTOM HEADERS - `x-litellm-*` - custom_headers = ProxyBaseLLMRequestProcessing.get_custom_headers( - user_api_key_dict=user_api_key_dict, - call_id=litellm_call_id, - model_id=None, - cache_key=None, - api_base=str(url._uri_reference), - ) - - return Response( - content=content, - status_code=response.status_code, - headers=HttpPassThroughEndpointHelpers.get_response_headers( - headers=response.headers, - custom_headers=custom_headers, - ), - ) - except Exception as e: - custom_headers = ProxyBaseLLMRequestProcessing.get_custom_headers( - user_api_key_dict=user_api_key_dict, - call_id=litellm_call_id, - model_id=None, - cache_key=None, - api_base=str(url._uri_reference) if url else None, - ) - verbose_proxy_logger.exception( - "litellm.proxy.proxy_server.pass_through_endpoint(): Exception occured - {}".format( - str(e) - ) - ) - - ######################################################### - # Monitoring: Trigger post_call_failure_hook - # for pass through endpoint failure - ######################################################### - request_payload: dict = _parsed_body or {} - # add user_api_key_dict, litellm_call_id, passthrough_logging_payloa for logging - if kwargs: - for key, value in kwargs.items(): - request_payload[key] = value - - if ( - "model" not in request_payload - and _parsed_body - and isinstance(_parsed_body, dict) - ): - request_payload["model"] = _parsed_body.get("model", "") - if "custom_llm_provider" not in request_payload and custom_llm_provider: - request_payload["custom_llm_provider"] = custom_llm_provider - - await proxy_logging_obj.post_call_failure_hook( - user_api_key_dict=user_api_key_dict, - original_exception=e, - request_data=request_payload, - traceback_str=traceback.format_exc( - limit=MAXIMUM_TRACEBACK_LINES_TO_LOG, - ), - ) - - ######################################################### - - if isinstance(e, HTTPException): - raise ProxyException( - message=getattr(e, "message", str(e.detail)), - type=getattr(e, "type", "None"), - param=getattr(e, "param", "None"), - code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST), - headers=custom_headers, - ) - else: - error_msg = f"{str(e)}" - raise ProxyException( - message=getattr(e, "message", error_msg), - type=getattr(e, "type", "None"), - param=getattr(e, "param", "None"), - code=getattr(e, "status_code", 500), - headers=custom_headers, - ) - - -def _update_metadata_with_tags_in_header(request: Request, metadata: dict) -> dict: - """ - If tags are in the request headers, add them to the metadata - - Used for google and vertex JS SDKs, and Azure passthrough - Checks both 'tags' and 'x-litellm-tags' headers - """ - tags_to_add = [] - - # Check for 'tags' header first - _tags = request.headers.get("tags") - if _tags: - tags_to_add.extend([tag.strip() for tag in _tags.split(",")]) - - _tags = request.headers.get("x-litellm-tags") - if _tags: - tags_to_add.extend([tag.strip() for tag in _tags.split(",")]) - - # Only add tags key if there are tags to add - if tags_to_add: - if "tags" not in metadata: - metadata["tags"] = [] - metadata["tags"].extend(tags_to_add) - - return metadata - - -async def _parse_request_data_by_content_type( - request: Request, -) -> Tuple[Optional[Any], Optional[Any], Optional[Any], Optional[Any]]: - """ - Parse request data based on content type. - - Handles JSON, multipart/form-data, and URL-encoded form data. - - Returns: - Tuple of (query_params_data, custom_body_data, file_data, stream) - """ - content_type = request.headers.get("content-type", "") - - query_params_data = None - custom_body_data = None - file_data = None - stream = None - - if "application/json" in content_type: - # ✅ Handle JSON - try: - body = await request.json() - query_params_data = body.get("query_params") - custom_body_data = body.get("custom_body") - stream = body.get("stream") - except json.JSONDecodeError: - # Handle requests with no body (e.g., DELETE requests) - pass - elif "multipart/form-data" in content_type: - # ✅ Handle multipart form-data - form = await request.form() - if "query_params" in form: - form_value = form["query_params"] - if isinstance(form_value, str): - try: - query_params_data = json.loads(form_value) - except Exception: - query_params_data = form_value - else: - query_params_data = form_value - - if "custom_body" in form: - form_value = form["custom_body"] - if isinstance(form_value, str): - try: - custom_body_data = json.loads(form_value) - except Exception: - custom_body_data = form_value - else: - custom_body_data = form_value - - if "file" in form: - file_data = form["file"] # this is a Starlette UploadFile object - - elif "application/x-www-form-urlencoded" in content_type: - # ✅ Handle URL-encoded form data - form = await request.form() - query_params_data = form.get("query_params") - custom_body_data = form.get("custom_body") - - else: - # ✅ Fallback: maybe no body, just query params - query_params_data = dict(request.query_params) or None - - return query_params_data, custom_body_data, file_data, stream - - -def create_pass_through_route( - endpoint, - target: str, - custom_headers: Optional[dict] = None, - _forward_headers: Optional[bool] = False, - _merge_query_params: Optional[bool] = False, - dependencies: Optional[List] = None, - include_subpath: Optional[bool] = False, - cost_per_request: Optional[float] = None, - custom_llm_provider: Optional[str] = None, - is_streaming_request: Optional[bool] = False, - query_params: Optional[dict] = None, - guardrails: Optional[Dict[str, Any]] = None, -): - # check if target is an adapter.py or a url - from litellm._uuid import uuid - from litellm.proxy.types_utils.utils import get_instance_fn - - try: - if isinstance(target, CustomLogger): - adapter = target - else: - adapter = get_instance_fn(value=target) - adapter_id = str(uuid.uuid4()) - litellm.adapters = [{"id": adapter_id, "adapter": adapter}] - - async def endpoint_func( # type: ignore - request: Request, - fastapi_response: Response, - user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), - subpath: str = "", # captures sub-paths when include_subpath=True - custom_body: Optional[dict] = None, # accepted for signature compatibility with URL-based path; not forwarded because chat_completion_pass_through_endpoint does not support it - ): - return await chat_completion_pass_through_endpoint( - fastapi_response=fastapi_response, - request=request, - adapter_id=adapter_id, - user_api_key_dict=user_api_key_dict, - ) - - except Exception: - verbose_proxy_logger.debug("Defaulting to target being a url.") - - async def endpoint_func( # type: ignore - request: Request, - fastapi_response: Response, - user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), - subpath: str = "", # captures sub-paths when include_subpath=True - custom_body: Optional[dict] = None, - ): - from litellm.proxy.pass_through_endpoints.pass_through_endpoints import ( - InitPassThroughEndpointHelpers, - ) - - path = request.url.path - - # Parse request data based on content type - ( - query_params_data, - custom_body_data, - file_data, - stream, - ) = await _parse_request_data_by_content_type(request) - - if not InitPassThroughEndpointHelpers.is_registered_pass_through_route( - route=path - ): - raise HTTPException( - status_code=404, - detail=f"Pass-through endpoint {endpoint} not found. This could have been deleted or not yet added to the proxy.", - ) - - passthrough_params = ( - InitPassThroughEndpointHelpers.get_registered_pass_through_route( - route=path - ) - ) - target_params = { - "target": target, - "custom_headers": custom_headers, - "forward_headers": _forward_headers, - "merge_query_params": _merge_query_params, - "cost_per_request": cost_per_request, - "guardrails": None, - } - - if passthrough_params is not None: - target_params.update(passthrough_params.get("passthrough_params", {})) - - # Extract and cast parameters with proper types - param_target = target_params.get("target") or target - param_custom_headers = target_params.get("custom_headers", custom_headers) - param_forward_headers = target_params.get( - "forward_headers", _forward_headers - ) - param_merge_query_params = target_params.get( - "merge_query_params", _merge_query_params - ) - param_cost_per_request = target_params.get( - "cost_per_request", cost_per_request - ) - param_guardrails = target_params.get("guardrails", None) - - # Construct the full target URL with subpath if needed - full_target = ( - HttpPassThroughEndpointHelpers.construct_target_url_with_subpath( - base_target=cast(str, param_target), - subpath=subpath, - include_subpath=include_subpath, - ) - ) - - # Ensure custom_headers is a dict - headers_dict = ( - param_custom_headers if isinstance(param_custom_headers, dict) else {} - ) - - # Ensure query_params and custom_body are dicts or None - final_query_params = ( - query_params_data if isinstance(query_params_data, dict) else {} - ) - if query_params: - final_query_params.update(query_params) - # When a caller (e.g. bedrock_proxy_route) supplies a pre-built - # body, use it instead of the body parsed from the raw request. - final_custom_body: Optional[dict] = None - if custom_body is not None: - final_custom_body = custom_body - elif isinstance(custom_body_data, dict): - final_custom_body = custom_body_data - - return await pass_through_request( # type: ignore - request=request, - target=full_target, - custom_headers=headers_dict, - user_api_key_dict=user_api_key_dict, - forward_headers=cast(Optional[bool], param_forward_headers), - merge_query_params=cast(Optional[bool], param_merge_query_params), - query_params=final_query_params, - stream=is_streaming_request or stream, - custom_body=final_custom_body, - cost_per_request=cast(Optional[float], param_cost_per_request), - custom_llm_provider=custom_llm_provider, - guardrails_config=cast(Optional[dict], param_guardrails), - ) - - return endpoint_func - - -def create_websocket_passthrough_route( - endpoint: str, - target: str, - custom_headers: Optional[dict] = None, - _forward_headers: Optional[bool] = False, - dependencies: Optional[List] = None, - cost_per_request: Optional[float] = None, -): - """ - Create a WebSocket passthrough route function. - - Args: - endpoint: The endpoint path (for logging purposes) - target: The target WebSocket URL (e.g., "wss://api.example.com/ws") - custom_headers: Custom headers to include in the WebSocket connection - _forward_headers: Whether to forward incoming headers - dependencies: FastAPI dependencies to inject - - Returns: - A WebSocket passthrough function that can be registered with app.websocket() - """ - from litellm.proxy.auth.user_api_key_auth import user_api_key_auth_websocket - - async def websocket_endpoint_func( - websocket: WebSocket, - user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth_websocket), - **kwargs, # For additional query parameters - ): - """ - WebSocket passthrough endpoint function. - - This function handles the WebSocket connection by: - 1. Accepting the incoming WebSocket connection - 2. Establishing a connection to the target WebSocket - 3. Forwarding messages bidirectionally - 4. Handling connection cleanup - """ - return await websocket_passthrough_request( - websocket=websocket, - target=target, - custom_headers=custom_headers or {}, - user_api_key_dict=user_api_key_dict, - forward_headers=_forward_headers, - endpoint=endpoint, - cost_per_request=cost_per_request, - accept_websocket=True, # Generic usage should accept the WebSocket - ) - - return websocket_endpoint_func - - -async def websocket_passthrough_request( # noqa: PLR0915 - websocket: WebSocket, - target: str, - custom_headers: dict, - user_api_key_dict: UserAPIKeyAuth, - forward_headers: Optional[bool] = False, - endpoint: Optional[str] = None, - cost_per_request: Optional[float] = None, - accept_websocket: bool = True, -): - """ - WebSocket passthrough request handler. - - Args: - websocket: The incoming WebSocket connection - target: The target WebSocket URL - custom_headers: Custom headers to include in the connection - user_api_key_dict: The user API key dictionary - forward_headers: Whether to forward incoming headers - endpoint: The endpoint path (for logging purposes) - cost_per_request: Optional field - cost per request to the target endpoint - """ - from litellm.litellm_core_utils.litellm_logging import Logging - from litellm.proxy.proxy_server import proxy_logging_obj - from litellm.types.passthrough_endpoints.pass_through_endpoints import ( - PassthroughStandardLoggingPayload, - ) - - # Initialize tracking variables - start_time = datetime.now() - websocket_messages: list[dict[str, Any]] = [] - litellm_call_id = str(uuid.uuid4()) - - verbose_proxy_logger.info( - f"WebSocket passthrough ({endpoint}): Starting WebSocket connection to {target}" - ) - - # Only accept the WebSocket if requested (for generic usage) - if accept_websocket: - await websocket.accept() - verbose_proxy_logger.debug( - f"WebSocket passthrough ({endpoint}): WebSocket connection accepted" - ) - - # Prepare headers for the upstream connection - upstream_headers = custom_headers.copy() - - if forward_headers: - # Forward relevant headers from the incoming request - incoming_headers = dict(websocket.headers) - for header_name, header_value in incoming_headers.items(): - # Only forward certain headers to avoid conflicts - if header_name.lower() in [ - "authorization", - "x-api-key", - "x-goog-user-project", - ]: - upstream_headers[header_name] = header_value - - # Initialize logging object similar to HTTP passthrough - logging_obj = Logging( - model="unknown", - messages=[{"role": "user", "content": "WebSocket connection"}], - stream=True, # WebSockets are inherently streaming - call_type="pass_through_endpoint", - start_time=start_time, - litellm_call_id=litellm_call_id, - function_id="websocket_passthrough", - ) - - # Create passthrough logging payload - passthrough_logging_payload = PassthroughStandardLoggingPayload( - url=target, - request_body={}, # WebSocket doesn't have a traditional request body - request_method="WEBSOCKET", - cost_per_request=cost_per_request, - ) - - # Create a dummy request object for WebSocket connections to maintain compatibility - # with the existing _init_kwargs_for_pass_through_endpoint function - class DummyRequest: - def __init__( - self, url: str, method: str = "WEBSOCKET", headers: Optional[dict] = None - ): - self.url = url - self.method = method - self.headers = headers or {} - - def __str__(self): - return f"DummyRequest(url={self.url}, method={self.method})" - - dummy_request = DummyRequest( - url=target, - method="WEBSOCKET", - headers=dict(websocket.headers) if hasattr(websocket, "headers") else {}, - ) - - # Initialize kwargs for logging using the same pattern as HTTP passthrough - kwargs = HttpPassThroughEndpointHelpers._init_kwargs_for_pass_through_endpoint( - user_api_key_dict=user_api_key_dict, - _parsed_body={}, # WebSocket doesn't have a traditional request body - passthrough_logging_payload=passthrough_logging_payload, - litellm_call_id=litellm_call_id, - request=dummy_request, # type: ignore - logging_obj=logging_obj, - ) - - # Update logging environment variables - logging_obj.update_environment_variables( - model="unknown", - user="unknown", - optional_params={}, - litellm_params=dict(kwargs.get("litellm_params", {})), - call_type="pass_through_endpoint", - ) - logging_obj.model_call_details["litellm_call_id"] = litellm_call_id - - # Pre-call logging - logging_obj.pre_call( - input=[{"role": "user", "content": "WebSocket connection"}], - api_key="", - additional_args={ - "complete_input_dict": {}, - "api_base": target, - "headers": upstream_headers, - }, - ) - - ### CALL HOOKS ### - modify incoming data / reject request before calling the model - websocket_data: dict[str, Any] = {} - websocket_data = await proxy_logging_obj.pre_call_hook( - user_api_key_dict=user_api_key_dict, - data=websocket_data, - call_type="pass_through_endpoint", - ) - - try: - verbose_proxy_logger.debug( - f"WebSocket passthrough ({endpoint}): Establishing upstream connection to {target}" - ) - async with connect( - target, - additional_headers=upstream_headers, - ) as upstream_ws: - verbose_proxy_logger.info( - f"WebSocket passthrough ({endpoint}): Upstream connection established successfully" - ) - - async def forward_client_to_upstream() -> None: - """Forward messages from client to upstream WebSocket""" - try: - while True: - message = await websocket.receive() - message_type = message.get("type") - if message_type == "websocket.disconnect": - await upstream_ws.close() - break - - text_data = message.get("text") - bytes_data = message.get("bytes") - - if text_data is not None: - # Try to extract model from client setup message for Vertex AI Live - if endpoint and "/vertex_ai/live" in endpoint: - verbose_proxy_logger.debug( - f"WebSocket passthrough ({endpoint}): Processing client message for model extraction" - ) - try: - client_message = json.loads(text_data) - if ( - isinstance(client_message, dict) - and "setup" in client_message - ): - setup_data = client_message["setup"] - verbose_proxy_logger.debug( - f"WebSocket passthrough ({endpoint}): Found setup data in client message: {setup_data}" - ) - if ( - isinstance(setup_data, dict) - and "model" in setup_data - ): - extracted_model = ( - _extract_model_from_vertex_ai_setup( - setup_data - ) - ) - if extracted_model: - kwargs["model"] = extracted_model - kwargs[ - "custom_llm_provider" - ] = "vertex_ai-language-models" - # Update logging object with correct model - logging_obj.model = extracted_model - logging_obj.model_call_details[ - "model" - ] = extracted_model - logging_obj.model_call_details[ - "custom_llm_provider" - ] = "vertex_ai" - verbose_proxy_logger.info( - f"WebSocket passthrough ({endpoint}): Successfully extracted model '{extracted_model}' and set provider to 'vertex_ai' from client setup message" - ) - else: - verbose_proxy_logger.warning( - f"WebSocket passthrough ({endpoint}): Failed to extract model from client setup data: {setup_data}" - ) - else: - verbose_proxy_logger.debug( - f"WebSocket passthrough ({endpoint}): Setup data does not contain model field: {setup_data}" - ) - else: - verbose_proxy_logger.debug( - f"WebSocket passthrough ({endpoint}): Client message does not contain setup data" - ) - except (json.JSONDecodeError, KeyError, TypeError) as e: - verbose_proxy_logger.debug( - f"WebSocket passthrough ({endpoint}): Client message is not a valid setup message: {e}" - ) - pass # Not a JSON message or doesn't contain setup data - - await upstream_ws.send(text_data) - elif bytes_data is not None: - await upstream_ws.send(bytes_data) - except asyncio.CancelledError: - raise - except Exception: - verbose_proxy_logger.exception( - f"WebSocket passthrough ({endpoint}): error forwarding client message" - ) - await upstream_ws.close() - - async def forward_upstream_to_client() -> None: - """Forward messages from upstream to client WebSocket""" - try: - # Wait for the first response from upstream - raw_response = await upstream_ws.recv(decode=False) - # Ensure raw_response is bytes before decoding - if isinstance(raw_response, str): - raw_response = raw_response.encode("ascii") - setup_response = json.loads(raw_response.decode("ascii")) - verbose_proxy_logger.debug(f"Setup response: {setup_response}") - - # Extract model and provider from setup response for Vertex AI Live - if endpoint and "/vertex_ai/live" in endpoint: - verbose_proxy_logger.debug( - f"WebSocket passthrough ({endpoint}): Processing server setup response for model extraction" - ) - extracted_model = _extract_model_from_vertex_ai_setup( - setup_response - ) - if extracted_model: - kwargs["model"] = extracted_model - kwargs["custom_llm_provider"] = "vertex_ai_language_models" - # Update logging object with correct model - logging_obj.model = extracted_model - logging_obj.model_call_details["model"] = extracted_model - logging_obj.model_call_details[ - "custom_llm_provider" - ] = "vertex_ai_language_models" - verbose_proxy_logger.debug( - f"WebSocket passthrough ({endpoint}): Successfully extracted model '{extracted_model}' and set provider to 'vertex_ai' from server setup response" - ) - else: - verbose_proxy_logger.warning( - f"WebSocket passthrough ({endpoint}): Failed to extract model from server setup response: {setup_response}" - ) - else: - verbose_proxy_logger.debug( - f"WebSocket passthrough ({endpoint}): Not a Vertex AI Live endpoint, skipping model extraction" - ) - - # Send the setup response to the client - await websocket.send_text(json.dumps(setup_response)) - - # Now continuously forward messages from upstream to client - async for upstream_message in upstream_ws: - if isinstance(upstream_message, bytes): - await websocket.send_bytes(upstream_message) - # Parse and collect for cost tracking - try: - message_data = json.loads(upstream_message.decode()) - websocket_messages.append(message_data) - except (json.JSONDecodeError, UnicodeDecodeError): - pass - else: - await websocket.send_text(upstream_message) - # Parse and collect for cost tracking - try: - message_data = json.loads(upstream_message) - websocket_messages.append(message_data) - except json.JSONDecodeError: - pass - - except (ConnectionClosedOK, ConnectionClosedError) as e: - verbose_proxy_logger.debug( - f"Upstream WebSocket connection closed: {e}" - ) - pass - except asyncio.CancelledError: - verbose_proxy_logger.debug( - "asyncio.CancelledError in forward_upstream_to_client" - ) - raise - except Exception as e: - verbose_proxy_logger.debug( - f"Exception in forward_upstream_to_client: {e}" - ) - verbose_proxy_logger.exception( - f"WebSocket passthrough ({endpoint}): error forwarding upstream message" - ) - raise - - # Create tasks for bidirectional message forwarding - tasks = [ - asyncio.create_task(forward_client_to_upstream()), - asyncio.create_task(forward_upstream_to_client()), - ] - - done, pending = await asyncio.wait( - tasks, return_when=asyncio.FIRST_COMPLETED - ) - - # Cancel remaining tasks - for task in pending: - task.cancel() - try: - await task - except asyncio.CancelledError: - pass - - # Check for exceptions in completed tasks - for task in done: - exception = task.exception() - if exception is not None: - raise exception - - end_time = datetime.now() - - # Update passthrough logging payload with response data - passthrough_logging_payload["response_body"] = websocket_messages # type: ignore - passthrough_logging_payload["end_time"] = end_time # type: ignore - - # Remove logging_obj from kwargs to avoid duplicate keyword argument - success_kwargs = kwargs.copy() - success_kwargs.pop("logging_obj", None) - - # # Add user authentication context for database logging - # if user_api_key_dict: - # success_kwargs.setdefault('litellm_params', {}) - # success_kwargs['litellm_params'].update({ - # 'proxy_server_request': { - # 'body': { - # 'user': user_api_key_dict.user_id, - # 'team_id': user_api_key_dict.team_id, - # 'end_user_id': user_api_key_dict.end_user_id, - # } - # } - # }) - # # Also add the user_api_key for direct access - # success_kwargs['user_api_key'] = user_api_key_dict.api_key - - # Create a dummy httpx.Response for WebSocket connections - class MockWebSocketResponse: - def __init__(self, target_url: str): - self.status_code = 200 - self.text = "WebSocket connection successful" - self.headers: dict[str, str] = {} - self.request = MockWebSocketRequest(target_url) - - class MockWebSocketRequest: - def __init__(self, target_url: str): - self.method = "WEBSOCKET" - self.url = target_url - - mock_response = MockWebSocketResponse(target) - - # Use the same success handler as HTTP passthrough endpoints - asyncio.create_task( - pass_through_endpoint_logging.pass_through_async_success_handler( - httpx_response=mock_response, # type: ignore - response_body=websocket_messages, # type: ignore - url_route=endpoint or "", - result="websocket_connection_successful", - start_time=start_time, - end_time=end_time, - logging_obj=logging_obj, - cache_hit=False, - request_body={}, - **success_kwargs, - ) - ) - - # Call the proxy logging success hook - if proxy_logging_obj: - await proxy_logging_obj.post_call_success_hook( - data={}, - user_api_key_dict=user_api_key_dict, - response={"status": "websocket_connection_successful"}, # type: ignore - ) - - except InvalidStatus as exc: - verbose_proxy_logger.exception( - f"WebSocket passthrough ({endpoint}): upstream rejected WebSocket connection" - ) - - # Prepare request payload for logging - request_payload = {} - if kwargs: - for key, value in kwargs.items(): - request_payload[key] = value - - # Log the connection failure using the same pattern as HTTP - await proxy_logging_obj.post_call_failure_hook( - user_api_key_dict=user_api_key_dict, - original_exception=exc, - request_data=request_payload, - traceback_str=traceback.format_exc( - limit=MAXIMUM_TRACEBACK_LINES_TO_LOG, - ), - ) - - if websocket.client_state != WebSocketState.DISCONNECTED: - await websocket.close( - code=getattr(exc, "status_code", 1011), - reason="Upstream connection rejected", - ) - except Exception as e: - verbose_proxy_logger.exception( - f"WebSocket passthrough ({endpoint}): unexpected error while proxying WebSocket" - ) - - # Prepare request payload for logging - request_payload = {} - if kwargs: - for key, value in kwargs.items(): - request_payload[key] = value - - # Log the unexpected error using the same pattern as HTTP - await proxy_logging_obj.post_call_failure_hook( - user_api_key_dict=user_api_key_dict, - original_exception=e, - request_data=request_payload, - traceback_str=traceback.format_exc( - limit=MAXIMUM_TRACEBACK_LINES_TO_LOG, - ), - ) - - if websocket.client_state != WebSocketState.DISCONNECTED: - await websocket.close(code=1011, reason="WebSocket passthrough error") - finally: - if websocket.client_state != WebSocketState.DISCONNECTED: - await websocket.close() - - -def _is_streaming_response(response: httpx.Response) -> bool: - _content_type = response.headers.get("content-type") - if _content_type is not None and "text/event-stream" in _content_type: - return True - return False - - -def _extract_model_from_vertex_ai_setup(setup_response: dict) -> Optional[str]: - """ - Extract the model name from Vertex AI Live setup response. - - The setup response can contain a model field in two formats: - 1. Direct: {"model": "projects/.../models/gemini-2.0-flash-live-preview-04-09"} - 2. Nested: {"setup": {"model": "projects/.../models/gemini-2.0-flash-live-preview-04-09"}} - - We extract just the model name: "gemini-2.0-flash-live-preview-04-09" - """ - try: - # Handle both direct model field and nested setup.model field - model_path = None - if isinstance(setup_response, dict): - if "model" in setup_response: - model_path = setup_response["model"] - elif ( - "setup" in setup_response - and isinstance(setup_response["setup"], dict) - and "model" in setup_response["setup"] - ): - model_path = setup_response["setup"]["model"] - - if isinstance(model_path, str) and "/models/" in model_path: - # Extract the model name after the last "/models/" - model_name = model_path.split("/models/")[-1] - return model_name - except Exception as e: - verbose_proxy_logger.debug(f"Error extracting model from setup response: {e}") - return None - - -class SafeRouteAdder: - """ - Wrapper class for adding routes to FastAPI app. - Only adds routes if they don't already exist on the app. - """ - - @staticmethod - def _is_path_registered(app: FastAPI, path: str, methods: List[str]) -> bool: - """ - Check if a path with any of the specified methods is already registered on the app. - - Args: - app: The FastAPI application instance - path: The path to check (e.g., "/v1/chat/completions") - methods: List of HTTP methods to check (e.g., ["GET", "POST"]) - - Returns: - True if the path is already registered with any of the methods, False otherwise - """ - for route in app.routes: - # Use getattr to safely access route attributes - route_path = getattr(route, "path", None) - route_methods = getattr(route, "methods", None) - - if route_path == path and route_methods is not None: - # Check if any of the methods overlap - if any(method in route_methods for method in methods): - return True - return False - - @staticmethod - def add_api_route_if_not_exists( - app: FastAPI, - path: str, - endpoint: Any, - methods: List[str], - dependencies: Optional[List] = None, - ) -> bool: - """ - Add an API route to the app only if it doesn't already exist. - - Args: - app: The FastAPI application instance - path: The path for the route - endpoint: The endpoint function/callable - methods: List of HTTP methods - dependencies: Optional list of dependencies - - Returns: - True if route was added, False if it already existed - """ - if SafeRouteAdder._is_path_registered(app=app, path=path, methods=methods): - verbose_proxy_logger.debug( - "Skipping route registration - path %s with methods %s already registered on app", - path, - methods, - ) - return False - - app.add_api_route( - path=path, - endpoint=endpoint, - methods=methods, - dependencies=dependencies, - ) - verbose_proxy_logger.debug( - "Successfully added route: %s with methods %s", - path, - methods, - ) - return True - - -class InitPassThroughEndpointHelpers: - @staticmethod - def add_exact_path_route( - app: FastAPI, - path: str, - target: str, - custom_headers: Optional[dict], - forward_headers: Optional[bool], - merge_query_params: Optional[bool], - dependencies: Optional[List], - cost_per_request: Optional[float], - endpoint_id: str, - guardrails: Optional[dict] = None, - ): - """Add exact path route for pass-through endpoint""" - route_key = f"{endpoint_id}:exact:{path}" - - # Check if this exact route is already registered - if route_key in _registered_pass_through_routes: - verbose_proxy_logger.debug( - "Updating duplicate exact pass through endpoint: %s (already registered)", - path, - ) - - verbose_proxy_logger.debug( - "adding exact pass through endpoint: %s, dependencies: %s", - path, - dependencies, - ) - - # Use SafeRouteAdder to only add route if it doesn't exist on the app - SafeRouteAdder.add_api_route_if_not_exists( - app=app, - path=path, - endpoint=create_pass_through_route( # type: ignore - path, - target, - custom_headers, - forward_headers, - merge_query_params, - dependencies, - cost_per_request=cost_per_request, - guardrails=guardrails, - ), - methods=["GET", "POST", "PUT", "DELETE", "PATCH"], - dependencies=dependencies, - ) - - # Always register/update the route metadata (headers, target) even if FastAPI route exists - _registered_pass_through_routes[route_key] = { - "endpoint_id": endpoint_id, - "path": path, - "type": "exact", - "passthrough_params": { - "target": target, - "custom_headers": custom_headers, - "forward_headers": forward_headers, - "merge_query_params": merge_query_params, - "dependencies": dependencies, - "cost_per_request": cost_per_request, - "guardrails": guardrails, - }, - } - - @staticmethod - def add_subpath_route( - app: FastAPI, - path: str, - target: str, - custom_headers: Optional[dict], - forward_headers: Optional[bool], - merge_query_params: Optional[bool], - dependencies: Optional[List], - cost_per_request: Optional[float], - endpoint_id: str, - guardrails: Optional[dict] = None, - ): - """Add wildcard route for sub-paths""" - wildcard_path = f"{path}/{{subpath:path}}" - route_key = f"{endpoint_id}:subpath:{path}" - - # Check if this subpath route is already registered - if route_key in _registered_pass_through_routes: - verbose_proxy_logger.debug( - "Updating duplicate wildcard pass through endpoint: %s (already registered)", - wildcard_path, - ) - - verbose_proxy_logger.debug( - "adding wildcard pass through endpoint: %s, dependencies: %s", - wildcard_path, - dependencies, - ) - - # Use SafeRouteAdder to only add route if it doesn't exist on the app - SafeRouteAdder.add_api_route_if_not_exists( - app=app, - path=wildcard_path, - endpoint=create_pass_through_route( # type: ignore - path, - target, - custom_headers, - forward_headers, - merge_query_params, - dependencies, - include_subpath=True, - cost_per_request=cost_per_request, - guardrails=guardrails, - ), - methods=["GET", "POST", "PUT", "DELETE", "PATCH"], - dependencies=dependencies, - ) - - # Register the route to prevent duplicates only if it was added - _registered_pass_through_routes[route_key] = { - "endpoint_id": endpoint_id, - "path": path, - "type": "subpath", - "passthrough_params": { - "target": target, - "custom_headers": custom_headers, - "forward_headers": forward_headers, - "merge_query_params": merge_query_params, - "dependencies": dependencies, - "cost_per_request": cost_per_request, - "guardrails": guardrails, - }, - } - - @staticmethod - def remove_endpoint_routes(endpoint_id: str): - """Remove all routes for a specific endpoint ID from the registry""" - keys_to_remove = [ - key - for key, value in _registered_pass_through_routes.items() - if value["endpoint_id"] == endpoint_id - ] - for key in keys_to_remove: - del _registered_pass_through_routes[key] - verbose_proxy_logger.debug( - "Removed pass-through route from registry: %s", key - ) - - @staticmethod - def clear_all_pass_through_routes(): - """Clear all pass-through routes from the registry""" - _registered_pass_through_routes.clear() - - @staticmethod - def get_all_registered_pass_through_routes() -> List[str]: - """Get all registered pass-through endpoints from the registry""" - return list(_registered_pass_through_routes.keys()) - - @staticmethod - def _build_full_path_with_root(path: str) -> str: - """ - Build full path by prepending server root path if needed. - - Args: - path: The relative path to build - - Returns: - Full path with server root prepended (if root is not "/") - """ - root_path = get_server_root_path() - if root_path == "/": - return path - return f"{root_path}{path}" - - @staticmethod - def is_registered_pass_through_route(route: str) -> bool: - """ - Check if route is a registered pass-through endpoint from DB - - Uses the in-memory registry to avoid additional DB queries - Optimized for minimal latency - - Args: - route: The route to check - - Returns: - bool: True if route is a registered pass-through endpoint, False otherwise - """ - ## CHECK IF MAPPED PASS THROUGH ENDPOINT - for mapped_route in LiteLLMRoutes.mapped_pass_through_routes.value: - if route.startswith(mapped_route): - return True - - # Fast path: check if any registered route key contains this path - # Keys are in format: "{endpoint_id}:exact:{path}" or "{endpoint_id}:subpath:{path}" - # Extract unique paths from keys for quick checking - for key in _registered_pass_through_routes.keys(): - parts = key.split(":", 2) # Split into [endpoint_id, type, path] - if len(parts) == 3: - route_type = parts[1] - registered_path = InitPassThroughEndpointHelpers._build_full_path_with_root( - parts[2] - ) - if route_type == "exact" and route == registered_path: - return True - elif route_type == "subpath": - if route == registered_path or route.startswith( - registered_path + "/" - ): - return True - - return False - - @staticmethod - def get_registered_pass_through_route(route: str) -> Optional[Dict[str, Any]]: - """Get passthrough params for a given route""" - for key in _registered_pass_through_routes.keys(): - parts = key.split(":", 2) # Split into [endpoint_id, type, path] - if len(parts) == 3: - route_type = parts[1] - registered_path = InitPassThroughEndpointHelpers._build_full_path_with_root( - parts[2] - ) - - if route_type == "exact" and route == registered_path: - return _registered_pass_through_routes[key] - elif route_type == "subpath": - if route == registered_path or route.startswith( - registered_path + "/" - ): - return _registered_pass_through_routes[key] - - return None - - -def _get_combined_pass_through_endpoints( - pass_through_endpoints: Union[List[Dict], List[PassThroughGenericEndpoint]], - config_pass_through_endpoints: List[Dict], -): - """Get combined pass-through endpoints from db + config""" - return pass_through_endpoints + config_pass_through_endpoints - - -async def initialize_pass_through_endpoints( - pass_through_endpoints: Union[List[Dict], List[PassThroughGenericEndpoint]], -): - """ - 1. Create a global list of pass-through endpoints (db + config) - 2. Clear all existing pass-through endpoints from the FastAPI app routes - 3. Add new endpoints to the in-memory registry - - Initialize a list of pass-through endpoints by adding them to the FastAPI app routes - - Args: - pass_through_endpoints: List of pass-through endpoints to initialize - - Returns: - None - """ - from litellm._uuid import uuid - - verbose_proxy_logger.debug("initializing pass through endpoints") - from litellm.proxy._types import CommonProxyErrors, LiteLLMRoutes - from litellm.proxy.proxy_server import ( - app, - config_passthrough_endpoints, - premium_user, - ) - - ## get combined pass-through endpoints from db + config - combined_pass_through_endpoints: List[Union[Dict, PassThroughGenericEndpoint]] - - if config_passthrough_endpoints is not None: - combined_pass_through_endpoints = _get_combined_pass_through_endpoints( # type: ignore - pass_through_endpoints, config_passthrough_endpoints - ) - else: - combined_pass_through_endpoints = pass_through_endpoints # type: ignore - - ## clear all existing pass-through endpoints from the FastAPI app routes - # InitPassThroughEndpointHelpers.clear_all_pass_through_routes() - - # get a list of all registered pass-through endpoints - # mark the ones that are visited in the list - # remove the ones that are not visited from the list - registered_pass_through_endpoints = ( - InitPassThroughEndpointHelpers.get_all_registered_pass_through_routes() - ) - - visited_endpoints = set() - - for endpoint in combined_pass_through_endpoints: - if isinstance(endpoint, PassThroughGenericEndpoint): - endpoint = endpoint.model_dump() - - # Auto-generate ID for backwards compatibility if not present - if endpoint.get("id") is None: - endpoint["id"] = str(uuid.uuid4()) - - # Get the endpoint_id as a string (guaranteed to be set at this point) - endpoint_id: str = endpoint["id"] - - _target = endpoint.get("target", None) - _path: Optional[str] = endpoint.get("path", None) - if _path is None: - raise ValueError("Path is required for pass-through endpoint") - _custom_headers = endpoint.get("headers", None) - _custom_headers = await set_env_variables_in_header( - custom_headers=_custom_headers - ) - _forward_headers = endpoint.get("forward_headers", None) - _merge_query_params = endpoint.get("merge_query_params", None) - _auth = endpoint.get("auth", None) - _dependencies = None - if _auth is not None and str(_auth).lower() == "true": - if premium_user is not True: - raise ValueError( - "Error Setting Authentication on Pass Through Endpoint: {}".format( - CommonProxyErrors.not_premium_user.value - ) - ) - _dependencies = [Depends(user_api_key_auth)] - LiteLLMRoutes.openai_routes.value.append(_path) - - if _target is None: - continue - - # Get guardrails config if present - _guardrails = endpoint.get("guardrails", None) - - # Add exact path route - verbose_proxy_logger.debug( - "Initializing pass through endpoint: %s (ID: %s)", _path, endpoint_id - ) - InitPassThroughEndpointHelpers.add_exact_path_route( - app=app, - path=_path, - target=_target, - custom_headers=_custom_headers, - forward_headers=_forward_headers, - merge_query_params=_merge_query_params, - dependencies=_dependencies, - cost_per_request=endpoint.get("cost_per_request", None), - endpoint_id=endpoint_id, - guardrails=_guardrails, - ) - - visited_endpoints.add(f"{endpoint_id}:exact:{_path}") - - # Add wildcard route for sub-paths - if endpoint.get("include_subpath", False) is True: - InitPassThroughEndpointHelpers.add_subpath_route( - app=app, - path=_path, - target=_target, - custom_headers=_custom_headers, - forward_headers=_forward_headers, - merge_query_params=_merge_query_params, - dependencies=_dependencies, - cost_per_request=endpoint.get("cost_per_request", None), - endpoint_id=endpoint_id, - guardrails=_guardrails, - ) - - visited_endpoints.add(f"{endpoint_id}:subpath:{_path}") - - verbose_proxy_logger.debug( - "Added new pass through endpoint: %s (ID: %s)", _path, endpoint_id - ) - - # remove the ones that are not visited from the list - for endpoint_key in registered_pass_through_endpoints: - if endpoint_key not in visited_endpoints: - InitPassThroughEndpointHelpers.remove_endpoint_routes(endpoint_key) - - -def _get_pass_through_endpoints_from_config() -> List[PassThroughGenericEndpoint]: - """ - Get pass-through endpoints defined in the config file. - These are read-only and cannot be edited via the UI. - Malformed endpoints are logged and skipped; they do not crash the function. - """ - from pydantic import ValidationError - - from litellm.proxy.proxy_server import config_passthrough_endpoints - - if config_passthrough_endpoints is None or len(config_passthrough_endpoints) == 0: - return [] - - returned_endpoints: List[PassThroughGenericEndpoint] = [] - for endpoint in config_passthrough_endpoints: - try: - if isinstance(endpoint, dict): - endpoint_dict = dict(endpoint) - endpoint_dict["is_from_config"] = True - returned_endpoints.append(PassThroughGenericEndpoint(**endpoint_dict)) - elif isinstance(endpoint, PassThroughGenericEndpoint): - # Create a copy with is_from_config=True - endpoint_dict = endpoint.model_dump() - endpoint_dict["is_from_config"] = True - returned_endpoints.append(PassThroughGenericEndpoint(**endpoint_dict)) - except ValidationError as e: - verbose_proxy_logger.warning( - "Skipping malformed pass-through endpoint from config: %s", - e, - exc_info=False, - ) - - return returned_endpoints - - -async def _get_pass_through_endpoints_from_db( - endpoint_id: Optional[str] = None, - user_api_key_dict: Optional[UserAPIKeyAuth] = None, -) -> List[PassThroughGenericEndpoint]: - from litellm.proxy._types import LitellmUserRoles - from litellm.proxy.proxy_server import get_config_general_settings - - try: - if user_api_key_dict is None: - user_api_key_dict = UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN) - response: ConfigFieldInfo = await get_config_general_settings( - field_name="pass_through_endpoints", user_api_key_dict=user_api_key_dict - ) - except Exception: - return [] - - pass_through_endpoint_data: Optional[List] = response.field_value - if pass_through_endpoint_data is None: - return [] - - returned_endpoints: List[PassThroughGenericEndpoint] = [] - if endpoint_id is None: - # Return all endpoints from DB, mark as not from config - for endpoint in pass_through_endpoint_data: - if isinstance(endpoint, dict): - endpoint_dict = dict(endpoint) - endpoint_dict["is_from_config"] = False - returned_endpoints.append(PassThroughGenericEndpoint(**endpoint_dict)) - elif isinstance(endpoint, PassThroughGenericEndpoint): - endpoint_dict = endpoint.model_dump() - endpoint_dict["is_from_config"] = False - returned_endpoints.append(PassThroughGenericEndpoint(**endpoint_dict)) - else: - # Find specific endpoint by ID - found_endpoint = _find_endpoint_by_id(pass_through_endpoint_data, endpoint_id) - if found_endpoint is not None: - endpoint_dict = ( - found_endpoint.model_dump() - if isinstance(found_endpoint, PassThroughGenericEndpoint) - else dict(found_endpoint) - ) - endpoint_dict["is_from_config"] = False - returned_endpoints.append(PassThroughGenericEndpoint(**endpoint_dict)) - - return returned_endpoints - - -async def _filter_endpoints_by_team_allowed_routes( - team_id: str, - pass_through_endpoints: List[PassThroughGenericEndpoint], - prisma_client, -) -> List[PassThroughGenericEndpoint]: - """ - Filter pass-through endpoints based on team's allowed_passthrough_routes metadata. - - Args: - team_id: The team ID to check permissions for - pass_through_endpoints: List of endpoints to filter - prisma_client: Database client - - Returns: - Filtered list of endpoints based on team permissions - - Raises: - HTTPException: If team is not found - """ - # retrieve team from db - team = await prisma_client.db.litellm_teamtable.find_unique( - where={"team_id": team_id}, - ) - if team is None: - raise HTTPException( - status_code=404, - detail={"error": "Team not found"}, - ) - - # retrieve team metadata - team_metadata = team.metadata - if ( - team_metadata is not None - and team_metadata.get("allowed_passthrough_routes") is not None - ): - ## FILTER pass_through_endpoints by allowed_passthrough_routes - pass_through_endpoints = [ - endpoint - for endpoint in pass_through_endpoints - if endpoint.path in team_metadata.get("allowed_passthrough_routes") - ] - - return pass_through_endpoints - - -@router.get( - "/config/pass_through_endpoint", - dependencies=[Depends(user_api_key_auth)], - response_model=PassThroughEndpointResponse, -) -@router.get( - "/config/pass_through_endpoint/team/{team_id}", - dependencies=[Depends(user_api_key_auth)], - response_model=PassThroughEndpointResponse, -) -async def get_pass_through_endpoints( - endpoint_id: Optional[str] = None, - user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), - team_id: Optional[str] = None, -): - """ - GET configured pass through endpoint. - - If no endpoint_id given, return all configured endpoints. - """ ## Get existing pass-through endpoint field value - from litellm.proxy._types import CommonProxyErrors - from litellm.proxy.proxy_server import prisma_client - - if prisma_client is None: - raise HTTPException( - status_code=500, - detail={"error": CommonProxyErrors.db_not_connected_error.value}, - ) - - # Get endpoints from DB (editable via UI) - db_endpoints = await _get_pass_through_endpoints_from_db( - endpoint_id=endpoint_id, user_api_key_dict=user_api_key_dict - ) - - # Get endpoints from config file (read-only, not editable via UI) - config_endpoints = _get_pass_through_endpoints_from_config() - - # Merge: config endpoints not in DB + all DB endpoints (DB overrides config for same path) - db_paths = {ep.path for ep in db_endpoints} - config_only_endpoints = [ - ep for ep in config_endpoints if ep.path not in db_paths - ] - if endpoint_id is not None: - # When filtering by endpoint_id, only return if found in DB (config endpoints use generated IDs) - pass_through_endpoints = db_endpoints - else: - pass_through_endpoints = config_only_endpoints + db_endpoints - - if team_id is not None: - pass_through_endpoints = await _filter_endpoints_by_team_allowed_routes( - team_id=team_id, - pass_through_endpoints=pass_through_endpoints, - prisma_client=prisma_client, - ) - - return PassThroughEndpointResponse(endpoints=pass_through_endpoints) - - -@router.post( - "/config/pass_through_endpoint/{endpoint_id}", - dependencies=[Depends(user_api_key_auth)], -) -async def update_pass_through_endpoints( - endpoint_id: str, - data: PassThroughGenericEndpoint, - request: Request, - user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), -): - """ - Update a pass-through endpoint by ID. - """ - from litellm.proxy.proxy_server import ( - get_config_general_settings, - update_config_general_settings, - ) - - ## Get existing pass-through endpoint field value - try: - response: ConfigFieldInfo = await get_config_general_settings( - field_name="pass_through_endpoints", user_api_key_dict=user_api_key_dict - ) - except Exception: - raise HTTPException( - status_code=404, - detail={"error": "No pass-through endpoints found"}, - ) - - pass_through_endpoint_data: Optional[List] = response.field_value - if pass_through_endpoint_data is None: - raise HTTPException( - status_code=404, - detail={"error": "No pass-through endpoints found"}, - ) - - # Find the endpoint to update - found_endpoint = _find_endpoint_by_id(pass_through_endpoint_data, endpoint_id) - - if found_endpoint is None: - raise HTTPException( - status_code=404, - detail={"error": f"Endpoint with ID '{endpoint_id}' not found"}, - ) - - # Find the index for updating the list - endpoint_index = None - for idx, endpoint in enumerate(pass_through_endpoint_data): - _endpoint = ( - PassThroughGenericEndpoint(**endpoint) - if isinstance(endpoint, dict) - else endpoint - ) - if _endpoint.id == endpoint_id: - endpoint_index = idx - break - - if endpoint_index is None: - raise HTTPException( - status_code=404, - detail={ - "error": f"Could not find index for endpoint with ID '{endpoint_id}'" - }, - ) - - # Get the update data as dict, excluding None values for partial updates - # Exclude is_from_config as it's a response-only field (computed at read time) - update_data = data.model_dump(exclude_none=True, exclude={"is_from_config"}) - - # Start with existing endpoint data - endpoint_dict = found_endpoint.model_dump() - - # Update with new data (only non-None values) - endpoint_dict.update(update_data) - - # Preserve existing ID if not provided in update and endpoint has ID - if "id" not in update_data and found_endpoint.id is not None: - endpoint_dict["id"] = found_endpoint.id - - # Remove is_from_config before saving - it's a response-only field (computed at read time) - endpoint_dict.pop("is_from_config", None) - - # Create updated endpoint object - updated_endpoint = PassThroughGenericEndpoint(**endpoint_dict) - - # Update the list - pass_through_endpoint_data[endpoint_index] = endpoint_dict - - # Remove old routes from registry before they get re-registered - InitPassThroughEndpointHelpers.remove_endpoint_routes(endpoint_id) - - ## Update db - updated_data = ConfigFieldUpdate( - field_name="pass_through_endpoints", - field_value=pass_through_endpoint_data, - config_type="general_settings", - ) - - await update_config_general_settings( - data=updated_data, user_api_key_dict=user_api_key_dict - ) - - # Re-register the route with updated headers - _custom_headers: Optional[dict] = updated_endpoint.headers or {} - _custom_headers = await set_env_variables_in_header(custom_headers=_custom_headers) - - if updated_endpoint.include_subpath: - InitPassThroughEndpointHelpers.add_subpath_route( - app=request.app, - path=updated_endpoint.path, - target=updated_endpoint.target, - custom_headers=_custom_headers, - forward_headers=None, # Defaults not available in model? assuming None logic handles it - merge_query_params=None, - dependencies=None, - cost_per_request=updated_endpoint.cost_per_request, - endpoint_id=updated_endpoint.id or endpoint_id or "", - guardrails=getattr(updated_endpoint, "guardrails", None), - ) - else: - InitPassThroughEndpointHelpers.add_exact_path_route( - app=request.app, - path=updated_endpoint.path, - target=updated_endpoint.target, - custom_headers=_custom_headers, - forward_headers=None, - merge_query_params=None, - dependencies=None, - cost_per_request=updated_endpoint.cost_per_request, - endpoint_id=updated_endpoint.id or endpoint_id or "", - guardrails=getattr(updated_endpoint, "guardrails", None), - ) - - return PassThroughEndpointResponse( - endpoints=[updated_endpoint] if updated_endpoint else [] - ) - - -@router.post( - "/config/pass_through_endpoint", - dependencies=[Depends(user_api_key_auth)], -) -async def create_pass_through_endpoints( - data: PassThroughGenericEndpoint, - request: Request, - user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), -): - """ - Create new pass-through endpoint - """ - from litellm._uuid import uuid - from litellm.proxy.proxy_server import ( - get_config_general_settings, - update_config_general_settings, - ) - - ## Get existing pass-through endpoint field value - - try: - response: ConfigFieldInfo = await get_config_general_settings( - field_name="pass_through_endpoints", user_api_key_dict=user_api_key_dict - ) - except Exception: - response = ConfigFieldInfo( - field_name="pass_through_endpoints", field_value=None - ) - - ## Auto-generate ID if not provided - # Exclude is_from_config as it's a response-only field (computed at read time) - data_dict = data.model_dump(exclude={"is_from_config"}) - if data_dict.get("id") is None: - data_dict["id"] = str(uuid.uuid4()) - - if response.field_value is None: - response.field_value = [data_dict] - elif isinstance(response.field_value, List): - response.field_value.append(data_dict) - - ## Update db - updated_data = ConfigFieldUpdate( - field_name="pass_through_endpoints", - field_value=response.field_value, - config_type="general_settings", - ) - await update_config_general_settings( - data=updated_data, user_api_key_dict=user_api_key_dict - ) - - # Return the created endpoint with the generated ID - created_endpoint = PassThroughGenericEndpoint(**data_dict) - - # Register the new route - _custom_headers: Optional[dict] = created_endpoint.headers or {} - _custom_headers = await set_env_variables_in_header(custom_headers=_custom_headers) - - if created_endpoint.include_subpath: - InitPassThroughEndpointHelpers.add_subpath_route( - app=request.app, - path=created_endpoint.path, - target=created_endpoint.target, - custom_headers=_custom_headers, - forward_headers=None, - merge_query_params=None, - dependencies=None, - cost_per_request=created_endpoint.cost_per_request, - endpoint_id=created_endpoint.id or "", - guardrails=getattr(created_endpoint, "guardrails", None), - ) - else: - InitPassThroughEndpointHelpers.add_exact_path_route( - app=request.app, - path=created_endpoint.path, - target=created_endpoint.target, - custom_headers=_custom_headers, - forward_headers=None, - merge_query_params=None, - dependencies=None, - cost_per_request=created_endpoint.cost_per_request, - endpoint_id=created_endpoint.id or "", - guardrails=getattr(created_endpoint, "guardrails", None), - ) - - return PassThroughEndpointResponse(endpoints=[created_endpoint]) - - -@router.delete( - "/config/pass_through_endpoint", - dependencies=[Depends(user_api_key_auth)], - response_model=PassThroughEndpointResponse, -) -async def delete_pass_through_endpoints( - endpoint_id: str, - user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), -): - """ - Delete a pass-through endpoint by ID. - - Returns - the deleted endpoint - """ - from litellm.proxy.proxy_server import ( - get_config_general_settings, - update_config_general_settings, - ) - - ## Get existing pass-through endpoint field value - - try: - response: ConfigFieldInfo = await get_config_general_settings( - field_name="pass_through_endpoints", user_api_key_dict=user_api_key_dict - ) - except Exception: - response = ConfigFieldInfo( - field_name="pass_through_endpoints", field_value=None - ) - - ## Update field by removing endpoint - pass_through_endpoint_data: Optional[List] = response.field_value - if response.field_value is None or pass_through_endpoint_data is None: - raise HTTPException( - status_code=400, - detail={"error": "There are no pass-through endpoints setup."}, - ) - - # Find the endpoint to delete - found_endpoint = _find_endpoint_by_id(pass_through_endpoint_data, endpoint_id) - - if found_endpoint is None: - raise HTTPException( - status_code=400, - detail={ - "error": "Endpoint with ID '{}' was not found in pass-through endpoint list.".format( - endpoint_id - ) - }, - ) - - # Find the index for deleting from the list - endpoint_index = None - for idx, endpoint in enumerate(pass_through_endpoint_data): - _endpoint = ( - PassThroughGenericEndpoint(**endpoint) - if isinstance(endpoint, dict) - else endpoint - ) - if _endpoint.id == endpoint_id: - endpoint_index = idx - break - - if endpoint_index is None: - raise HTTPException( - status_code=400, - detail={ - "error": f"Could not find index for endpoint with ID '{endpoint_id}'" - }, - ) - - # Remove the endpoint - pass_through_endpoint_data.pop(endpoint_index) - response_obj = found_endpoint - - # Remove routes from registry - InitPassThroughEndpointHelpers.remove_endpoint_routes(endpoint_id) - - ## Update db - updated_data = ConfigFieldUpdate( - field_name="pass_through_endpoints", - field_value=pass_through_endpoint_data, - config_type="general_settings", - ) - await update_config_general_settings( - data=updated_data, user_api_key_dict=user_api_key_dict - ) - - return PassThroughEndpointResponse(endpoints=[response_obj]) - - -def _find_endpoint_by_id( - endpoints_data: List, - endpoint_id: str, -) -> Optional[PassThroughGenericEndpoint]: - """ - Find an endpoint by ID. - - Args: - endpoints_data: List of endpoint data (dicts or PassThroughGenericEndpoint objects) - endpoint_id: ID to search for - - Returns: - Found endpoint or None if not found - """ - for endpoint in endpoints_data: - _endpoint: Optional[PassThroughGenericEndpoint] = None - if isinstance(endpoint, dict): - _endpoint = PassThroughGenericEndpoint(**endpoint) - elif isinstance(endpoint, PassThroughGenericEndpoint): - _endpoint = endpoint - - # Only compare IDs to IDs - if _endpoint is not None and _endpoint.id == endpoint_id: - return _endpoint - - return None - - -async def initialize_pass_through_endpoints_in_db(): - """ - Gets all pass-through endpoints from db and initializes them in the proxy server. - """ - pass_through_endpoints = await _get_pass_through_endpoints_from_db() - await initialize_pass_through_endpoints( - pass_through_endpoints=pass_through_endpoints - ) +import ast +import asyncio +import copy +import json +import traceback +from base64 import b64encode +from datetime import datetime +from typing import Any, Dict, List, Optional, Tuple, Union, cast +from urllib.parse import urlencode, urlparse + +import httpx +from fastapi import ( + APIRouter, + Depends, + FastAPI, + HTTPException, + Request, + Response, + UploadFile, + WebSocket, + status, +) +from fastapi.responses import StreamingResponse +from starlette.datastructures import UploadFile as StarletteUploadFile +from starlette.websockets import WebSocketState +from websockets.asyncio.client import connect +from websockets.exceptions import ( + ConnectionClosedError, + ConnectionClosedOK, + InvalidStatus, +) + +import litellm +from litellm._logging import verbose_proxy_logger +from litellm._uuid import uuid +from litellm.constants import MAXIMUM_TRACEBACK_LINES_TO_LOG +from litellm.integrations.custom_logger import CustomLogger +from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj +from litellm.litellm_core_utils.safe_json_dumps import safe_dumps +from litellm.llms.custom_httpx.http_handler import get_async_httpx_client +from litellm.passthrough import BasePassthroughUtils +from litellm.proxy._types import ( + ConfigFieldInfo, + ConfigFieldUpdate, + LiteLLMRoutes, + PassThroughEndpointResponse, + PassThroughGenericEndpoint, + ProxyException, + UserAPIKeyAuth, +) +from litellm.proxy.auth.user_api_key_auth import user_api_key_auth +from litellm.proxy.common_request_processing import ProxyBaseLLMRequestProcessing +from litellm.proxy.common_utils.http_parsing_utils import _read_request_body +from litellm.proxy.utils import get_server_root_path +from litellm.secret_managers.main import get_secret_str +from litellm.types.llms.custom_http import httpxSpecialProvider +from litellm.types.passthrough_endpoints.pass_through_endpoints import ( + EndpointType, + PassthroughStandardLoggingPayload, +) +from litellm.types.utils import StandardLoggingUserAPIKeyMetadata + +from .streaming_handler import PassThroughStreamingHandler +from .success_handler import PassThroughEndpointLogging + +router = APIRouter() + +pass_through_endpoint_logging = PassThroughEndpointLogging() + +# Global registry to track registered pass-through routes and prevent memory leaks +_registered_pass_through_routes: Dict[str, Dict[str, Union[str, Dict[str, Any]]]] = {} + + +def get_response_body(response: httpx.Response) -> Optional[dict]: + try: + return response.json() + except Exception: + return None + + +async def set_env_variables_in_header(custom_headers: Optional[dict]) -> Optional[dict]: + """ + checks if any headers on config.yaml are defined as os.environ/COHERE_API_KEY etc + + only runs for headers defined on config.yaml + + example header can be + + {"Authorization": "Bearer os.environ/COHERE_API_KEY"} + """ + if custom_headers is None: + return None + headers = {} + for key, value in custom_headers.items(): + # langfuse Api requires base64 encoded headers - it's simpleer to just ask litellm users to set their langfuse public and secret keys + # we can then get the b64 encoded keys here + if key == "LANGFUSE_PUBLIC_KEY" or key == "LANGFUSE_SECRET_KEY": + # langfuse requires b64 encoded headers - we construct that here + _langfuse_public_key = custom_headers["LANGFUSE_PUBLIC_KEY"] + _langfuse_secret_key = custom_headers["LANGFUSE_SECRET_KEY"] + if isinstance( + _langfuse_public_key, str + ) and _langfuse_public_key.startswith("os.environ/"): + _langfuse_public_key = get_secret_str(_langfuse_public_key) + if isinstance( + _langfuse_secret_key, str + ) and _langfuse_secret_key.startswith("os.environ/"): + _langfuse_secret_key = get_secret_str(_langfuse_secret_key) + headers["Authorization"] = "Basic " + b64encode( + f"{_langfuse_public_key}:{_langfuse_secret_key}".encode("utf-8") + ).decode("ascii") + else: + # for all other headers + headers[key] = value + if isinstance(value, str) and "os.environ/" in value: + verbose_proxy_logger.debug( + "pass through endpoint - looking up 'os.environ/' variable" + ) + # get string section that is os.environ/ + start_index = value.find("os.environ/") + _variable_name = value[start_index:] + + verbose_proxy_logger.debug( + "pass through endpoint - getting secret for variable name: %s", + _variable_name, + ) + _secret_value = get_secret_str(_variable_name) + if _secret_value is not None: + new_value = value.replace(_variable_name, _secret_value) + headers[key] = new_value + return headers + + +async def chat_completion_pass_through_endpoint( # noqa: PLR0915 + fastapi_response: Response, + request: Request, + adapter_id: str, + user_api_key_dict: UserAPIKeyAuth, +): + from litellm.proxy.proxy_server import ( + add_litellm_data_to_request, + general_settings, + llm_router, + proxy_config, + proxy_logging_obj, + user_api_base, + user_max_tokens, + user_model, + user_request_timeout, + user_temperature, + version, + ) + + data = {} + try: + body = await request.body() + body_str = body.decode() + try: + data = ast.literal_eval(body_str) + except Exception: + data = json.loads(body_str) + + data["adapter_id"] = adapter_id + + verbose_proxy_logger.debug( + "Request received by LiteLLM:\n{}".format(json.dumps(data, indent=4)), + ) + data["model"] = ( + general_settings.get("completion_model", None) # server default + or user_model # model name passed via cli args + or data.get("model", None) # default passed in http request + ) + if user_model: + data["model"] = user_model + + data = await add_litellm_data_to_request( + data=data, # type: ignore + request=request, + general_settings=general_settings, + user_api_key_dict=user_api_key_dict, + version=version, + proxy_config=proxy_config, + ) + + # override with user settings, these are params passed via cli + if user_temperature: + data["temperature"] = user_temperature + if user_request_timeout: + data["request_timeout"] = user_request_timeout + if user_max_tokens: + data["max_tokens"] = user_max_tokens + if user_api_base: + data["api_base"] = user_api_base + + ### MODEL ALIAS MAPPING ### + # check if model name in model alias map + # get the actual model name + if data["model"] in litellm.model_alias_map: + data["model"] = litellm.model_alias_map[data["model"]] + + # Check key-specific aliases + if ( + isinstance(data["model"], str) + and user_api_key_dict.aliases + and isinstance(user_api_key_dict.aliases, dict) + and data["model"] in user_api_key_dict.aliases + ): + data["model"] = user_api_key_dict.aliases[data["model"]] + + ### CALL HOOKS ### - modify incoming data before calling the model + data = await proxy_logging_obj.pre_call_hook( # type: ignore + user_api_key_dict=user_api_key_dict, data=data, call_type="text_completion" + ) + + ### ROUTE THE REQUESTs ### + router_model_names = llm_router.model_names if llm_router is not None else [] + # skip router if user passed their key + if "api_key" in data: + llm_response = asyncio.create_task(litellm.aadapter_completion(**data)) + elif ( + llm_router is not None and data["model"] in router_model_names + ): # model in router model list + llm_response = asyncio.create_task(llm_router.aadapter_completion(**data)) + elif ( + llm_router is not None + and llm_router.model_group_alias is not None + and data["model"] in llm_router.model_group_alias + ): # model set in model_group_alias + llm_response = asyncio.create_task(llm_router.aadapter_completion(**data)) + elif llm_router is not None and llm_router.has_model_id( + data["model"] + ): # model in router model list + llm_response = asyncio.create_task(llm_router.aadapter_completion(**data)) + elif ( + llm_router is not None + and data["model"] not in router_model_names + and ( + llm_router.default_deployment is not None + or len(llm_router.pattern_router.patterns) > 0 + ) + ): # check for wildcard routes or default deployment before checking deployment_names + llm_response = asyncio.create_task(llm_router.aadapter_completion(**data)) + elif ( + llm_router is not None and data["model"] in llm_router.deployment_names + ): # model in router deployments, calling a specific deployment on the router (lowest priority) + llm_response = asyncio.create_task( + llm_router.aadapter_completion(**data, specific_deployment=True) + ) + elif user_model is not None: # `litellm --model ` + llm_response = asyncio.create_task(litellm.aadapter_completion(**data)) + else: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail={ + "error": "completion: Invalid model name passed in model=" + + data.get("model", "") + }, + ) + + # Await the llm_response task + response = await llm_response + + hidden_params = getattr(response, "_hidden_params", {}) or {} + model_id = hidden_params.get("model_id", None) or "" + cache_key = hidden_params.get("cache_key", None) or "" + api_base = hidden_params.get("api_base", None) or "" + response_cost = hidden_params.get("response_cost", None) or "" + + ### ALERTING ### + asyncio.create_task( + proxy_logging_obj.update_request_status( + litellm_call_id=data.get("litellm_call_id", ""), status="success" + ) + ) + + verbose_proxy_logger.debug("final response: %s", response) + + fastapi_response.headers.update( + ProxyBaseLLMRequestProcessing.get_custom_headers( + user_api_key_dict=user_api_key_dict, + model_id=model_id, + cache_key=cache_key, + api_base=api_base, + version=version, + response_cost=response_cost, + ) + ) + + verbose_proxy_logger.debug("\nResponse from Litellm:\n{}".format(response)) + return response + except Exception as e: + await proxy_logging_obj.post_call_failure_hook( + user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data + ) + verbose_proxy_logger.exception( + "litellm.proxy.proxy_server.completion(): Exception occured - {}".format( + str(e) + ) + ) + error_msg = f"{str(e)}" + raise ProxyException( + message=getattr(e, "message", error_msg), + type=getattr(e, "type", "None"), + param=getattr(e, "param", "None"), + code=getattr(e, "status_code", 500), + ) + + +class HttpPassThroughEndpointHelpers(BasePassthroughUtils): + @staticmethod + def get_response_headers( + headers: httpx.Headers, + litellm_call_id: Optional[str] = None, + custom_headers: Optional[dict] = None, + ) -> dict: + excluded_headers = {"transfer-encoding", "content-encoding"} + + return_headers = { + key: value + for key, value in headers.items() + if key.lower() not in excluded_headers + } + if litellm_call_id: + return_headers["x-litellm-call-id"] = litellm_call_id + if custom_headers: + return_headers.update(custom_headers) + + return return_headers + + @staticmethod + def get_endpoint_type(url: str) -> EndpointType: + parsed_url = urlparse(url) + if ( + ("generateContent") in url + or ("streamGenerateContent") in url + or ("rawPredict") in url + or ("streamRawPredict") in url + ): + return EndpointType.VERTEX_AI + elif parsed_url.hostname == "api.anthropic.com": + return EndpointType.ANTHROPIC + elif ( + parsed_url.hostname == "api.openai.com" + or parsed_url.hostname == "openai.azure.com" + or (parsed_url.hostname and "openai.com" in parsed_url.hostname) + ): + return EndpointType.OPENAI + return EndpointType.GENERIC + + @staticmethod + async def _make_non_streaming_http_request( + request: Request, + async_client: httpx.AsyncClient, + url: str, + headers: dict, + requested_query_params: Optional[dict] = None, + custom_body: Optional[dict] = None, + ) -> httpx.Response: + """ + Make a non-streaming HTTP request + + If request is GET, don't include a JSON body + """ + if request.method == "GET": + response = await async_client.request( + method=request.method, + url=url, + headers=headers, + params=requested_query_params, + ) + else: + response = await async_client.request( + method=request.method, + url=url, + headers=headers, + params=requested_query_params, + json=custom_body, + ) + return response + + @staticmethod + async def non_streaming_http_request_handler( + request: Request, + async_client: httpx.AsyncClient, + url: httpx.URL, + headers: dict, + requested_query_params: Optional[dict] = None, + _parsed_body: Optional[dict] = None, + ) -> httpx.Response: + """ + Handle non-streaming HTTP requests + + Handles special cases when GET requests, multipart/form-data requests, and generic httpx requests + """ + if request.method == "GET": + response = await async_client.request( + method=request.method, + url=url, + headers=headers, + params=requested_query_params, + ) + elif HttpPassThroughEndpointHelpers.is_multipart(request) is True: + return await HttpPassThroughEndpointHelpers.make_multipart_http_request( + request=request, + async_client=async_client, + url=url, + headers=headers, + requested_query_params=requested_query_params, + ) + else: + # Generic httpx method + response = await async_client.request( + method=request.method, + url=url, + headers=headers, + params=requested_query_params, + json=_parsed_body, + ) + return response + + @staticmethod + def is_multipart(request: Request) -> bool: + """Check if the request is a multipart/form-data request""" + return "multipart/form-data" in request.headers.get("content-type", "") + + @staticmethod + async def _build_request_files_from_upload_file( + upload_file: Union[UploadFile, StarletteUploadFile], + ) -> Tuple[Optional[str], bytes, Optional[str]]: + """Build a request files dict from an UploadFile object""" + file_content = await upload_file.read() + return (upload_file.filename, file_content, upload_file.content_type) + + @staticmethod + async def make_multipart_http_request( + request: Request, + async_client: httpx.AsyncClient, + url: httpx.URL, + headers: dict, + requested_query_params: Optional[dict] = None, + ) -> httpx.Response: + """Process multipart/form-data requests, handling both files and form fields""" + form_data = await request.form() + files = {} + form_data_dict = {} + + for field_name, field_value in form_data.items(): + if isinstance(field_value, (StarletteUploadFile, UploadFile)): + files[ + field_name + ] = await HttpPassThroughEndpointHelpers._build_request_files_from_upload_file( + upload_file=field_value + ) + else: + form_data_dict[field_name] = field_value + + # Remove content-type header - httpx will set it correctly with the new boundary + # when it creates the multipart body from files/data parameters + headers_copy = headers.copy() + headers_copy.pop("content-type", None) + + response = await async_client.request( + method=request.method, + url=url, + headers=headers_copy, + params=requested_query_params, + files=files, + data=form_data_dict, + ) + return response + + @staticmethod + def _init_kwargs_for_pass_through_endpoint( + request: Request, + user_api_key_dict: UserAPIKeyAuth, + passthrough_logging_payload: PassthroughStandardLoggingPayload, + logging_obj: LiteLLMLoggingObj, + _parsed_body: Optional[dict] = None, + litellm_call_id: Optional[str] = None, + ) -> dict: + """ + Filter out litellm params from the request body + """ + from litellm.types.utils import all_litellm_params + + _parsed_body = _parsed_body or {} + + litellm_params_in_body = {} + for k in all_litellm_params: + if k in _parsed_body: + litellm_params_in_body[k] = _parsed_body.pop(k, None) + + _metadata = dict( + StandardLoggingUserAPIKeyMetadata( + user_api_key_hash=user_api_key_dict.api_key, + user_api_key_alias=user_api_key_dict.key_alias, + user_api_key_user_email=user_api_key_dict.user_email, + user_api_key_user_id=user_api_key_dict.user_id, + user_api_key_team_id=user_api_key_dict.team_id, + user_api_key_org_id=user_api_key_dict.org_id, + user_api_key_project_id=user_api_key_dict.project_id, + user_api_key_team_alias=user_api_key_dict.team_alias, + user_api_key_end_user_id=user_api_key_dict.end_user_id, + user_api_key_request_route=user_api_key_dict.request_route, + user_api_key_spend=user_api_key_dict.spend, + user_api_key_max_budget=user_api_key_dict.max_budget, + user_api_key_budget_reset_at=( + user_api_key_dict.budget_reset_at.isoformat() + if user_api_key_dict.budget_reset_at + else None + ), + user_api_key_auth_metadata=user_api_key_dict.metadata, + ) + ) + + _metadata["user_api_key"] = user_api_key_dict.api_key + + litellm_metadata = litellm_params_in_body.pop("litellm_metadata", None) + metadata = litellm_params_in_body.pop("metadata", None) + if litellm_metadata: + _metadata.update(litellm_metadata) + if metadata: + _metadata.update(metadata) + + _metadata = _update_metadata_with_tags_in_header( + request=request, + metadata=_metadata, + ) + + kwargs = { + "litellm_params": { + **litellm_params_in_body, # type: ignore + "metadata": _metadata, + "proxy_server_request": { + "url": str(request.url), + "method": request.method, + "body": copy.copy(_parsed_body), # use copy instead of deepcopy + "headers": request.headers, + }, + }, + "call_type": "pass_through_endpoint", + "litellm_call_id": litellm_call_id, + "passthrough_logging_payload": passthrough_logging_payload, + } + + logging_obj.model_call_details[ + "passthrough_logging_payload" + ] = passthrough_logging_payload + + return kwargs + + @staticmethod + def construct_target_url_with_subpath( + base_target: str, subpath: str, include_subpath: Optional[bool] + ) -> str: + """ + Helper function to construct the full target URL with subpath handling. + + Args: + base_target: The base target URL + subpath: The captured subpath from the request + include_subpath: Whether to include the subpath in the target URL + + Returns: + The constructed full target URL + """ + if not include_subpath: + return base_target + + if not subpath: + return base_target + + # Ensure base_target ends with / and subpath doesn't start with / + if not base_target.endswith("/"): + base_target = base_target + "/" + if subpath.startswith("/"): + subpath = subpath[1:] + + return base_target + subpath + + @staticmethod + def _update_stream_param_based_on_request_body( + parsed_body: dict, + stream: Optional[bool] = None, + ) -> Optional[bool]: + """ + If stream is provided in the request body, use it. + Otherwise, use the stream parameter passed to the `pass_through_request` function + """ + if "stream" in parsed_body: + return parsed_body.get("stream", stream) + return stream + + +async def pass_through_request( # noqa: PLR0915 + request: Request, + target: str, + custom_headers: dict, + user_api_key_dict: UserAPIKeyAuth, + custom_body: Optional[dict] = None, + forward_headers: Optional[bool] = False, + merge_query_params: Optional[bool] = False, + query_params: Optional[dict] = None, + stream: Optional[bool] = None, + cost_per_request: Optional[float] = None, + custom_llm_provider: Optional[str] = None, + guardrails_config: Optional[dict] = None, +): + """ + Pass through endpoint handler, makes the httpx request for pass-through endpoints and ensures logging hooks are called + + Args: + request: The incoming request + target: The target URL + custom_headers: The custom headers + user_api_key_dict: The user API key dictionary + custom_body: The custom body + forward_headers: Whether to forward headers + merge_query_params: Whether to merge query params + query_params: The query params + stream: Whether to stream the response + cost_per_request: Optional field - cost per request to the target endpoint + custom_llm_provider: Optional field - custom LLM provider for the endpoint + guardrails_config: Optional field - guardrails configuration for passthrough endpoint + """ + from litellm.litellm_core_utils.litellm_logging import Logging + from litellm.proxy.pass_through_endpoints.passthrough_guardrails import ( + PassthroughGuardrailHandler, + ) + from litellm.proxy.proxy_server import proxy_logging_obj + + ######################################################### + # Initialize variables + ######################################################### + litellm_call_id = str(uuid.uuid4()) + url: Optional[httpx.URL] = None + + # parsed request body + _parsed_body: Optional[dict] = None + # kwargs for pass through endpoint, contains metadata, litellm_params, call_type, litellm_call_id, passthrough_logging_payload + kwargs: Optional[dict] = None + + ######################################################### + try: + url = httpx.URL(target) + headers = custom_headers + headers = HttpPassThroughEndpointHelpers.forward_headers_from_request( + request_headers=dict(request.headers), + headers=headers, + forward_headers=forward_headers, + ) + + if merge_query_params: + # Create a new URL with the merged query params + url = url.copy_with( + query=urlencode( + HttpPassThroughEndpointHelpers.get_merged_query_parameters( + existing_url=url, + request_query_params=dict(request.query_params), + ) + ).encode("ascii") + ) + + endpoint_type: EndpointType = HttpPassThroughEndpointHelpers.get_endpoint_type( + str(url) + ) + + if custom_body: + _parsed_body = custom_body + else: + _parsed_body = await _read_request_body(request) + verbose_proxy_logger.debug( + "Pass through endpoint sending request to \nURL {}\nheaders: {}\nbody: {}\n".format( + url, headers, _parsed_body + ) + ) + + ### COLLECT GUARDRAILS FOR PASSTHROUGH ENDPOINT ### + # Passthrough endpoints are opt-in only for guardrails + # When enabled, collect guardrails from org/team/key levels + passthrough-specific + guardrails_to_run = PassthroughGuardrailHandler.collect_guardrails( + user_api_key_dict=user_api_key_dict, + passthrough_guardrails_config=guardrails_config, + ) + + # Add guardrails to metadata if any should run + if guardrails_to_run and len(guardrails_to_run) > 0: + if _parsed_body is None: + _parsed_body = {} + if "metadata" not in _parsed_body: + _parsed_body["metadata"] = {} + _parsed_body["metadata"]["guardrails"] = guardrails_to_run + verbose_proxy_logger.debug( + f"Added guardrails to passthrough request metadata: {guardrails_to_run}" + ) + + ## LOGGING OBJECT ## - initialize before pre_call_hook so guardrails can access it + start_time = datetime.now() + logging_obj = Logging( + model="unknown", + messages=[{"role": "user", "content": safe_dumps(_parsed_body)}], + stream=False, + call_type="pass_through_endpoint", + start_time=start_time, + litellm_call_id=litellm_call_id, + function_id="1245", + ) + + # Store passthrough guardrails config on logging_obj for field targeting + logging_obj.passthrough_guardrails_config = guardrails_config + + # Store logging_obj in data so guardrails can access it + if _parsed_body is None: + _parsed_body = {} + _parsed_body["litellm_logging_obj"] = logging_obj + + ### CALL HOOKS ### - modify incoming data / reject request before calling the model + _parsed_body = await proxy_logging_obj.pre_call_hook( + user_api_key_dict=user_api_key_dict, + data=_parsed_body, + call_type="pass_through_endpoint", + ) + async_client_obj = get_async_httpx_client( + llm_provider=httpxSpecialProvider.PassThroughEndpoint, + params={"timeout": 600}, + ) + async_client = async_client_obj.client + passthrough_logging_payload = PassthroughStandardLoggingPayload( + url=str(url), + request_body=_parsed_body, + request_method=getattr(request, "method", None), + cost_per_request=cost_per_request, + ) + kwargs = HttpPassThroughEndpointHelpers._init_kwargs_for_pass_through_endpoint( + user_api_key_dict=user_api_key_dict, + _parsed_body=_parsed_body, + passthrough_logging_payload=passthrough_logging_payload, + litellm_call_id=litellm_call_id, + request=request, + logging_obj=logging_obj, + ) + + # Store custom_llm_provider in kwargs and logging object if provided + if custom_llm_provider: + logging_obj.model_call_details["custom_llm_provider"] = custom_llm_provider + logging_obj.model_call_details["litellm_params"] = kwargs.get( + "litellm_params", {} + ) + + # done for supporting 'parallel_request_limiter.py' with pass-through endpoints + logging_obj.update_environment_variables( + model="unknown", + user="unknown", + optional_params={}, + litellm_params=kwargs["litellm_params"], + call_type="pass_through_endpoint", + ) + logging_obj.model_call_details["litellm_call_id"] = litellm_call_id + + # combine url with query params for logging + requested_query_params: Optional[dict] = query_params or dict( + request.query_params + ) + + requested_query_params_str = None + if requested_query_params: + requested_query_params_str = "&".join( + f"{k}={v}" for k, v in requested_query_params.items() + ) + + logging_url = str(url) + if requested_query_params_str: + if "?" in str(url): + logging_url = str(url) + "&" + requested_query_params_str + else: + logging_url = str(url) + "?" + requested_query_params_str + + logging_obj.pre_call( + input=[{"role": "user", "content": safe_dumps(_parsed_body)}], + api_key="", + additional_args={ + "complete_input_dict": _parsed_body, + "api_base": str(logging_url), + "headers": headers, + }, + ) + stream = ( + HttpPassThroughEndpointHelpers._update_stream_param_based_on_request_body( + parsed_body=_parsed_body, + stream=stream, + ) + ) + + if stream: + req = async_client.build_request( + "POST", + url, + json=_parsed_body, + params=requested_query_params, + headers=headers, + ) + + response = await async_client.send(req, stream=stream) + + try: + response.raise_for_status() + except httpx.HTTPStatusError as e: + raise HTTPException( + status_code=e.response.status_code, detail=await e.response.aread() + ) + + return StreamingResponse( + PassThroughStreamingHandler.chunk_processor( + response=response, + request_body=_parsed_body, + litellm_logging_obj=logging_obj, + endpoint_type=endpoint_type, + start_time=start_time, + passthrough_success_handler_obj=pass_through_endpoint_logging, + url_route=str(url), + ), + headers=HttpPassThroughEndpointHelpers.get_response_headers( + headers=response.headers, + litellm_call_id=litellm_call_id, + ), + status_code=response.status_code, + ) + + response = ( + await HttpPassThroughEndpointHelpers.non_streaming_http_request_handler( + request=request, + async_client=async_client, + url=url, + headers=headers, + requested_query_params=requested_query_params, + _parsed_body=_parsed_body, + ) + ) + verbose_proxy_logger.debug("response.headers= %s", response.headers) + + if _is_streaming_response(response) is True: + try: + response.raise_for_status() + except httpx.HTTPStatusError as e: + raise HTTPException( + status_code=e.response.status_code, detail=await e.response.aread() + ) + + return StreamingResponse( + PassThroughStreamingHandler.chunk_processor( + response=response, + request_body=_parsed_body, + litellm_logging_obj=logging_obj, + endpoint_type=endpoint_type, + start_time=start_time, + passthrough_success_handler_obj=pass_through_endpoint_logging, + url_route=str(url), + ), + headers=HttpPassThroughEndpointHelpers.get_response_headers( + headers=response.headers, + litellm_call_id=litellm_call_id, + ), + status_code=response.status_code, + ) + + try: + response.raise_for_status() + except httpx.HTTPStatusError as e: + raise HTTPException( + status_code=e.response.status_code, detail=e.response.text + ) + + if response.status_code >= 300: + raise HTTPException(status_code=response.status_code, detail=response.text) + + content = await response.aread() + + ## LOG SUCCESS + response_body: Optional[dict] = get_response_body(response) + passthrough_logging_payload["response_body"] = response_body + end_time = datetime.now() + asyncio.create_task( + pass_through_endpoint_logging.pass_through_async_success_handler( + httpx_response=response, + response_body=response_body, + url_route=str(url), + result="", + start_time=start_time, + end_time=end_time, + logging_obj=logging_obj, + cache_hit=False, + request_body=_parsed_body, + custom_llm_provider=custom_llm_provider, + **kwargs, + ) + ) + + ## CUSTOM HEADERS - `x-litellm-*` + custom_headers = ProxyBaseLLMRequestProcessing.get_custom_headers( + user_api_key_dict=user_api_key_dict, + call_id=litellm_call_id, + model_id=None, + cache_key=None, + api_base=str(url._uri_reference), + ) + + return Response( + content=content, + status_code=response.status_code, + headers=HttpPassThroughEndpointHelpers.get_response_headers( + headers=response.headers, + custom_headers=custom_headers, + ), + ) + except Exception as e: + custom_headers = ProxyBaseLLMRequestProcessing.get_custom_headers( + user_api_key_dict=user_api_key_dict, + call_id=litellm_call_id, + model_id=None, + cache_key=None, + api_base=str(url._uri_reference) if url else None, + ) + verbose_proxy_logger.exception( + "litellm.proxy.proxy_server.pass_through_endpoint(): Exception occured - {}".format( + str(e) + ) + ) + + ######################################################### + # Monitoring: Trigger post_call_failure_hook + # for pass through endpoint failure + ######################################################### + request_payload: dict = _parsed_body or {} + # add user_api_key_dict, litellm_call_id, passthrough_logging_payloa for logging + if kwargs: + for key, value in kwargs.items(): + request_payload[key] = value + + if ( + "model" not in request_payload + and _parsed_body + and isinstance(_parsed_body, dict) + ): + request_payload["model"] = _parsed_body.get("model", "") + if "custom_llm_provider" not in request_payload and custom_llm_provider: + request_payload["custom_llm_provider"] = custom_llm_provider + + await proxy_logging_obj.post_call_failure_hook( + user_api_key_dict=user_api_key_dict, + original_exception=e, + request_data=request_payload, + traceback_str=traceback.format_exc( + limit=MAXIMUM_TRACEBACK_LINES_TO_LOG, + ), + ) + + ######################################################### + + if isinstance(e, HTTPException): + raise ProxyException( + message=getattr(e, "message", str(e.detail)), + type=getattr(e, "type", "None"), + param=getattr(e, "param", "None"), + code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST), + headers=custom_headers, + ) + else: + error_msg = f"{str(e)}" + raise ProxyException( + message=getattr(e, "message", error_msg), + type=getattr(e, "type", "None"), + param=getattr(e, "param", "None"), + code=getattr(e, "status_code", 500), + headers=custom_headers, + ) + + +def _update_metadata_with_tags_in_header(request: Request, metadata: dict) -> dict: + """ + If tags are in the request headers, add them to the metadata + + Used for google and vertex JS SDKs, and Azure passthrough + Checks both 'tags' and 'x-litellm-tags' headers + """ + tags_to_add = [] + + # Check for 'tags' header first + _tags = request.headers.get("tags") + if _tags: + tags_to_add.extend([tag.strip() for tag in _tags.split(",")]) + + _tags = request.headers.get("x-litellm-tags") + if _tags: + tags_to_add.extend([tag.strip() for tag in _tags.split(",")]) + + # Only add tags key if there are tags to add + if tags_to_add: + if "tags" not in metadata: + metadata["tags"] = [] + metadata["tags"].extend(tags_to_add) + + return metadata + + +async def _parse_request_data_by_content_type( + request: Request, +) -> Tuple[Optional[Any], Optional[Any], Optional[Any], Optional[Any]]: + """ + Parse request data based on content type. + + Handles JSON, multipart/form-data, and URL-encoded form data. + + Returns: + Tuple of (query_params_data, custom_body_data, file_data, stream) + """ + content_type = request.headers.get("content-type", "") + + query_params_data = None + custom_body_data = None + file_data = None + stream = None + + if "application/json" in content_type: + # ✅ Handle JSON + try: + body = await request.json() + query_params_data = body.get("query_params") + custom_body_data = body.get("custom_body") + stream = body.get("stream") + except json.JSONDecodeError: + # Handle requests with no body (e.g., DELETE requests) + pass + elif "multipart/form-data" in content_type: + # ✅ Handle multipart form-data + form = await request.form() + if "query_params" in form: + form_value = form["query_params"] + if isinstance(form_value, str): + try: + query_params_data = json.loads(form_value) + except Exception: + query_params_data = form_value + else: + query_params_data = form_value + + if "custom_body" in form: + form_value = form["custom_body"] + if isinstance(form_value, str): + try: + custom_body_data = json.loads(form_value) + except Exception: + custom_body_data = form_value + else: + custom_body_data = form_value + + if "file" in form: + file_data = form["file"] # this is a Starlette UploadFile object + + elif "application/x-www-form-urlencoded" in content_type: + # ✅ Handle URL-encoded form data + form = await request.form() + query_params_data = form.get("query_params") + custom_body_data = form.get("custom_body") + + else: + # ✅ Fallback: maybe no body, just query params + query_params_data = dict(request.query_params) or None + + return query_params_data, custom_body_data, file_data, stream + + +def create_pass_through_route( + endpoint, + target: str, + custom_headers: Optional[dict] = None, + _forward_headers: Optional[bool] = False, + _merge_query_params: Optional[bool] = False, + dependencies: Optional[List] = None, + include_subpath: Optional[bool] = False, + cost_per_request: Optional[float] = None, + custom_llm_provider: Optional[str] = None, + is_streaming_request: Optional[bool] = False, + query_params: Optional[dict] = None, + guardrails: Optional[Dict[str, Any]] = None, +): + # check if target is an adapter.py or a url + from litellm._uuid import uuid + from litellm.proxy.types_utils.utils import get_instance_fn + + try: + if isinstance(target, CustomLogger): + adapter = target + else: + adapter = get_instance_fn(value=target) + adapter_id = str(uuid.uuid4()) + litellm.adapters = [{"id": adapter_id, "adapter": adapter}] + + async def endpoint_func( # type: ignore + request: Request, + fastapi_response: Response, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), + subpath: str = "", # captures sub-paths when include_subpath=True + custom_body: Optional[ + dict + ] = None, # accepted for signature compatibility with URL-based path; not forwarded because chat_completion_pass_through_endpoint does not support it + ): + return await chat_completion_pass_through_endpoint( + fastapi_response=fastapi_response, + request=request, + adapter_id=adapter_id, + user_api_key_dict=user_api_key_dict, + ) + + except Exception: + verbose_proxy_logger.debug("Defaulting to target being a url.") + + async def endpoint_func( # type: ignore + request: Request, + fastapi_response: Response, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), + subpath: str = "", # captures sub-paths when include_subpath=True + custom_body: Optional[dict] = None, + ): + from litellm.proxy.pass_through_endpoints.pass_through_endpoints import ( + InitPassThroughEndpointHelpers, + ) + + path = request.url.path + + # Parse request data based on content type + ( + query_params_data, + custom_body_data, + file_data, + stream, + ) = await _parse_request_data_by_content_type(request) + + if not InitPassThroughEndpointHelpers.is_registered_pass_through_route( + route=path + ): + raise HTTPException( + status_code=404, + detail=f"Pass-through endpoint {endpoint} not found. This could have been deleted or not yet added to the proxy.", + ) + + passthrough_params = ( + InitPassThroughEndpointHelpers.get_registered_pass_through_route( + route=path + ) + ) + target_params = { + "target": target, + "custom_headers": custom_headers, + "forward_headers": _forward_headers, + "merge_query_params": _merge_query_params, + "cost_per_request": cost_per_request, + "guardrails": None, + } + + if passthrough_params is not None: + target_params.update(passthrough_params.get("passthrough_params", {})) + + # Extract and cast parameters with proper types + param_target = target_params.get("target") or target + param_custom_headers = target_params.get("custom_headers", custom_headers) + param_forward_headers = target_params.get( + "forward_headers", _forward_headers + ) + param_merge_query_params = target_params.get( + "merge_query_params", _merge_query_params + ) + param_cost_per_request = target_params.get( + "cost_per_request", cost_per_request + ) + param_guardrails = target_params.get("guardrails", None) + + # Construct the full target URL with subpath if needed + full_target = ( + HttpPassThroughEndpointHelpers.construct_target_url_with_subpath( + base_target=cast(str, param_target), + subpath=subpath, + include_subpath=include_subpath, + ) + ) + + # Ensure custom_headers is a dict + headers_dict = ( + param_custom_headers if isinstance(param_custom_headers, dict) else {} + ) + + # Ensure query_params and custom_body are dicts or None + final_query_params = ( + query_params_data if isinstance(query_params_data, dict) else {} + ) + if query_params: + final_query_params.update(query_params) + # When a caller (e.g. bedrock_proxy_route) supplies a pre-built + # body, use it instead of the body parsed from the raw request. + final_custom_body: Optional[dict] = None + if custom_body is not None: + final_custom_body = custom_body + elif isinstance(custom_body_data, dict): + final_custom_body = custom_body_data + + return await pass_through_request( # type: ignore + request=request, + target=full_target, + custom_headers=headers_dict, + user_api_key_dict=user_api_key_dict, + forward_headers=cast(Optional[bool], param_forward_headers), + merge_query_params=cast(Optional[bool], param_merge_query_params), + query_params=final_query_params, + stream=is_streaming_request or stream, + custom_body=final_custom_body, + cost_per_request=cast(Optional[float], param_cost_per_request), + custom_llm_provider=custom_llm_provider, + guardrails_config=cast(Optional[dict], param_guardrails), + ) + + return endpoint_func + + +def create_websocket_passthrough_route( + endpoint: str, + target: str, + custom_headers: Optional[dict] = None, + _forward_headers: Optional[bool] = False, + dependencies: Optional[List] = None, + cost_per_request: Optional[float] = None, +): + """ + Create a WebSocket passthrough route function. + + Args: + endpoint: The endpoint path (for logging purposes) + target: The target WebSocket URL (e.g., "wss://api.example.com/ws") + custom_headers: Custom headers to include in the WebSocket connection + _forward_headers: Whether to forward incoming headers + dependencies: FastAPI dependencies to inject + + Returns: + A WebSocket passthrough function that can be registered with app.websocket() + """ + from litellm.proxy.auth.user_api_key_auth import user_api_key_auth_websocket + + async def websocket_endpoint_func( + websocket: WebSocket, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth_websocket), + **kwargs, # For additional query parameters + ): + """ + WebSocket passthrough endpoint function. + + This function handles the WebSocket connection by: + 1. Accepting the incoming WebSocket connection + 2. Establishing a connection to the target WebSocket + 3. Forwarding messages bidirectionally + 4. Handling connection cleanup + """ + return await websocket_passthrough_request( + websocket=websocket, + target=target, + custom_headers=custom_headers or {}, + user_api_key_dict=user_api_key_dict, + forward_headers=_forward_headers, + endpoint=endpoint, + cost_per_request=cost_per_request, + accept_websocket=True, # Generic usage should accept the WebSocket + ) + + return websocket_endpoint_func + + +async def websocket_passthrough_request( # noqa: PLR0915 + websocket: WebSocket, + target: str, + custom_headers: dict, + user_api_key_dict: UserAPIKeyAuth, + forward_headers: Optional[bool] = False, + endpoint: Optional[str] = None, + cost_per_request: Optional[float] = None, + accept_websocket: bool = True, +): + """ + WebSocket passthrough request handler. + + Args: + websocket: The incoming WebSocket connection + target: The target WebSocket URL + custom_headers: Custom headers to include in the connection + user_api_key_dict: The user API key dictionary + forward_headers: Whether to forward incoming headers + endpoint: The endpoint path (for logging purposes) + cost_per_request: Optional field - cost per request to the target endpoint + """ + from litellm.litellm_core_utils.litellm_logging import Logging + from litellm.proxy.proxy_server import proxy_logging_obj + from litellm.types.passthrough_endpoints.pass_through_endpoints import ( + PassthroughStandardLoggingPayload, + ) + + # Initialize tracking variables + start_time = datetime.now() + websocket_messages: list[dict[str, Any]] = [] + litellm_call_id = str(uuid.uuid4()) + + verbose_proxy_logger.info( + f"WebSocket passthrough ({endpoint}): Starting WebSocket connection to {target}" + ) + + # Only accept the WebSocket if requested (for generic usage) + if accept_websocket: + await websocket.accept() + verbose_proxy_logger.debug( + f"WebSocket passthrough ({endpoint}): WebSocket connection accepted" + ) + + # Prepare headers for the upstream connection + upstream_headers = custom_headers.copy() + + if forward_headers: + # Forward relevant headers from the incoming request + incoming_headers = dict(websocket.headers) + for header_name, header_value in incoming_headers.items(): + # Only forward certain headers to avoid conflicts + if header_name.lower() in [ + "authorization", + "x-api-key", + "x-goog-user-project", + ]: + upstream_headers[header_name] = header_value + + # Initialize logging object similar to HTTP passthrough + logging_obj = Logging( + model="unknown", + messages=[{"role": "user", "content": "WebSocket connection"}], + stream=True, # WebSockets are inherently streaming + call_type="pass_through_endpoint", + start_time=start_time, + litellm_call_id=litellm_call_id, + function_id="websocket_passthrough", + ) + + # Create passthrough logging payload + passthrough_logging_payload = PassthroughStandardLoggingPayload( + url=target, + request_body={}, # WebSocket doesn't have a traditional request body + request_method="WEBSOCKET", + cost_per_request=cost_per_request, + ) + + # Create a dummy request object for WebSocket connections to maintain compatibility + # with the existing _init_kwargs_for_pass_through_endpoint function + class DummyRequest: + def __init__( + self, url: str, method: str = "WEBSOCKET", headers: Optional[dict] = None + ): + self.url = url + self.method = method + self.headers = headers or {} + + def __str__(self): + return f"DummyRequest(url={self.url}, method={self.method})" + + dummy_request = DummyRequest( + url=target, + method="WEBSOCKET", + headers=dict(websocket.headers) if hasattr(websocket, "headers") else {}, + ) + + # Initialize kwargs for logging using the same pattern as HTTP passthrough + kwargs = HttpPassThroughEndpointHelpers._init_kwargs_for_pass_through_endpoint( + user_api_key_dict=user_api_key_dict, + _parsed_body={}, # WebSocket doesn't have a traditional request body + passthrough_logging_payload=passthrough_logging_payload, + litellm_call_id=litellm_call_id, + request=dummy_request, # type: ignore + logging_obj=logging_obj, + ) + + # Update logging environment variables + logging_obj.update_environment_variables( + model="unknown", + user="unknown", + optional_params={}, + litellm_params=dict(kwargs.get("litellm_params", {})), + call_type="pass_through_endpoint", + ) + logging_obj.model_call_details["litellm_call_id"] = litellm_call_id + + # Pre-call logging + logging_obj.pre_call( + input=[{"role": "user", "content": "WebSocket connection"}], + api_key="", + additional_args={ + "complete_input_dict": {}, + "api_base": target, + "headers": upstream_headers, + }, + ) + + ### CALL HOOKS ### - modify incoming data / reject request before calling the model + websocket_data: dict[str, Any] = {} + websocket_data = await proxy_logging_obj.pre_call_hook( + user_api_key_dict=user_api_key_dict, + data=websocket_data, + call_type="pass_through_endpoint", + ) + + try: + verbose_proxy_logger.debug( + f"WebSocket passthrough ({endpoint}): Establishing upstream connection to {target}" + ) + async with connect( + target, + additional_headers=upstream_headers, + ) as upstream_ws: + verbose_proxy_logger.info( + f"WebSocket passthrough ({endpoint}): Upstream connection established successfully" + ) + + async def forward_client_to_upstream() -> None: + """Forward messages from client to upstream WebSocket""" + try: + while True: + message = await websocket.receive() + message_type = message.get("type") + if message_type == "websocket.disconnect": + await upstream_ws.close() + break + + text_data = message.get("text") + bytes_data = message.get("bytes") + + if text_data is not None: + # Try to extract model from client setup message for Vertex AI Live + if endpoint and "/vertex_ai/live" in endpoint: + verbose_proxy_logger.debug( + f"WebSocket passthrough ({endpoint}): Processing client message for model extraction" + ) + try: + client_message = json.loads(text_data) + if ( + isinstance(client_message, dict) + and "setup" in client_message + ): + setup_data = client_message["setup"] + verbose_proxy_logger.debug( + f"WebSocket passthrough ({endpoint}): Found setup data in client message: {setup_data}" + ) + if ( + isinstance(setup_data, dict) + and "model" in setup_data + ): + extracted_model = ( + _extract_model_from_vertex_ai_setup( + setup_data + ) + ) + if extracted_model: + kwargs["model"] = extracted_model + kwargs[ + "custom_llm_provider" + ] = "vertex_ai-language-models" + # Update logging object with correct model + logging_obj.model = extracted_model + logging_obj.model_call_details[ + "model" + ] = extracted_model + logging_obj.model_call_details[ + "custom_llm_provider" + ] = "vertex_ai" + verbose_proxy_logger.info( + f"WebSocket passthrough ({endpoint}): Successfully extracted model '{extracted_model}' and set provider to 'vertex_ai' from client setup message" + ) + else: + verbose_proxy_logger.warning( + f"WebSocket passthrough ({endpoint}): Failed to extract model from client setup data: {setup_data}" + ) + else: + verbose_proxy_logger.debug( + f"WebSocket passthrough ({endpoint}): Setup data does not contain model field: {setup_data}" + ) + else: + verbose_proxy_logger.debug( + f"WebSocket passthrough ({endpoint}): Client message does not contain setup data" + ) + except (json.JSONDecodeError, KeyError, TypeError) as e: + verbose_proxy_logger.debug( + f"WebSocket passthrough ({endpoint}): Client message is not a valid setup message: {e}" + ) + pass # Not a JSON message or doesn't contain setup data + + await upstream_ws.send(text_data) + elif bytes_data is not None: + await upstream_ws.send(bytes_data) + except asyncio.CancelledError: + raise + except Exception: + verbose_proxy_logger.exception( + f"WebSocket passthrough ({endpoint}): error forwarding client message" + ) + await upstream_ws.close() + + async def forward_upstream_to_client() -> None: + """Forward messages from upstream to client WebSocket""" + try: + # Wait for the first response from upstream + raw_response = await upstream_ws.recv(decode=False) + # Ensure raw_response is bytes before decoding + if isinstance(raw_response, str): + raw_response = raw_response.encode("ascii") + setup_response = json.loads(raw_response.decode("ascii")) + verbose_proxy_logger.debug(f"Setup response: {setup_response}") + + # Extract model and provider from setup response for Vertex AI Live + if endpoint and "/vertex_ai/live" in endpoint: + verbose_proxy_logger.debug( + f"WebSocket passthrough ({endpoint}): Processing server setup response for model extraction" + ) + extracted_model = _extract_model_from_vertex_ai_setup( + setup_response + ) + if extracted_model: + kwargs["model"] = extracted_model + kwargs["custom_llm_provider"] = "vertex_ai_language_models" + # Update logging object with correct model + logging_obj.model = extracted_model + logging_obj.model_call_details["model"] = extracted_model + logging_obj.model_call_details[ + "custom_llm_provider" + ] = "vertex_ai_language_models" + verbose_proxy_logger.debug( + f"WebSocket passthrough ({endpoint}): Successfully extracted model '{extracted_model}' and set provider to 'vertex_ai' from server setup response" + ) + else: + verbose_proxy_logger.warning( + f"WebSocket passthrough ({endpoint}): Failed to extract model from server setup response: {setup_response}" + ) + else: + verbose_proxy_logger.debug( + f"WebSocket passthrough ({endpoint}): Not a Vertex AI Live endpoint, skipping model extraction" + ) + + # Send the setup response to the client + await websocket.send_text(json.dumps(setup_response)) + + # Now continuously forward messages from upstream to client + async for upstream_message in upstream_ws: + if isinstance(upstream_message, bytes): + await websocket.send_bytes(upstream_message) + # Parse and collect for cost tracking + try: + message_data = json.loads(upstream_message.decode()) + websocket_messages.append(message_data) + except (json.JSONDecodeError, UnicodeDecodeError): + pass + else: + await websocket.send_text(upstream_message) + # Parse and collect for cost tracking + try: + message_data = json.loads(upstream_message) + websocket_messages.append(message_data) + except json.JSONDecodeError: + pass + + except (ConnectionClosedOK, ConnectionClosedError) as e: + verbose_proxy_logger.debug( + f"Upstream WebSocket connection closed: {e}" + ) + pass + except asyncio.CancelledError: + verbose_proxy_logger.debug( + "asyncio.CancelledError in forward_upstream_to_client" + ) + raise + except Exception as e: + verbose_proxy_logger.debug( + f"Exception in forward_upstream_to_client: {e}" + ) + verbose_proxy_logger.exception( + f"WebSocket passthrough ({endpoint}): error forwarding upstream message" + ) + raise + + # Create tasks for bidirectional message forwarding + tasks = [ + asyncio.create_task(forward_client_to_upstream()), + asyncio.create_task(forward_upstream_to_client()), + ] + + done, pending = await asyncio.wait( + tasks, return_when=asyncio.FIRST_COMPLETED + ) + + # Cancel remaining tasks + for task in pending: + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + + # Check for exceptions in completed tasks + for task in done: + exception = task.exception() + if exception is not None: + raise exception + + end_time = datetime.now() + + # Update passthrough logging payload with response data + passthrough_logging_payload["response_body"] = websocket_messages # type: ignore + passthrough_logging_payload["end_time"] = end_time # type: ignore + + # Remove logging_obj from kwargs to avoid duplicate keyword argument + success_kwargs = kwargs.copy() + success_kwargs.pop("logging_obj", None) + + # # Add user authentication context for database logging + # if user_api_key_dict: + # success_kwargs.setdefault('litellm_params', {}) + # success_kwargs['litellm_params'].update({ + # 'proxy_server_request': { + # 'body': { + # 'user': user_api_key_dict.user_id, + # 'team_id': user_api_key_dict.team_id, + # 'end_user_id': user_api_key_dict.end_user_id, + # } + # } + # }) + # # Also add the user_api_key for direct access + # success_kwargs['user_api_key'] = user_api_key_dict.api_key + + # Create a dummy httpx.Response for WebSocket connections + class MockWebSocketResponse: + def __init__(self, target_url: str): + self.status_code = 200 + self.text = "WebSocket connection successful" + self.headers: dict[str, str] = {} + self.request = MockWebSocketRequest(target_url) + + class MockWebSocketRequest: + def __init__(self, target_url: str): + self.method = "WEBSOCKET" + self.url = target_url + + mock_response = MockWebSocketResponse(target) + + # Use the same success handler as HTTP passthrough endpoints + asyncio.create_task( + pass_through_endpoint_logging.pass_through_async_success_handler( + httpx_response=mock_response, # type: ignore + response_body=websocket_messages, # type: ignore + url_route=endpoint or "", + result="websocket_connection_successful", + start_time=start_time, + end_time=end_time, + logging_obj=logging_obj, + cache_hit=False, + request_body={}, + **success_kwargs, + ) + ) + + # Call the proxy logging success hook + if proxy_logging_obj: + await proxy_logging_obj.post_call_success_hook( + data={}, + user_api_key_dict=user_api_key_dict, + response={"status": "websocket_connection_successful"}, # type: ignore + ) + + except InvalidStatus as exc: + verbose_proxy_logger.exception( + f"WebSocket passthrough ({endpoint}): upstream rejected WebSocket connection" + ) + + # Prepare request payload for logging + request_payload = {} + if kwargs: + for key, value in kwargs.items(): + request_payload[key] = value + + # Log the connection failure using the same pattern as HTTP + await proxy_logging_obj.post_call_failure_hook( + user_api_key_dict=user_api_key_dict, + original_exception=exc, + request_data=request_payload, + traceback_str=traceback.format_exc( + limit=MAXIMUM_TRACEBACK_LINES_TO_LOG, + ), + ) + + if websocket.client_state != WebSocketState.DISCONNECTED: + await websocket.close( + code=getattr(exc, "status_code", 1011), + reason="Upstream connection rejected", + ) + except Exception as e: + verbose_proxy_logger.exception( + f"WebSocket passthrough ({endpoint}): unexpected error while proxying WebSocket" + ) + + # Prepare request payload for logging + request_payload = {} + if kwargs: + for key, value in kwargs.items(): + request_payload[key] = value + + # Log the unexpected error using the same pattern as HTTP + await proxy_logging_obj.post_call_failure_hook( + user_api_key_dict=user_api_key_dict, + original_exception=e, + request_data=request_payload, + traceback_str=traceback.format_exc( + limit=MAXIMUM_TRACEBACK_LINES_TO_LOG, + ), + ) + + if websocket.client_state != WebSocketState.DISCONNECTED: + await websocket.close(code=1011, reason="WebSocket passthrough error") + finally: + if websocket.client_state != WebSocketState.DISCONNECTED: + await websocket.close() + + +def _is_streaming_response(response: httpx.Response) -> bool: + _content_type = response.headers.get("content-type") + if _content_type is not None and "text/event-stream" in _content_type: + return True + return False + + +def _extract_model_from_vertex_ai_setup(setup_response: dict) -> Optional[str]: + """ + Extract the model name from Vertex AI Live setup response. + + The setup response can contain a model field in two formats: + 1. Direct: {"model": "projects/.../models/gemini-2.0-flash-live-preview-04-09"} + 2. Nested: {"setup": {"model": "projects/.../models/gemini-2.0-flash-live-preview-04-09"}} + + We extract just the model name: "gemini-2.0-flash-live-preview-04-09" + """ + try: + # Handle both direct model field and nested setup.model field + model_path = None + if isinstance(setup_response, dict): + if "model" in setup_response: + model_path = setup_response["model"] + elif ( + "setup" in setup_response + and isinstance(setup_response["setup"], dict) + and "model" in setup_response["setup"] + ): + model_path = setup_response["setup"]["model"] + + if isinstance(model_path, str) and "/models/" in model_path: + # Extract the model name after the last "/models/" + model_name = model_path.split("/models/")[-1] + return model_name + except Exception as e: + verbose_proxy_logger.debug(f"Error extracting model from setup response: {e}") + return None + + +class SafeRouteAdder: + """ + Wrapper class for adding routes to FastAPI app. + Only adds routes if they don't already exist on the app. + """ + + @staticmethod + def _is_path_registered(app: FastAPI, path: str, methods: List[str]) -> bool: + """ + Check if a path with any of the specified methods is already registered on the app. + + Args: + app: The FastAPI application instance + path: The path to check (e.g., "/v1/chat/completions") + methods: List of HTTP methods to check (e.g., ["GET", "POST"]) + + Returns: + True if the path is already registered with any of the methods, False otherwise + """ + for route in app.routes: + # Use getattr to safely access route attributes + route_path = getattr(route, "path", None) + route_methods = getattr(route, "methods", None) + + if route_path == path and route_methods is not None: + # Check if any of the methods overlap + if any(method in route_methods for method in methods): + return True + return False + + @staticmethod + def add_api_route_if_not_exists( + app: FastAPI, + path: str, + endpoint: Any, + methods: List[str], + dependencies: Optional[List] = None, + ) -> bool: + """ + Add an API route to the app only if it doesn't already exist. + + Args: + app: The FastAPI application instance + path: The path for the route + endpoint: The endpoint function/callable + methods: List of HTTP methods + dependencies: Optional list of dependencies + + Returns: + True if route was added, False if it already existed + """ + if SafeRouteAdder._is_path_registered(app=app, path=path, methods=methods): + verbose_proxy_logger.debug( + "Skipping route registration - path %s with methods %s already registered on app", + path, + methods, + ) + return False + + app.add_api_route( + path=path, + endpoint=endpoint, + methods=methods, + dependencies=dependencies, + ) + verbose_proxy_logger.debug( + "Successfully added route: %s with methods %s", + path, + methods, + ) + return True + + +class InitPassThroughEndpointHelpers: + @staticmethod + def add_exact_path_route( + app: FastAPI, + path: str, + target: str, + custom_headers: Optional[dict], + forward_headers: Optional[bool], + merge_query_params: Optional[bool], + dependencies: Optional[List], + cost_per_request: Optional[float], + endpoint_id: str, + guardrails: Optional[dict] = None, + ): + """Add exact path route for pass-through endpoint""" + route_key = f"{endpoint_id}:exact:{path}" + + # Check if this exact route is already registered + if route_key in _registered_pass_through_routes: + verbose_proxy_logger.debug( + "Updating duplicate exact pass through endpoint: %s (already registered)", + path, + ) + + verbose_proxy_logger.debug( + "adding exact pass through endpoint: %s, dependencies: %s", + path, + dependencies, + ) + + # Use SafeRouteAdder to only add route if it doesn't exist on the app + SafeRouteAdder.add_api_route_if_not_exists( + app=app, + path=path, + endpoint=create_pass_through_route( # type: ignore + path, + target, + custom_headers, + forward_headers, + merge_query_params, + dependencies, + cost_per_request=cost_per_request, + guardrails=guardrails, + ), + methods=["GET", "POST", "PUT", "DELETE", "PATCH"], + dependencies=dependencies, + ) + + # Always register/update the route metadata (headers, target) even if FastAPI route exists + _registered_pass_through_routes[route_key] = { + "endpoint_id": endpoint_id, + "path": path, + "type": "exact", + "passthrough_params": { + "target": target, + "custom_headers": custom_headers, + "forward_headers": forward_headers, + "merge_query_params": merge_query_params, + "dependencies": dependencies, + "cost_per_request": cost_per_request, + "guardrails": guardrails, + }, + } + + @staticmethod + def add_subpath_route( + app: FastAPI, + path: str, + target: str, + custom_headers: Optional[dict], + forward_headers: Optional[bool], + merge_query_params: Optional[bool], + dependencies: Optional[List], + cost_per_request: Optional[float], + endpoint_id: str, + guardrails: Optional[dict] = None, + ): + """Add wildcard route for sub-paths""" + wildcard_path = f"{path}/{{subpath:path}}" + route_key = f"{endpoint_id}:subpath:{path}" + + # Check if this subpath route is already registered + if route_key in _registered_pass_through_routes: + verbose_proxy_logger.debug( + "Updating duplicate wildcard pass through endpoint: %s (already registered)", + wildcard_path, + ) + + verbose_proxy_logger.debug( + "adding wildcard pass through endpoint: %s, dependencies: %s", + wildcard_path, + dependencies, + ) + + # Use SafeRouteAdder to only add route if it doesn't exist on the app + SafeRouteAdder.add_api_route_if_not_exists( + app=app, + path=wildcard_path, + endpoint=create_pass_through_route( # type: ignore + path, + target, + custom_headers, + forward_headers, + merge_query_params, + dependencies, + include_subpath=True, + cost_per_request=cost_per_request, + guardrails=guardrails, + ), + methods=["GET", "POST", "PUT", "DELETE", "PATCH"], + dependencies=dependencies, + ) + + # Register the route to prevent duplicates only if it was added + _registered_pass_through_routes[route_key] = { + "endpoint_id": endpoint_id, + "path": path, + "type": "subpath", + "passthrough_params": { + "target": target, + "custom_headers": custom_headers, + "forward_headers": forward_headers, + "merge_query_params": merge_query_params, + "dependencies": dependencies, + "cost_per_request": cost_per_request, + "guardrails": guardrails, + }, + } + + @staticmethod + def remove_endpoint_routes(endpoint_id: str): + """Remove all routes for a specific endpoint ID from the registry""" + keys_to_remove = [ + key + for key, value in _registered_pass_through_routes.items() + if value["endpoint_id"] == endpoint_id + ] + for key in keys_to_remove: + del _registered_pass_through_routes[key] + verbose_proxy_logger.debug( + "Removed pass-through route from registry: %s", key + ) + + @staticmethod + def clear_all_pass_through_routes(): + """Clear all pass-through routes from the registry""" + _registered_pass_through_routes.clear() + + @staticmethod + def get_all_registered_pass_through_routes() -> List[str]: + """Get all registered pass-through endpoints from the registry""" + return list(_registered_pass_through_routes.keys()) + + @staticmethod + def _build_full_path_with_root(path: str) -> str: + """ + Build full path by prepending server root path if needed. + + Args: + path: The relative path to build + + Returns: + Full path with server root prepended (if root is not "/") + """ + root_path = get_server_root_path() + if root_path == "/": + return path + return f"{root_path}{path}" + + @staticmethod + def is_registered_pass_through_route(route: str) -> bool: + """ + Check if route is a registered pass-through endpoint from DB + + Uses the in-memory registry to avoid additional DB queries + Optimized for minimal latency + + Args: + route: The route to check + + Returns: + bool: True if route is a registered pass-through endpoint, False otherwise + """ + ## CHECK IF MAPPED PASS THROUGH ENDPOINT + for mapped_route in LiteLLMRoutes.mapped_pass_through_routes.value: + if route.startswith(mapped_route): + return True + + # Fast path: check if any registered route key contains this path + # Keys are in format: "{endpoint_id}:exact:{path}" or "{endpoint_id}:subpath:{path}" + # Extract unique paths from keys for quick checking + for key in _registered_pass_through_routes.keys(): + parts = key.split(":", 2) # Split into [endpoint_id, type, path] + if len(parts) == 3: + route_type = parts[1] + registered_path = ( + InitPassThroughEndpointHelpers._build_full_path_with_root(parts[2]) + ) + if route_type == "exact" and route == registered_path: + return True + elif route_type == "subpath": + if route == registered_path or route.startswith( + registered_path + "/" + ): + return True + + return False + + @staticmethod + def get_registered_pass_through_route(route: str) -> Optional[Dict[str, Any]]: + """Get passthrough params for a given route""" + for key in _registered_pass_through_routes.keys(): + parts = key.split(":", 2) # Split into [endpoint_id, type, path] + if len(parts) == 3: + route_type = parts[1] + registered_path = ( + InitPassThroughEndpointHelpers._build_full_path_with_root(parts[2]) + ) + + if route_type == "exact" and route == registered_path: + return _registered_pass_through_routes[key] + elif route_type == "subpath": + if route == registered_path or route.startswith( + registered_path + "/" + ): + return _registered_pass_through_routes[key] + + return None + + +def _get_combined_pass_through_endpoints( + pass_through_endpoints: Union[List[Dict], List[PassThroughGenericEndpoint]], + config_pass_through_endpoints: List[Dict], +): + """Get combined pass-through endpoints from db + config""" + return pass_through_endpoints + config_pass_through_endpoints + + +async def initialize_pass_through_endpoints( + pass_through_endpoints: Union[List[Dict], List[PassThroughGenericEndpoint]], +): + """ + 1. Create a global list of pass-through endpoints (db + config) + 2. Clear all existing pass-through endpoints from the FastAPI app routes + 3. Add new endpoints to the in-memory registry + + Initialize a list of pass-through endpoints by adding them to the FastAPI app routes + + Args: + pass_through_endpoints: List of pass-through endpoints to initialize + + Returns: + None + """ + from litellm._uuid import uuid + + verbose_proxy_logger.debug("initializing pass through endpoints") + from litellm.proxy._types import CommonProxyErrors, LiteLLMRoutes + from litellm.proxy.proxy_server import ( + app, + config_passthrough_endpoints, + premium_user, + ) + + ## get combined pass-through endpoints from db + config + combined_pass_through_endpoints: List[Union[Dict, PassThroughGenericEndpoint]] + + if config_passthrough_endpoints is not None: + combined_pass_through_endpoints = _get_combined_pass_through_endpoints( # type: ignore + pass_through_endpoints, config_passthrough_endpoints + ) + else: + combined_pass_through_endpoints = pass_through_endpoints # type: ignore + + ## clear all existing pass-through endpoints from the FastAPI app routes + # InitPassThroughEndpointHelpers.clear_all_pass_through_routes() + + # get a list of all registered pass-through endpoints + # mark the ones that are visited in the list + # remove the ones that are not visited from the list + registered_pass_through_endpoints = ( + InitPassThroughEndpointHelpers.get_all_registered_pass_through_routes() + ) + + visited_endpoints = set() + + for endpoint in combined_pass_through_endpoints: + if isinstance(endpoint, PassThroughGenericEndpoint): + endpoint = endpoint.model_dump() + + # Auto-generate ID for backwards compatibility if not present + if endpoint.get("id") is None: + endpoint["id"] = str(uuid.uuid4()) + + # Get the endpoint_id as a string (guaranteed to be set at this point) + endpoint_id: str = endpoint["id"] + + _target = endpoint.get("target", None) + _path: Optional[str] = endpoint.get("path", None) + if _path is None: + raise ValueError("Path is required for pass-through endpoint") + _custom_headers = endpoint.get("headers", None) + _custom_headers = await set_env_variables_in_header( + custom_headers=_custom_headers + ) + _forward_headers = endpoint.get("forward_headers", None) + _merge_query_params = endpoint.get("merge_query_params", None) + _auth = endpoint.get("auth", None) + _dependencies = None + if _auth is not None and str(_auth).lower() == "true": + if premium_user is not True: + raise ValueError( + "Error Setting Authentication on Pass Through Endpoint: {}".format( + CommonProxyErrors.not_premium_user.value + ) + ) + _dependencies = [Depends(user_api_key_auth)] + LiteLLMRoutes.openai_routes.value.append(_path) + + if _target is None: + continue + + # Get guardrails config if present + _guardrails = endpoint.get("guardrails", None) + + # Add exact path route + verbose_proxy_logger.debug( + "Initializing pass through endpoint: %s (ID: %s)", _path, endpoint_id + ) + InitPassThroughEndpointHelpers.add_exact_path_route( + app=app, + path=_path, + target=_target, + custom_headers=_custom_headers, + forward_headers=_forward_headers, + merge_query_params=_merge_query_params, + dependencies=_dependencies, + cost_per_request=endpoint.get("cost_per_request", None), + endpoint_id=endpoint_id, + guardrails=_guardrails, + ) + + visited_endpoints.add(f"{endpoint_id}:exact:{_path}") + + # Add wildcard route for sub-paths + if endpoint.get("include_subpath", False) is True: + InitPassThroughEndpointHelpers.add_subpath_route( + app=app, + path=_path, + target=_target, + custom_headers=_custom_headers, + forward_headers=_forward_headers, + merge_query_params=_merge_query_params, + dependencies=_dependencies, + cost_per_request=endpoint.get("cost_per_request", None), + endpoint_id=endpoint_id, + guardrails=_guardrails, + ) + + visited_endpoints.add(f"{endpoint_id}:subpath:{_path}") + + verbose_proxy_logger.debug( + "Added new pass through endpoint: %s (ID: %s)", _path, endpoint_id + ) + + # remove the ones that are not visited from the list + for endpoint_key in registered_pass_through_endpoints: + if endpoint_key not in visited_endpoints: + InitPassThroughEndpointHelpers.remove_endpoint_routes(endpoint_key) + + +def _get_pass_through_endpoints_from_config() -> List[PassThroughGenericEndpoint]: + """ + Get pass-through endpoints defined in the config file. + These are read-only and cannot be edited via the UI. + Malformed endpoints are logged and skipped; they do not crash the function. + """ + from pydantic import ValidationError + + from litellm.proxy.proxy_server import config_passthrough_endpoints + + if config_passthrough_endpoints is None or len(config_passthrough_endpoints) == 0: + return [] + + returned_endpoints: List[PassThroughGenericEndpoint] = [] + for endpoint in config_passthrough_endpoints: + try: + if isinstance(endpoint, dict): + endpoint_dict = dict(endpoint) + endpoint_dict["is_from_config"] = True + returned_endpoints.append(PassThroughGenericEndpoint(**endpoint_dict)) + elif isinstance(endpoint, PassThroughGenericEndpoint): + # Create a copy with is_from_config=True + endpoint_dict = endpoint.model_dump() + endpoint_dict["is_from_config"] = True + returned_endpoints.append(PassThroughGenericEndpoint(**endpoint_dict)) + except ValidationError as e: + verbose_proxy_logger.warning( + "Skipping malformed pass-through endpoint from config: %s", + e, + exc_info=False, + ) + + return returned_endpoints + + +async def _get_pass_through_endpoints_from_db( + endpoint_id: Optional[str] = None, + user_api_key_dict: Optional[UserAPIKeyAuth] = None, +) -> List[PassThroughGenericEndpoint]: + from litellm.proxy._types import LitellmUserRoles + from litellm.proxy.proxy_server import get_config_general_settings + + try: + if user_api_key_dict is None: + user_api_key_dict = UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN) + response: ConfigFieldInfo = await get_config_general_settings( + field_name="pass_through_endpoints", user_api_key_dict=user_api_key_dict + ) + except Exception: + return [] + + pass_through_endpoint_data: Optional[List] = response.field_value + if pass_through_endpoint_data is None: + return [] + + returned_endpoints: List[PassThroughGenericEndpoint] = [] + if endpoint_id is None: + # Return all endpoints from DB, mark as not from config + for endpoint in pass_through_endpoint_data: + if isinstance(endpoint, dict): + endpoint_dict = dict(endpoint) + endpoint_dict["is_from_config"] = False + returned_endpoints.append(PassThroughGenericEndpoint(**endpoint_dict)) + elif isinstance(endpoint, PassThroughGenericEndpoint): + endpoint_dict = endpoint.model_dump() + endpoint_dict["is_from_config"] = False + returned_endpoints.append(PassThroughGenericEndpoint(**endpoint_dict)) + else: + # Find specific endpoint by ID + found_endpoint = _find_endpoint_by_id(pass_through_endpoint_data, endpoint_id) + if found_endpoint is not None: + endpoint_dict = ( + found_endpoint.model_dump() + if isinstance(found_endpoint, PassThroughGenericEndpoint) + else dict(found_endpoint) + ) + endpoint_dict["is_from_config"] = False + returned_endpoints.append(PassThroughGenericEndpoint(**endpoint_dict)) + + return returned_endpoints + + +async def _filter_endpoints_by_team_allowed_routes( + team_id: str, + pass_through_endpoints: List[PassThroughGenericEndpoint], + prisma_client, +) -> List[PassThroughGenericEndpoint]: + """ + Filter pass-through endpoints based on team's allowed_passthrough_routes metadata. + + Args: + team_id: The team ID to check permissions for + pass_through_endpoints: List of endpoints to filter + prisma_client: Database client + + Returns: + Filtered list of endpoints based on team permissions + + Raises: + HTTPException: If team is not found + """ + # retrieve team from db + team = await prisma_client.db.litellm_teamtable.find_unique( + where={"team_id": team_id}, + ) + if team is None: + raise HTTPException( + status_code=404, + detail={"error": "Team not found"}, + ) + + # retrieve team metadata + team_metadata = team.metadata + if ( + team_metadata is not None + and team_metadata.get("allowed_passthrough_routes") is not None + ): + ## FILTER pass_through_endpoints by allowed_passthrough_routes + pass_through_endpoints = [ + endpoint + for endpoint in pass_through_endpoints + if endpoint.path in team_metadata.get("allowed_passthrough_routes") + ] + + return pass_through_endpoints + + +@router.get( + "/config/pass_through_endpoint", + dependencies=[Depends(user_api_key_auth)], + response_model=PassThroughEndpointResponse, +) +@router.get( + "/config/pass_through_endpoint/team/{team_id}", + dependencies=[Depends(user_api_key_auth)], + response_model=PassThroughEndpointResponse, +) +async def get_pass_through_endpoints( + endpoint_id: Optional[str] = None, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), + team_id: Optional[str] = None, +): + """ + GET configured pass through endpoint. + + If no endpoint_id given, return all configured endpoints. + """ ## Get existing pass-through endpoint field value + from litellm.proxy._types import CommonProxyErrors + from litellm.proxy.proxy_server import prisma_client + + if prisma_client is None: + raise HTTPException( + status_code=500, + detail={"error": CommonProxyErrors.db_not_connected_error.value}, + ) + + # Get endpoints from DB (editable via UI) + db_endpoints = await _get_pass_through_endpoints_from_db( + endpoint_id=endpoint_id, user_api_key_dict=user_api_key_dict + ) + + # Get endpoints from config file (read-only, not editable via UI) + config_endpoints = _get_pass_through_endpoints_from_config() + + # Merge: config endpoints not in DB + all DB endpoints (DB overrides config for same path) + db_paths = {ep.path for ep in db_endpoints} + config_only_endpoints = [ep for ep in config_endpoints if ep.path not in db_paths] + if endpoint_id is not None: + # When filtering by endpoint_id, only return if found in DB (config endpoints use generated IDs) + pass_through_endpoints = db_endpoints + else: + pass_through_endpoints = config_only_endpoints + db_endpoints + + if team_id is not None: + pass_through_endpoints = await _filter_endpoints_by_team_allowed_routes( + team_id=team_id, + pass_through_endpoints=pass_through_endpoints, + prisma_client=prisma_client, + ) + + return PassThroughEndpointResponse(endpoints=pass_through_endpoints) + + +@router.post( + "/config/pass_through_endpoint/{endpoint_id}", + dependencies=[Depends(user_api_key_auth)], +) +async def update_pass_through_endpoints( + endpoint_id: str, + data: PassThroughGenericEndpoint, + request: Request, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): + """ + Update a pass-through endpoint by ID. + """ + from litellm.proxy.proxy_server import ( + get_config_general_settings, + update_config_general_settings, + ) + + ## Get existing pass-through endpoint field value + try: + response: ConfigFieldInfo = await get_config_general_settings( + field_name="pass_through_endpoints", user_api_key_dict=user_api_key_dict + ) + except Exception: + raise HTTPException( + status_code=404, + detail={"error": "No pass-through endpoints found"}, + ) + + pass_through_endpoint_data: Optional[List] = response.field_value + if pass_through_endpoint_data is None: + raise HTTPException( + status_code=404, + detail={"error": "No pass-through endpoints found"}, + ) + + # Find the endpoint to update + found_endpoint = _find_endpoint_by_id(pass_through_endpoint_data, endpoint_id) + + if found_endpoint is None: + raise HTTPException( + status_code=404, + detail={"error": f"Endpoint with ID '{endpoint_id}' not found"}, + ) + + # Find the index for updating the list + endpoint_index = None + for idx, endpoint in enumerate(pass_through_endpoint_data): + _endpoint = ( + PassThroughGenericEndpoint(**endpoint) + if isinstance(endpoint, dict) + else endpoint + ) + if _endpoint.id == endpoint_id: + endpoint_index = idx + break + + if endpoint_index is None: + raise HTTPException( + status_code=404, + detail={ + "error": f"Could not find index for endpoint with ID '{endpoint_id}'" + }, + ) + + # Get the update data as dict, excluding None values for partial updates + # Exclude is_from_config as it's a response-only field (computed at read time) + update_data = data.model_dump(exclude_none=True, exclude={"is_from_config"}) + + # Start with existing endpoint data + endpoint_dict = found_endpoint.model_dump() + + # Update with new data (only non-None values) + endpoint_dict.update(update_data) + + # Preserve existing ID if not provided in update and endpoint has ID + if "id" not in update_data and found_endpoint.id is not None: + endpoint_dict["id"] = found_endpoint.id + + # Remove is_from_config before saving - it's a response-only field (computed at read time) + endpoint_dict.pop("is_from_config", None) + + # Create updated endpoint object + updated_endpoint = PassThroughGenericEndpoint(**endpoint_dict) + + # Update the list + pass_through_endpoint_data[endpoint_index] = endpoint_dict + + # Remove old routes from registry before they get re-registered + InitPassThroughEndpointHelpers.remove_endpoint_routes(endpoint_id) + + ## Update db + updated_data = ConfigFieldUpdate( + field_name="pass_through_endpoints", + field_value=pass_through_endpoint_data, + config_type="general_settings", + ) + + await update_config_general_settings( + data=updated_data, user_api_key_dict=user_api_key_dict + ) + + # Re-register the route with updated headers + _custom_headers: Optional[dict] = updated_endpoint.headers or {} + _custom_headers = await set_env_variables_in_header(custom_headers=_custom_headers) + + if updated_endpoint.include_subpath: + InitPassThroughEndpointHelpers.add_subpath_route( + app=request.app, + path=updated_endpoint.path, + target=updated_endpoint.target, + custom_headers=_custom_headers, + forward_headers=None, # Defaults not available in model? assuming None logic handles it + merge_query_params=None, + dependencies=None, + cost_per_request=updated_endpoint.cost_per_request, + endpoint_id=updated_endpoint.id or endpoint_id or "", + guardrails=getattr(updated_endpoint, "guardrails", None), + ) + else: + InitPassThroughEndpointHelpers.add_exact_path_route( + app=request.app, + path=updated_endpoint.path, + target=updated_endpoint.target, + custom_headers=_custom_headers, + forward_headers=None, + merge_query_params=None, + dependencies=None, + cost_per_request=updated_endpoint.cost_per_request, + endpoint_id=updated_endpoint.id or endpoint_id or "", + guardrails=getattr(updated_endpoint, "guardrails", None), + ) + + return PassThroughEndpointResponse( + endpoints=[updated_endpoint] if updated_endpoint else [] + ) + + +@router.post( + "/config/pass_through_endpoint", + dependencies=[Depends(user_api_key_auth)], +) +async def create_pass_through_endpoints( + data: PassThroughGenericEndpoint, + request: Request, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): + """ + Create new pass-through endpoint + """ + from litellm._uuid import uuid + from litellm.proxy.proxy_server import ( + get_config_general_settings, + update_config_general_settings, + ) + + ## Get existing pass-through endpoint field value + + try: + response: ConfigFieldInfo = await get_config_general_settings( + field_name="pass_through_endpoints", user_api_key_dict=user_api_key_dict + ) + except Exception: + response = ConfigFieldInfo( + field_name="pass_through_endpoints", field_value=None + ) + + ## Auto-generate ID if not provided + # Exclude is_from_config as it's a response-only field (computed at read time) + data_dict = data.model_dump(exclude={"is_from_config"}) + if data_dict.get("id") is None: + data_dict["id"] = str(uuid.uuid4()) + + if response.field_value is None: + response.field_value = [data_dict] + elif isinstance(response.field_value, List): + response.field_value.append(data_dict) + + ## Update db + updated_data = ConfigFieldUpdate( + field_name="pass_through_endpoints", + field_value=response.field_value, + config_type="general_settings", + ) + await update_config_general_settings( + data=updated_data, user_api_key_dict=user_api_key_dict + ) + + # Return the created endpoint with the generated ID + created_endpoint = PassThroughGenericEndpoint(**data_dict) + + # Register the new route + _custom_headers: Optional[dict] = created_endpoint.headers or {} + _custom_headers = await set_env_variables_in_header(custom_headers=_custom_headers) + + if created_endpoint.include_subpath: + InitPassThroughEndpointHelpers.add_subpath_route( + app=request.app, + path=created_endpoint.path, + target=created_endpoint.target, + custom_headers=_custom_headers, + forward_headers=None, + merge_query_params=None, + dependencies=None, + cost_per_request=created_endpoint.cost_per_request, + endpoint_id=created_endpoint.id or "", + guardrails=getattr(created_endpoint, "guardrails", None), + ) + else: + InitPassThroughEndpointHelpers.add_exact_path_route( + app=request.app, + path=created_endpoint.path, + target=created_endpoint.target, + custom_headers=_custom_headers, + forward_headers=None, + merge_query_params=None, + dependencies=None, + cost_per_request=created_endpoint.cost_per_request, + endpoint_id=created_endpoint.id or "", + guardrails=getattr(created_endpoint, "guardrails", None), + ) + + return PassThroughEndpointResponse(endpoints=[created_endpoint]) + + +@router.delete( + "/config/pass_through_endpoint", + dependencies=[Depends(user_api_key_auth)], + response_model=PassThroughEndpointResponse, +) +async def delete_pass_through_endpoints( + endpoint_id: str, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): + """ + Delete a pass-through endpoint by ID. + + Returns - the deleted endpoint + """ + from litellm.proxy.proxy_server import ( + get_config_general_settings, + update_config_general_settings, + ) + + ## Get existing pass-through endpoint field value + + try: + response: ConfigFieldInfo = await get_config_general_settings( + field_name="pass_through_endpoints", user_api_key_dict=user_api_key_dict + ) + except Exception: + response = ConfigFieldInfo( + field_name="pass_through_endpoints", field_value=None + ) + + ## Update field by removing endpoint + pass_through_endpoint_data: Optional[List] = response.field_value + if response.field_value is None or pass_through_endpoint_data is None: + raise HTTPException( + status_code=400, + detail={"error": "There are no pass-through endpoints setup."}, + ) + + # Find the endpoint to delete + found_endpoint = _find_endpoint_by_id(pass_through_endpoint_data, endpoint_id) + + if found_endpoint is None: + raise HTTPException( + status_code=400, + detail={ + "error": "Endpoint with ID '{}' was not found in pass-through endpoint list.".format( + endpoint_id + ) + }, + ) + + # Find the index for deleting from the list + endpoint_index = None + for idx, endpoint in enumerate(pass_through_endpoint_data): + _endpoint = ( + PassThroughGenericEndpoint(**endpoint) + if isinstance(endpoint, dict) + else endpoint + ) + if _endpoint.id == endpoint_id: + endpoint_index = idx + break + + if endpoint_index is None: + raise HTTPException( + status_code=400, + detail={ + "error": f"Could not find index for endpoint with ID '{endpoint_id}'" + }, + ) + + # Remove the endpoint + pass_through_endpoint_data.pop(endpoint_index) + response_obj = found_endpoint + + # Remove routes from registry + InitPassThroughEndpointHelpers.remove_endpoint_routes(endpoint_id) + + ## Update db + updated_data = ConfigFieldUpdate( + field_name="pass_through_endpoints", + field_value=pass_through_endpoint_data, + config_type="general_settings", + ) + await update_config_general_settings( + data=updated_data, user_api_key_dict=user_api_key_dict + ) + + return PassThroughEndpointResponse(endpoints=[response_obj]) + + +def _find_endpoint_by_id( + endpoints_data: List, + endpoint_id: str, +) -> Optional[PassThroughGenericEndpoint]: + """ + Find an endpoint by ID. + + Args: + endpoints_data: List of endpoint data (dicts or PassThroughGenericEndpoint objects) + endpoint_id: ID to search for + + Returns: + Found endpoint or None if not found + """ + for endpoint in endpoints_data: + _endpoint: Optional[PassThroughGenericEndpoint] = None + if isinstance(endpoint, dict): + _endpoint = PassThroughGenericEndpoint(**endpoint) + elif isinstance(endpoint, PassThroughGenericEndpoint): + _endpoint = endpoint + + # Only compare IDs to IDs + if _endpoint is not None and _endpoint.id == endpoint_id: + return _endpoint + + return None + + +async def initialize_pass_through_endpoints_in_db(): + """ + Gets all pass-through endpoints from db and initializes them in the proxy server. + """ + pass_through_endpoints = await _get_pass_through_endpoints_from_db() + await initialize_pass_through_endpoints( + pass_through_endpoints=pass_through_endpoints + )