mirror of
https://github.com/tiennm99/litellm.git
synced 2026-06-18 05:28:02 +00:00
fix mypy error
This commit is contained in:
@@ -1,309 +1,311 @@
|
||||
"""
|
||||
PagerDuty Alerting Integration
|
||||
|
||||
Handles two types of alerts:
|
||||
- High LLM API Failure Rate. Configure X fails in Y seconds to trigger an alert.
|
||||
- High Number of Hanging LLM Requests. Configure X hangs in Y seconds to trigger an alert.
|
||||
|
||||
Note: This is a Free feature on the regular litellm docker image.
|
||||
|
||||
However, this is under the enterprise license
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import List, Literal, Optional, Union
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.caching import DualCache
|
||||
from litellm.integrations.SlackAlerting.slack_alerting import SlackAlerting
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
AsyncHTTPHandler,
|
||||
get_async_httpx_client,
|
||||
httpxSpecialProvider,
|
||||
)
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
from litellm.types.integrations.pagerduty import (
|
||||
AlertingConfig,
|
||||
PagerDutyInternalEvent,
|
||||
PagerDutyPayload,
|
||||
PagerDutyRequestBody,
|
||||
)
|
||||
from litellm.types.utils import (
|
||||
CallTypesLiteral,
|
||||
StandardLoggingPayload,
|
||||
StandardLoggingPayloadErrorInformation,
|
||||
)
|
||||
|
||||
PAGERDUTY_DEFAULT_FAILURE_THRESHOLD = 60
|
||||
PAGERDUTY_DEFAULT_FAILURE_THRESHOLD_WINDOW_SECONDS = 60
|
||||
PAGERDUTY_DEFAULT_HANGING_THRESHOLD_SECONDS = 60
|
||||
PAGERDUTY_DEFAULT_HANGING_THRESHOLD_WINDOW_SECONDS = 600
|
||||
|
||||
|
||||
class PagerDutyAlerting(SlackAlerting):
|
||||
"""
|
||||
Tracks failed requests and hanging requests separately.
|
||||
If threshold is crossed for either type, triggers a PagerDuty alert.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, alerting_args: Optional[Union[AlertingConfig, dict]] = None, **kwargs
|
||||
):
|
||||
super().__init__()
|
||||
_api_key = os.getenv("PAGERDUTY_API_KEY")
|
||||
if not _api_key:
|
||||
raise ValueError("PAGERDUTY_API_KEY is not set")
|
||||
|
||||
self.api_key: str = _api_key
|
||||
alerting_args = alerting_args or {}
|
||||
self.pagerduty_alerting_args: AlertingConfig = AlertingConfig(
|
||||
failure_threshold=alerting_args.get(
|
||||
"failure_threshold", PAGERDUTY_DEFAULT_FAILURE_THRESHOLD
|
||||
),
|
||||
failure_threshold_window_seconds=alerting_args.get(
|
||||
"failure_threshold_window_seconds",
|
||||
PAGERDUTY_DEFAULT_FAILURE_THRESHOLD_WINDOW_SECONDS,
|
||||
),
|
||||
hanging_threshold_seconds=alerting_args.get(
|
||||
"hanging_threshold_seconds", PAGERDUTY_DEFAULT_HANGING_THRESHOLD_SECONDS
|
||||
),
|
||||
hanging_threshold_window_seconds=alerting_args.get(
|
||||
"hanging_threshold_window_seconds",
|
||||
PAGERDUTY_DEFAULT_HANGING_THRESHOLD_WINDOW_SECONDS,
|
||||
),
|
||||
)
|
||||
|
||||
# Separate storage for failures vs. hangs
|
||||
self._failure_events: List[PagerDutyInternalEvent] = []
|
||||
self._hanging_events: List[PagerDutyInternalEvent] = []
|
||||
|
||||
# ------------------ MAIN LOGIC ------------------ #
|
||||
|
||||
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
||||
"""
|
||||
Record a failure event. Only send an alert to PagerDuty if the
|
||||
configured *failure* threshold is exceeded in the specified window.
|
||||
"""
|
||||
now = datetime.now(timezone.utc)
|
||||
standard_logging_payload: Optional[StandardLoggingPayload] = kwargs.get(
|
||||
"standard_logging_object"
|
||||
)
|
||||
if not standard_logging_payload:
|
||||
raise ValueError(
|
||||
"standard_logging_object is required for PagerDutyAlerting"
|
||||
)
|
||||
|
||||
# Extract error details
|
||||
error_info: Optional[StandardLoggingPayloadErrorInformation] = (
|
||||
standard_logging_payload.get("error_information") or {}
|
||||
)
|
||||
_meta = standard_logging_payload.get("metadata") or {}
|
||||
|
||||
self._failure_events.append(
|
||||
PagerDutyInternalEvent(
|
||||
failure_event_type="failed_response",
|
||||
timestamp=now,
|
||||
error_class=error_info.get("error_class"),
|
||||
error_code=error_info.get("error_code"),
|
||||
error_llm_provider=error_info.get("llm_provider"),
|
||||
user_api_key_hash=_meta.get("user_api_key_hash"),
|
||||
user_api_key_alias=_meta.get("user_api_key_alias"),
|
||||
user_api_key_spend=_meta.get("user_api_key_spend"),
|
||||
user_api_key_max_budget=_meta.get("user_api_key_max_budget"),
|
||||
user_api_key_budget_reset_at=_meta.get("user_api_key_budget_reset_at"),
|
||||
user_api_key_org_id=_meta.get("user_api_key_org_id"),
|
||||
user_api_key_team_id=_meta.get("user_api_key_team_id"),
|
||||
user_api_key_user_id=_meta.get("user_api_key_user_id"),
|
||||
user_api_key_team_alias=_meta.get("user_api_key_team_alias"),
|
||||
user_api_key_end_user_id=_meta.get("user_api_key_end_user_id"),
|
||||
user_api_key_user_email=_meta.get("user_api_key_user_email"),
|
||||
user_api_key_request_route=_meta.get("user_api_key_request_route"),
|
||||
user_api_key_auth_metadata=_meta.get("user_api_key_auth_metadata"),
|
||||
)
|
||||
)
|
||||
|
||||
# Prune + Possibly alert
|
||||
window_seconds = self.pagerduty_alerting_args.get(
|
||||
"failure_threshold_window_seconds", 60
|
||||
)
|
||||
threshold = self.pagerduty_alerting_args.get("failure_threshold", 1)
|
||||
|
||||
# If threshold is crossed, send PD alert for failures
|
||||
await self._send_alert_if_thresholds_crossed(
|
||||
events=self._failure_events,
|
||||
window_seconds=window_seconds,
|
||||
threshold=threshold,
|
||||
alert_prefix="High LLM API Failure Rate",
|
||||
)
|
||||
|
||||
async def async_pre_call_hook(
|
||||
self,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
cache: DualCache,
|
||||
data: dict,
|
||||
call_type: CallTypesLiteral,
|
||||
) -> Optional[Union[Exception, str, dict]]:
|
||||
"""
|
||||
Example of detecting hanging requests by waiting a given threshold.
|
||||
If the request didn't finish by then, we treat it as 'hanging'.
|
||||
"""
|
||||
verbose_logger.info("Inside Proxy Logging Pre-call hook!")
|
||||
asyncio.create_task(
|
||||
self.hanging_response_handler(
|
||||
request_data=data, user_api_key_dict=user_api_key_dict
|
||||
)
|
||||
)
|
||||
return None
|
||||
|
||||
async def hanging_response_handler(
|
||||
self, request_data: Optional[dict], user_api_key_dict: UserAPIKeyAuth
|
||||
):
|
||||
"""
|
||||
Checks if request completed by the time 'hanging_threshold_seconds' elapses.
|
||||
If not, we classify it as a hanging request.
|
||||
"""
|
||||
verbose_logger.debug(
|
||||
f"Inside Hanging Response Handler!..sleeping for {self.pagerduty_alerting_args.get('hanging_threshold_seconds', PAGERDUTY_DEFAULT_HANGING_THRESHOLD_SECONDS)} seconds"
|
||||
)
|
||||
await asyncio.sleep(
|
||||
self.pagerduty_alerting_args.get(
|
||||
"hanging_threshold_seconds", PAGERDUTY_DEFAULT_HANGING_THRESHOLD_SECONDS
|
||||
)
|
||||
)
|
||||
|
||||
if await self._request_is_completed(request_data=request_data):
|
||||
return # It's not hanging if completed
|
||||
|
||||
# Otherwise, record it as hanging
|
||||
self._hanging_events.append(
|
||||
PagerDutyInternalEvent(
|
||||
failure_event_type="hanging_response",
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
error_class="HangingRequest",
|
||||
error_code="HangingRequest",
|
||||
error_llm_provider="HangingRequest",
|
||||
user_api_key_hash=user_api_key_dict.api_key,
|
||||
user_api_key_alias=user_api_key_dict.key_alias,
|
||||
user_api_key_spend=user_api_key_dict.spend,
|
||||
user_api_key_max_budget=user_api_key_dict.max_budget,
|
||||
user_api_key_budget_reset_at=(
|
||||
user_api_key_dict.budget_reset_at.isoformat()
|
||||
if user_api_key_dict.budget_reset_at
|
||||
else None
|
||||
),
|
||||
user_api_key_org_id=user_api_key_dict.org_id,
|
||||
user_api_key_team_id=user_api_key_dict.team_id,
|
||||
user_api_key_user_id=user_api_key_dict.user_id,
|
||||
user_api_key_team_alias=user_api_key_dict.team_alias,
|
||||
user_api_key_end_user_id=user_api_key_dict.end_user_id,
|
||||
user_api_key_user_email=user_api_key_dict.user_email,
|
||||
user_api_key_request_route=user_api_key_dict.request_route,
|
||||
user_api_key_auth_metadata=user_api_key_dict.metadata,
|
||||
)
|
||||
)
|
||||
|
||||
# Prune + Possibly alert
|
||||
window_seconds = self.pagerduty_alerting_args.get(
|
||||
"hanging_threshold_window_seconds",
|
||||
PAGERDUTY_DEFAULT_HANGING_THRESHOLD_WINDOW_SECONDS,
|
||||
)
|
||||
threshold: int = self.pagerduty_alerting_args.get(
|
||||
"hanging_threshold_fails", PAGERDUTY_DEFAULT_HANGING_THRESHOLD_SECONDS
|
||||
)
|
||||
|
||||
# If threshold is crossed, send PD alert for hangs
|
||||
await self._send_alert_if_thresholds_crossed(
|
||||
events=self._hanging_events,
|
||||
window_seconds=window_seconds,
|
||||
threshold=threshold,
|
||||
alert_prefix="High Number of Hanging LLM Requests",
|
||||
)
|
||||
|
||||
# ------------------ HELPERS ------------------ #
|
||||
|
||||
async def _send_alert_if_thresholds_crossed(
|
||||
self,
|
||||
events: List[PagerDutyInternalEvent],
|
||||
window_seconds: int,
|
||||
threshold: int,
|
||||
alert_prefix: str,
|
||||
):
|
||||
"""
|
||||
1. Prune old events
|
||||
2. If threshold is reached, build alert, send to PagerDuty
|
||||
3. Clear those events
|
||||
"""
|
||||
cutoff = datetime.now(timezone.utc) - timedelta(seconds=window_seconds)
|
||||
pruned = [e for e in events if e.get("timestamp", datetime.min) > cutoff]
|
||||
|
||||
# Update the reference list
|
||||
events.clear()
|
||||
events.extend(pruned)
|
||||
|
||||
# Check threshold
|
||||
verbose_logger.debug(
|
||||
f"Have {len(events)} events in the last {window_seconds} seconds. Threshold is {threshold}"
|
||||
)
|
||||
if len(events) >= threshold:
|
||||
# Build short summary of last N events
|
||||
error_summaries = self._build_error_summaries(events, max_errors=5)
|
||||
alert_message = (
|
||||
f"{alert_prefix}: {len(events)} in the last {window_seconds} seconds."
|
||||
)
|
||||
custom_details = {"recent_errors": error_summaries}
|
||||
|
||||
await self.send_alert_to_pagerduty(
|
||||
alert_message=alert_message,
|
||||
custom_details=custom_details,
|
||||
)
|
||||
|
||||
# Clear them after sending an alert, so we don't spam
|
||||
events.clear()
|
||||
|
||||
def _build_error_summaries(
|
||||
self, events: List[PagerDutyInternalEvent], max_errors: int = 5
|
||||
) -> List[PagerDutyInternalEvent]:
|
||||
"""
|
||||
Build short text summaries for the last `max_errors`.
|
||||
Example: "ValueError (code: 500, provider: openai)"
|
||||
"""
|
||||
recent = events[-max_errors:]
|
||||
summaries = []
|
||||
for fe in recent:
|
||||
# If any of these is None, show "N/A" to avoid messing up the summary string
|
||||
fe.pop("timestamp")
|
||||
summaries.append(fe)
|
||||
return summaries
|
||||
|
||||
async def send_alert_to_pagerduty(self, alert_message: str, custom_details: dict):
|
||||
"""
|
||||
Send [critical] Alert to PagerDuty
|
||||
|
||||
https://developer.pagerduty.com/api-reference/YXBpOjI3NDgyNjU-pager-duty-v2-events-api
|
||||
"""
|
||||
try:
|
||||
verbose_logger.debug(f"Sending alert to PagerDuty: {alert_message}")
|
||||
async_client: AsyncHTTPHandler = get_async_httpx_client(
|
||||
llm_provider=httpxSpecialProvider.LoggingCallback
|
||||
)
|
||||
payload: PagerDutyRequestBody = PagerDutyRequestBody(
|
||||
payload=PagerDutyPayload(
|
||||
summary=alert_message,
|
||||
severity="critical",
|
||||
source="LiteLLM Alert",
|
||||
component="LiteLLM",
|
||||
custom_details=custom_details,
|
||||
),
|
||||
routing_key=self.api_key,
|
||||
event_action="trigger",
|
||||
)
|
||||
|
||||
return await async_client.post(
|
||||
url="https://events.pagerduty.com/v2/enqueue",
|
||||
json=dict(payload),
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_logger.exception(f"Error sending alert to PagerDuty: {e}")
|
||||
"""
|
||||
PagerDuty Alerting Integration
|
||||
|
||||
Handles two types of alerts:
|
||||
- High LLM API Failure Rate. Configure X fails in Y seconds to trigger an alert.
|
||||
- High Number of Hanging LLM Requests. Configure X hangs in Y seconds to trigger an alert.
|
||||
|
||||
Note: This is a Free feature on the regular litellm docker image.
|
||||
|
||||
However, this is under the enterprise license
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.caching import DualCache
|
||||
from litellm.integrations.SlackAlerting.slack_alerting import SlackAlerting
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
AsyncHTTPHandler,
|
||||
get_async_httpx_client,
|
||||
httpxSpecialProvider,
|
||||
)
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
from litellm.types.integrations.pagerduty import (
|
||||
AlertingConfig,
|
||||
PagerDutyInternalEvent,
|
||||
PagerDutyPayload,
|
||||
PagerDutyRequestBody,
|
||||
)
|
||||
from litellm.types.utils import (
|
||||
CallTypesLiteral,
|
||||
StandardLoggingPayload,
|
||||
StandardLoggingPayloadErrorInformation,
|
||||
)
|
||||
|
||||
PAGERDUTY_DEFAULT_FAILURE_THRESHOLD = 60
|
||||
PAGERDUTY_DEFAULT_FAILURE_THRESHOLD_WINDOW_SECONDS = 60
|
||||
PAGERDUTY_DEFAULT_HANGING_THRESHOLD_SECONDS = 60
|
||||
PAGERDUTY_DEFAULT_HANGING_THRESHOLD_WINDOW_SECONDS = 600
|
||||
|
||||
|
||||
class PagerDutyAlerting(SlackAlerting):
|
||||
"""
|
||||
Tracks failed requests and hanging requests separately.
|
||||
If threshold is crossed for either type, triggers a PagerDuty alert.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, alerting_args: Optional[Union[AlertingConfig, dict]] = None, **kwargs
|
||||
):
|
||||
super().__init__()
|
||||
_api_key = os.getenv("PAGERDUTY_API_KEY")
|
||||
if not _api_key:
|
||||
raise ValueError("PAGERDUTY_API_KEY is not set")
|
||||
|
||||
self.api_key: str = _api_key
|
||||
alerting_args = alerting_args or {}
|
||||
self.pagerduty_alerting_args: AlertingConfig = AlertingConfig(
|
||||
failure_threshold=alerting_args.get(
|
||||
"failure_threshold", PAGERDUTY_DEFAULT_FAILURE_THRESHOLD
|
||||
),
|
||||
failure_threshold_window_seconds=alerting_args.get(
|
||||
"failure_threshold_window_seconds",
|
||||
PAGERDUTY_DEFAULT_FAILURE_THRESHOLD_WINDOW_SECONDS,
|
||||
),
|
||||
hanging_threshold_seconds=alerting_args.get(
|
||||
"hanging_threshold_seconds", PAGERDUTY_DEFAULT_HANGING_THRESHOLD_SECONDS
|
||||
),
|
||||
hanging_threshold_window_seconds=alerting_args.get(
|
||||
"hanging_threshold_window_seconds",
|
||||
PAGERDUTY_DEFAULT_HANGING_THRESHOLD_WINDOW_SECONDS,
|
||||
),
|
||||
)
|
||||
|
||||
# Separate storage for failures vs. hangs
|
||||
self._failure_events: List[PagerDutyInternalEvent] = []
|
||||
self._hanging_events: List[PagerDutyInternalEvent] = []
|
||||
|
||||
# ------------------ MAIN LOGIC ------------------ #
|
||||
|
||||
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
||||
"""
|
||||
Record a failure event. Only send an alert to PagerDuty if the
|
||||
configured *failure* threshold is exceeded in the specified window.
|
||||
"""
|
||||
now = datetime.now(timezone.utc)
|
||||
standard_logging_payload: Optional[StandardLoggingPayload] = kwargs.get(
|
||||
"standard_logging_object"
|
||||
)
|
||||
if not standard_logging_payload:
|
||||
raise ValueError(
|
||||
"standard_logging_object is required for PagerDutyAlerting"
|
||||
)
|
||||
|
||||
# Extract error details
|
||||
error_info: Optional[StandardLoggingPayloadErrorInformation] = (
|
||||
standard_logging_payload.get("error_information") or {}
|
||||
)
|
||||
_meta = standard_logging_payload.get("metadata") or {}
|
||||
|
||||
self._failure_events.append(
|
||||
PagerDutyInternalEvent(
|
||||
failure_event_type="failed_response",
|
||||
timestamp=now,
|
||||
error_class=error_info.get("error_class"),
|
||||
error_code=error_info.get("error_code"),
|
||||
error_llm_provider=error_info.get("llm_provider"),
|
||||
user_api_key_hash=_meta.get("user_api_key_hash"),
|
||||
user_api_key_alias=_meta.get("user_api_key_alias"),
|
||||
user_api_key_spend=_meta.get("user_api_key_spend"),
|
||||
user_api_key_max_budget=_meta.get("user_api_key_max_budget"),
|
||||
user_api_key_budget_reset_at=_meta.get("user_api_key_budget_reset_at"),
|
||||
user_api_key_org_id=_meta.get("user_api_key_org_id"),
|
||||
user_api_key_team_id=_meta.get("user_api_key_team_id"),
|
||||
user_api_key_project_id=_meta.get("user_api_key_project_id"),
|
||||
user_api_key_user_id=_meta.get("user_api_key_user_id"),
|
||||
user_api_key_team_alias=_meta.get("user_api_key_team_alias"),
|
||||
user_api_key_end_user_id=_meta.get("user_api_key_end_user_id"),
|
||||
user_api_key_user_email=_meta.get("user_api_key_user_email"),
|
||||
user_api_key_request_route=_meta.get("user_api_key_request_route"),
|
||||
user_api_key_auth_metadata=_meta.get("user_api_key_auth_metadata"),
|
||||
)
|
||||
)
|
||||
|
||||
# Prune + Possibly alert
|
||||
window_seconds = self.pagerduty_alerting_args.get(
|
||||
"failure_threshold_window_seconds", 60
|
||||
)
|
||||
threshold = self.pagerduty_alerting_args.get("failure_threshold", 1)
|
||||
|
||||
# If threshold is crossed, send PD alert for failures
|
||||
await self._send_alert_if_thresholds_crossed(
|
||||
events=self._failure_events,
|
||||
window_seconds=window_seconds,
|
||||
threshold=threshold,
|
||||
alert_prefix="High LLM API Failure Rate",
|
||||
)
|
||||
|
||||
async def async_pre_call_hook(
|
||||
self,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
cache: DualCache,
|
||||
data: dict,
|
||||
call_type: CallTypesLiteral,
|
||||
) -> Optional[Union[Exception, str, dict]]:
|
||||
"""
|
||||
Example of detecting hanging requests by waiting a given threshold.
|
||||
If the request didn't finish by then, we treat it as 'hanging'.
|
||||
"""
|
||||
verbose_logger.info("Inside Proxy Logging Pre-call hook!")
|
||||
asyncio.create_task(
|
||||
self.hanging_response_handler(
|
||||
request_data=data, user_api_key_dict=user_api_key_dict
|
||||
)
|
||||
)
|
||||
return None
|
||||
|
||||
async def hanging_response_handler(
|
||||
self, request_data: Optional[dict], user_api_key_dict: UserAPIKeyAuth
|
||||
):
|
||||
"""
|
||||
Checks if request completed by the time 'hanging_threshold_seconds' elapses.
|
||||
If not, we classify it as a hanging request.
|
||||
"""
|
||||
verbose_logger.debug(
|
||||
f"Inside Hanging Response Handler!..sleeping for {self.pagerduty_alerting_args.get('hanging_threshold_seconds', PAGERDUTY_DEFAULT_HANGING_THRESHOLD_SECONDS)} seconds"
|
||||
)
|
||||
await asyncio.sleep(
|
||||
self.pagerduty_alerting_args.get(
|
||||
"hanging_threshold_seconds", PAGERDUTY_DEFAULT_HANGING_THRESHOLD_SECONDS
|
||||
)
|
||||
)
|
||||
|
||||
if await self._request_is_completed(request_data=request_data):
|
||||
return # It's not hanging if completed
|
||||
|
||||
# Otherwise, record it as hanging
|
||||
self._hanging_events.append(
|
||||
PagerDutyInternalEvent(
|
||||
failure_event_type="hanging_response",
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
error_class="HangingRequest",
|
||||
error_code="HangingRequest",
|
||||
error_llm_provider="HangingRequest",
|
||||
user_api_key_hash=user_api_key_dict.api_key,
|
||||
user_api_key_alias=user_api_key_dict.key_alias,
|
||||
user_api_key_spend=user_api_key_dict.spend,
|
||||
user_api_key_max_budget=user_api_key_dict.max_budget,
|
||||
user_api_key_budget_reset_at=(
|
||||
user_api_key_dict.budget_reset_at.isoformat()
|
||||
if user_api_key_dict.budget_reset_at
|
||||
else None
|
||||
),
|
||||
user_api_key_org_id=user_api_key_dict.org_id,
|
||||
user_api_key_team_id=user_api_key_dict.team_id,
|
||||
user_api_key_project_id=user_api_key_dict.project_id,
|
||||
user_api_key_user_id=user_api_key_dict.user_id,
|
||||
user_api_key_team_alias=user_api_key_dict.team_alias,
|
||||
user_api_key_end_user_id=user_api_key_dict.end_user_id,
|
||||
user_api_key_user_email=user_api_key_dict.user_email,
|
||||
user_api_key_request_route=user_api_key_dict.request_route,
|
||||
user_api_key_auth_metadata=user_api_key_dict.metadata,
|
||||
)
|
||||
)
|
||||
|
||||
# Prune + Possibly alert
|
||||
window_seconds = self.pagerduty_alerting_args.get(
|
||||
"hanging_threshold_window_seconds",
|
||||
PAGERDUTY_DEFAULT_HANGING_THRESHOLD_WINDOW_SECONDS,
|
||||
)
|
||||
threshold: int = self.pagerduty_alerting_args.get(
|
||||
"hanging_threshold_fails", PAGERDUTY_DEFAULT_HANGING_THRESHOLD_SECONDS
|
||||
)
|
||||
|
||||
# If threshold is crossed, send PD alert for hangs
|
||||
await self._send_alert_if_thresholds_crossed(
|
||||
events=self._hanging_events,
|
||||
window_seconds=window_seconds,
|
||||
threshold=threshold,
|
||||
alert_prefix="High Number of Hanging LLM Requests",
|
||||
)
|
||||
|
||||
# ------------------ HELPERS ------------------ #
|
||||
|
||||
async def _send_alert_if_thresholds_crossed(
|
||||
self,
|
||||
events: List[PagerDutyInternalEvent],
|
||||
window_seconds: int,
|
||||
threshold: int,
|
||||
alert_prefix: str,
|
||||
):
|
||||
"""
|
||||
1. Prune old events
|
||||
2. If threshold is reached, build alert, send to PagerDuty
|
||||
3. Clear those events
|
||||
"""
|
||||
cutoff = datetime.now(timezone.utc) - timedelta(seconds=window_seconds)
|
||||
pruned = [e for e in events if e.get("timestamp", datetime.min) > cutoff]
|
||||
|
||||
# Update the reference list
|
||||
events.clear()
|
||||
events.extend(pruned)
|
||||
|
||||
# Check threshold
|
||||
verbose_logger.debug(
|
||||
f"Have {len(events)} events in the last {window_seconds} seconds. Threshold is {threshold}"
|
||||
)
|
||||
if len(events) >= threshold:
|
||||
# Build short summary of last N events
|
||||
error_summaries = self._build_error_summaries(events, max_errors=5)
|
||||
alert_message = (
|
||||
f"{alert_prefix}: {len(events)} in the last {window_seconds} seconds."
|
||||
)
|
||||
custom_details = {"recent_errors": error_summaries}
|
||||
|
||||
await self.send_alert_to_pagerduty(
|
||||
alert_message=alert_message,
|
||||
custom_details=custom_details,
|
||||
)
|
||||
|
||||
# Clear them after sending an alert, so we don't spam
|
||||
events.clear()
|
||||
|
||||
def _build_error_summaries(
|
||||
self, events: List[PagerDutyInternalEvent], max_errors: int = 5
|
||||
) -> List[PagerDutyInternalEvent]:
|
||||
"""
|
||||
Build short text summaries for the last `max_errors`.
|
||||
Example: "ValueError (code: 500, provider: openai)"
|
||||
"""
|
||||
recent = events[-max_errors:]
|
||||
summaries = []
|
||||
for fe in recent:
|
||||
# If any of these is None, show "N/A" to avoid messing up the summary string
|
||||
fe.pop("timestamp")
|
||||
summaries.append(fe)
|
||||
return summaries
|
||||
|
||||
async def send_alert_to_pagerduty(self, alert_message: str, custom_details: dict):
|
||||
"""
|
||||
Send [critical] Alert to PagerDuty
|
||||
|
||||
https://developer.pagerduty.com/api-reference/YXBpOjI3NDgyNjU-pager-duty-v2-events-api
|
||||
"""
|
||||
try:
|
||||
verbose_logger.debug(f"Sending alert to PagerDuty: {alert_message}")
|
||||
async_client: AsyncHTTPHandler = get_async_httpx_client(
|
||||
llm_provider=httpxSpecialProvider.LoggingCallback
|
||||
)
|
||||
payload: PagerDutyRequestBody = PagerDutyRequestBody(
|
||||
payload=PagerDutyPayload(
|
||||
summary=alert_message,
|
||||
severity="critical",
|
||||
source="LiteLLM Alert",
|
||||
component="LiteLLM",
|
||||
custom_details=custom_details,
|
||||
),
|
||||
routing_key=self.api_key,
|
||||
event_action="trigger",
|
||||
)
|
||||
|
||||
return await async_client.post(
|
||||
url="https://events.pagerduty.com/v2/enqueue",
|
||||
json=dict(payload),
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_logger.exception(f"Error sending alert to PagerDuty: {e}")
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,287 +1,295 @@
|
||||
import asyncio
|
||||
import traceback
|
||||
from datetime import datetime
|
||||
from typing import Any, List, Optional, Union, cast
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.litellm_core_utils.core_helpers import (
|
||||
_get_parent_otel_span_from_kwargs,
|
||||
get_litellm_metadata_from_kwargs,
|
||||
)
|
||||
from litellm.litellm_core_utils.litellm_logging import StandardLoggingPayloadSetup
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
from litellm.proxy.auth.auth_checks import log_db_metrics
|
||||
from litellm.proxy.auth.route_checks import RouteChecks
|
||||
from litellm.proxy.utils import ProxyUpdateSpend
|
||||
from litellm.types.utils import (
|
||||
StandardLoggingPayload,
|
||||
StandardLoggingUserAPIKeyMetadata,
|
||||
)
|
||||
from litellm.utils import get_end_user_id_for_cost_tracking
|
||||
|
||||
|
||||
class _ProxyDBLogger(CustomLogger):
|
||||
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
await self._PROXY_track_cost_callback(
|
||||
kwargs, response_obj, start_time, end_time
|
||||
)
|
||||
|
||||
async def async_post_call_failure_hook(
|
||||
self,
|
||||
request_data: dict,
|
||||
original_exception: Exception,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
traceback_str: Optional[str] = None,
|
||||
):
|
||||
request_route = user_api_key_dict.request_route
|
||||
if _ProxyDBLogger._should_track_errors_in_db() is False:
|
||||
return
|
||||
elif request_route is not None and not RouteChecks.is_llm_api_route(
|
||||
route=request_route
|
||||
):
|
||||
return
|
||||
|
||||
from litellm.proxy.proxy_server import proxy_logging_obj
|
||||
|
||||
_metadata = dict(
|
||||
StandardLoggingUserAPIKeyMetadata(
|
||||
user_api_key_hash=user_api_key_dict.api_key,
|
||||
user_api_key_alias=user_api_key_dict.key_alias,
|
||||
user_api_key_spend=user_api_key_dict.spend,
|
||||
user_api_key_max_budget=user_api_key_dict.max_budget,
|
||||
user_api_key_budget_reset_at=(
|
||||
user_api_key_dict.budget_reset_at.isoformat()
|
||||
if user_api_key_dict.budget_reset_at
|
||||
else None
|
||||
),
|
||||
user_api_key_user_email=user_api_key_dict.user_email,
|
||||
user_api_key_user_id=user_api_key_dict.user_id,
|
||||
user_api_key_team_id=user_api_key_dict.team_id,
|
||||
user_api_key_org_id=user_api_key_dict.org_id,
|
||||
user_api_key_team_alias=user_api_key_dict.team_alias,
|
||||
user_api_key_end_user_id=user_api_key_dict.end_user_id,
|
||||
user_api_key_request_route=user_api_key_dict.request_route,
|
||||
user_api_key_auth_metadata=user_api_key_dict.metadata,
|
||||
)
|
||||
)
|
||||
_metadata["user_api_key"] = user_api_key_dict.api_key
|
||||
_metadata["status"] = "failure"
|
||||
_metadata["error_information"] = (
|
||||
StandardLoggingPayloadSetup.get_error_information(
|
||||
original_exception=original_exception,
|
||||
traceback_str=traceback_str,
|
||||
)
|
||||
)
|
||||
|
||||
existing_metadata: dict = request_data.get("metadata", None) or {}
|
||||
existing_metadata.update(_metadata)
|
||||
|
||||
if "litellm_params" not in request_data:
|
||||
request_data["litellm_params"] = {}
|
||||
|
||||
existing_litellm_params = request_data.get("litellm_params", {})
|
||||
existing_litellm_metadata = existing_litellm_params.get("metadata", {}) or {}
|
||||
|
||||
# Preserve tags from existing metadata
|
||||
if existing_litellm_metadata.get("tags"):
|
||||
existing_metadata["tags"] = existing_litellm_metadata.get("tags")
|
||||
|
||||
request_data["litellm_params"]["proxy_server_request"] = (
|
||||
request_data.get("proxy_server_request") or existing_litellm_params.get("proxy_server_request") or {}
|
||||
)
|
||||
request_data["litellm_params"]["metadata"] = existing_metadata
|
||||
|
||||
# Preserve model name and custom_llm_provider
|
||||
if "model" not in request_data:
|
||||
request_data["model"] = existing_litellm_params.get("model") or request_data.get("model", "")
|
||||
if "custom_llm_provider" not in request_data:
|
||||
request_data["custom_llm_provider"] = existing_litellm_params.get("custom_llm_provider") or request_data.get("custom_llm_provider", "")
|
||||
|
||||
await proxy_logging_obj.db_spend_update_writer.update_database(
|
||||
token=user_api_key_dict.api_key,
|
||||
response_cost=0.0,
|
||||
user_id=user_api_key_dict.user_id,
|
||||
end_user_id=user_api_key_dict.end_user_id,
|
||||
team_id=user_api_key_dict.team_id,
|
||||
kwargs=request_data,
|
||||
completion_response=original_exception,
|
||||
start_time=datetime.now(),
|
||||
end_time=datetime.now(),
|
||||
org_id=user_api_key_dict.org_id,
|
||||
)
|
||||
|
||||
@log_db_metrics
|
||||
async def _PROXY_track_cost_callback(
|
||||
self,
|
||||
kwargs, # kwargs to completion
|
||||
completion_response: Optional[
|
||||
Union[litellm.ModelResponse, Any]
|
||||
], # response from completion
|
||||
start_time=None,
|
||||
end_time=None, # start/end time for completion
|
||||
):
|
||||
from litellm.proxy.proxy_server import proxy_logging_obj, update_cache
|
||||
|
||||
verbose_proxy_logger.debug("INSIDE _PROXY_track_cost_callback")
|
||||
try:
|
||||
verbose_proxy_logger.debug(
|
||||
f"kwargs stream: {kwargs.get('stream', None)} + complete streaming response: {kwargs.get('complete_streaming_response', None)}"
|
||||
)
|
||||
parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs=kwargs)
|
||||
litellm_params = kwargs.get("litellm_params", {}) or {}
|
||||
end_user_id = get_end_user_id_for_cost_tracking(litellm_params)
|
||||
metadata = get_litellm_metadata_from_kwargs(kwargs=kwargs)
|
||||
user_id = cast(Optional[str], metadata.get("user_api_key_user_id", None))
|
||||
team_id = cast(Optional[str], metadata.get("user_api_key_team_id", None))
|
||||
org_id = cast(Optional[str], metadata.get("user_api_key_org_id", None))
|
||||
key_alias = cast(Optional[str], metadata.get("user_api_key_alias", None))
|
||||
end_user_max_budget = metadata.get("user_api_end_user_max_budget", None)
|
||||
sl_object: Optional[StandardLoggingPayload] = kwargs.get(
|
||||
"standard_logging_object", None
|
||||
)
|
||||
response_cost = (
|
||||
sl_object.get("response_cost", None)
|
||||
if sl_object is not None
|
||||
else kwargs.get("response_cost", None)
|
||||
)
|
||||
tags: Optional[List[str]] = (
|
||||
sl_object.get("request_tags", None) if sl_object is not None else None
|
||||
)
|
||||
|
||||
if response_cost is not None:
|
||||
user_api_key = metadata.get("user_api_key", None)
|
||||
if kwargs.get("cache_hit", False) is True:
|
||||
response_cost = 0.0
|
||||
verbose_proxy_logger.debug(
|
||||
f"Cache Hit: response_cost {response_cost}, for user_id {user_id}"
|
||||
)
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
f"user_api_key {user_api_key}, user_id {user_id}, team_id {team_id}, end_user_id {end_user_id}"
|
||||
)
|
||||
if _should_track_cost_callback(
|
||||
user_api_key=user_api_key,
|
||||
user_id=user_id,
|
||||
team_id=team_id,
|
||||
end_user_id=end_user_id,
|
||||
):
|
||||
## UPDATE DATABASE
|
||||
await proxy_logging_obj.db_spend_update_writer.update_database(
|
||||
token=user_api_key,
|
||||
response_cost=response_cost,
|
||||
user_id=user_id,
|
||||
end_user_id=end_user_id,
|
||||
team_id=team_id,
|
||||
kwargs=kwargs,
|
||||
completion_response=completion_response,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
org_id=org_id,
|
||||
)
|
||||
|
||||
# update cache
|
||||
asyncio.create_task(
|
||||
update_cache(
|
||||
token=user_api_key,
|
||||
user_id=user_id,
|
||||
end_user_id=end_user_id,
|
||||
response_cost=response_cost,
|
||||
team_id=team_id,
|
||||
parent_otel_span=parent_otel_span,
|
||||
tags=tags,
|
||||
)
|
||||
)
|
||||
|
||||
await proxy_logging_obj.slack_alerting_instance.customer_spend_alert(
|
||||
token=user_api_key,
|
||||
key_alias=key_alias,
|
||||
end_user_id=end_user_id,
|
||||
response_cost=response_cost,
|
||||
max_budget=end_user_max_budget,
|
||||
)
|
||||
else:
|
||||
# Non-model call types (health checks, afile_delete) have no model or standard_logging_object.
|
||||
# Use .get() for "stream" to avoid KeyError on health checks.
|
||||
if sl_object is None and not kwargs.get("model"):
|
||||
verbose_proxy_logger.warning(
|
||||
"Cost tracking - skipping, no standard_logging_object and no model for call_type=%s",
|
||||
kwargs.get("call_type", "unknown"),
|
||||
)
|
||||
return
|
||||
if kwargs.get("stream") is not True or (
|
||||
kwargs.get("stream") is True and "complete_streaming_response" in kwargs
|
||||
):
|
||||
if sl_object is not None:
|
||||
cost_tracking_failure_debug_info: Union[dict, str] = (
|
||||
sl_object["response_cost_failure_debug_info"] # type: ignore
|
||||
or "response_cost_failure_debug_info is None in standard_logging_object"
|
||||
)
|
||||
else:
|
||||
cost_tracking_failure_debug_info = (
|
||||
"standard_logging_object not found"
|
||||
)
|
||||
model = kwargs.get("model")
|
||||
raise Exception(
|
||||
f"Cost tracking failed for model={model}.\nDebug info - {cost_tracking_failure_debug_info}\nAdd custom pricing - https://docs.litellm.ai/docs/proxy/custom_pricing"
|
||||
)
|
||||
except Exception as e:
|
||||
error_msg = f"Error in tracking cost callback - {str(e)}\n Traceback:{traceback.format_exc()}"
|
||||
model = kwargs.get("model", "")
|
||||
metadata = get_litellm_metadata_from_kwargs(kwargs=kwargs)
|
||||
litellm_metadata = kwargs.get("litellm_params", {}).get(
|
||||
"litellm_metadata", {}
|
||||
)
|
||||
old_metadata = kwargs.get("litellm_params", {}).get("metadata", {})
|
||||
call_type = kwargs.get("call_type", "")
|
||||
error_msg += f"\n Args to _PROXY_track_cost_callback\n model: {model}\n chosen_metadata: {metadata}\n litellm_metadata: {litellm_metadata}\n old_metadata: {old_metadata}\n call_type: {call_type}\n"
|
||||
asyncio.create_task(
|
||||
proxy_logging_obj.failed_tracking_alert(
|
||||
error_message=error_msg,
|
||||
failing_model=model,
|
||||
)
|
||||
)
|
||||
|
||||
verbose_proxy_logger.exception(
|
||||
"Error in tracking cost callback - %s", str(e)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _should_track_errors_in_db():
|
||||
"""
|
||||
Returns True if errors should be tracked in the database
|
||||
|
||||
By default, errors are tracked in the database
|
||||
|
||||
If users want to disable error tracking, they can set the disable_error_logs flag in the general_settings
|
||||
"""
|
||||
from litellm.proxy.proxy_server import general_settings
|
||||
|
||||
if general_settings.get("disable_error_logs") is True:
|
||||
return False
|
||||
return
|
||||
|
||||
|
||||
def _should_track_cost_callback(
|
||||
user_api_key: Optional[str],
|
||||
user_id: Optional[str],
|
||||
team_id: Optional[str],
|
||||
end_user_id: Optional[str],
|
||||
) -> bool:
|
||||
"""
|
||||
Determine if the cost callback should be tracked based on the kwargs
|
||||
"""
|
||||
|
||||
# don't run track cost callback if user opted into disabling spend
|
||||
if ProxyUpdateSpend.disable_spend_updates() is True:
|
||||
return False
|
||||
|
||||
if (
|
||||
user_api_key is not None
|
||||
or user_id is not None
|
||||
or team_id is not None
|
||||
or end_user_id is not None
|
||||
):
|
||||
return True
|
||||
return False
|
||||
import asyncio
|
||||
import traceback
|
||||
from datetime import datetime
|
||||
from typing import Any, List, Optional, Union, cast
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.litellm_core_utils.core_helpers import (
|
||||
_get_parent_otel_span_from_kwargs,
|
||||
get_litellm_metadata_from_kwargs,
|
||||
)
|
||||
from litellm.litellm_core_utils.litellm_logging import StandardLoggingPayloadSetup
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
from litellm.proxy.auth.auth_checks import log_db_metrics
|
||||
from litellm.proxy.auth.route_checks import RouteChecks
|
||||
from litellm.proxy.utils import ProxyUpdateSpend
|
||||
from litellm.types.utils import (
|
||||
StandardLoggingPayload,
|
||||
StandardLoggingUserAPIKeyMetadata,
|
||||
)
|
||||
from litellm.utils import get_end_user_id_for_cost_tracking
|
||||
|
||||
|
||||
class _ProxyDBLogger(CustomLogger):
|
||||
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
await self._PROXY_track_cost_callback(
|
||||
kwargs, response_obj, start_time, end_time
|
||||
)
|
||||
|
||||
async def async_post_call_failure_hook(
|
||||
self,
|
||||
request_data: dict,
|
||||
original_exception: Exception,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
traceback_str: Optional[str] = None,
|
||||
):
|
||||
request_route = user_api_key_dict.request_route
|
||||
if _ProxyDBLogger._should_track_errors_in_db() is False:
|
||||
return
|
||||
elif request_route is not None and not RouteChecks.is_llm_api_route(
|
||||
route=request_route
|
||||
):
|
||||
return
|
||||
|
||||
from litellm.proxy.proxy_server import proxy_logging_obj
|
||||
|
||||
_metadata = dict(
|
||||
StandardLoggingUserAPIKeyMetadata(
|
||||
user_api_key_hash=user_api_key_dict.api_key,
|
||||
user_api_key_alias=user_api_key_dict.key_alias,
|
||||
user_api_key_spend=user_api_key_dict.spend,
|
||||
user_api_key_max_budget=user_api_key_dict.max_budget,
|
||||
user_api_key_budget_reset_at=(
|
||||
user_api_key_dict.budget_reset_at.isoformat()
|
||||
if user_api_key_dict.budget_reset_at
|
||||
else None
|
||||
),
|
||||
user_api_key_user_email=user_api_key_dict.user_email,
|
||||
user_api_key_user_id=user_api_key_dict.user_id,
|
||||
user_api_key_team_id=user_api_key_dict.team_id,
|
||||
user_api_key_org_id=user_api_key_dict.org_id,
|
||||
user_api_key_project_id=user_api_key_dict.project_id,
|
||||
user_api_key_team_alias=user_api_key_dict.team_alias,
|
||||
user_api_key_end_user_id=user_api_key_dict.end_user_id,
|
||||
user_api_key_request_route=user_api_key_dict.request_route,
|
||||
user_api_key_auth_metadata=user_api_key_dict.metadata,
|
||||
)
|
||||
)
|
||||
_metadata["user_api_key"] = user_api_key_dict.api_key
|
||||
_metadata["status"] = "failure"
|
||||
_metadata[
|
||||
"error_information"
|
||||
] = StandardLoggingPayloadSetup.get_error_information(
|
||||
original_exception=original_exception,
|
||||
traceback_str=traceback_str,
|
||||
)
|
||||
|
||||
existing_metadata: dict = request_data.get("metadata", None) or {}
|
||||
existing_metadata.update(_metadata)
|
||||
|
||||
if "litellm_params" not in request_data:
|
||||
request_data["litellm_params"] = {}
|
||||
|
||||
existing_litellm_params = request_data.get("litellm_params", {})
|
||||
existing_litellm_metadata = existing_litellm_params.get("metadata", {}) or {}
|
||||
|
||||
# Preserve tags from existing metadata
|
||||
if existing_litellm_metadata.get("tags"):
|
||||
existing_metadata["tags"] = existing_litellm_metadata.get("tags")
|
||||
|
||||
request_data["litellm_params"]["proxy_server_request"] = (
|
||||
request_data.get("proxy_server_request")
|
||||
or existing_litellm_params.get("proxy_server_request")
|
||||
or {}
|
||||
)
|
||||
request_data["litellm_params"]["metadata"] = existing_metadata
|
||||
|
||||
# Preserve model name and custom_llm_provider
|
||||
if "model" not in request_data:
|
||||
request_data["model"] = existing_litellm_params.get(
|
||||
"model"
|
||||
) or request_data.get("model", "")
|
||||
if "custom_llm_provider" not in request_data:
|
||||
request_data["custom_llm_provider"] = existing_litellm_params.get(
|
||||
"custom_llm_provider"
|
||||
) or request_data.get("custom_llm_provider", "")
|
||||
|
||||
await proxy_logging_obj.db_spend_update_writer.update_database(
|
||||
token=user_api_key_dict.api_key,
|
||||
response_cost=0.0,
|
||||
user_id=user_api_key_dict.user_id,
|
||||
end_user_id=user_api_key_dict.end_user_id,
|
||||
team_id=user_api_key_dict.team_id,
|
||||
kwargs=request_data,
|
||||
completion_response=original_exception,
|
||||
start_time=datetime.now(),
|
||||
end_time=datetime.now(),
|
||||
org_id=user_api_key_dict.org_id,
|
||||
)
|
||||
|
||||
@log_db_metrics
|
||||
async def _PROXY_track_cost_callback(
|
||||
self,
|
||||
kwargs, # kwargs to completion
|
||||
completion_response: Optional[
|
||||
Union[litellm.ModelResponse, Any]
|
||||
], # response from completion
|
||||
start_time=None,
|
||||
end_time=None, # start/end time for completion
|
||||
):
|
||||
from litellm.proxy.proxy_server import proxy_logging_obj, update_cache
|
||||
|
||||
verbose_proxy_logger.debug("INSIDE _PROXY_track_cost_callback")
|
||||
try:
|
||||
verbose_proxy_logger.debug(
|
||||
f"kwargs stream: {kwargs.get('stream', None)} + complete streaming response: {kwargs.get('complete_streaming_response', None)}"
|
||||
)
|
||||
parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs=kwargs)
|
||||
litellm_params = kwargs.get("litellm_params", {}) or {}
|
||||
end_user_id = get_end_user_id_for_cost_tracking(litellm_params)
|
||||
metadata = get_litellm_metadata_from_kwargs(kwargs=kwargs)
|
||||
user_id = cast(Optional[str], metadata.get("user_api_key_user_id", None))
|
||||
team_id = cast(Optional[str], metadata.get("user_api_key_team_id", None))
|
||||
org_id = cast(Optional[str], metadata.get("user_api_key_org_id", None))
|
||||
key_alias = cast(Optional[str], metadata.get("user_api_key_alias", None))
|
||||
end_user_max_budget = metadata.get("user_api_end_user_max_budget", None)
|
||||
sl_object: Optional[StandardLoggingPayload] = kwargs.get(
|
||||
"standard_logging_object", None
|
||||
)
|
||||
response_cost = (
|
||||
sl_object.get("response_cost", None)
|
||||
if sl_object is not None
|
||||
else kwargs.get("response_cost", None)
|
||||
)
|
||||
tags: Optional[List[str]] = (
|
||||
sl_object.get("request_tags", None) if sl_object is not None else None
|
||||
)
|
||||
|
||||
if response_cost is not None:
|
||||
user_api_key = metadata.get("user_api_key", None)
|
||||
if kwargs.get("cache_hit", False) is True:
|
||||
response_cost = 0.0
|
||||
verbose_proxy_logger.debug(
|
||||
f"Cache Hit: response_cost {response_cost}, for user_id {user_id}"
|
||||
)
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
f"user_api_key {user_api_key}, user_id {user_id}, team_id {team_id}, end_user_id {end_user_id}"
|
||||
)
|
||||
if _should_track_cost_callback(
|
||||
user_api_key=user_api_key,
|
||||
user_id=user_id,
|
||||
team_id=team_id,
|
||||
end_user_id=end_user_id,
|
||||
):
|
||||
## UPDATE DATABASE
|
||||
await proxy_logging_obj.db_spend_update_writer.update_database(
|
||||
token=user_api_key,
|
||||
response_cost=response_cost,
|
||||
user_id=user_id,
|
||||
end_user_id=end_user_id,
|
||||
team_id=team_id,
|
||||
kwargs=kwargs,
|
||||
completion_response=completion_response,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
org_id=org_id,
|
||||
)
|
||||
|
||||
# update cache
|
||||
asyncio.create_task(
|
||||
update_cache(
|
||||
token=user_api_key,
|
||||
user_id=user_id,
|
||||
end_user_id=end_user_id,
|
||||
response_cost=response_cost,
|
||||
team_id=team_id,
|
||||
parent_otel_span=parent_otel_span,
|
||||
tags=tags,
|
||||
)
|
||||
)
|
||||
|
||||
await proxy_logging_obj.slack_alerting_instance.customer_spend_alert(
|
||||
token=user_api_key,
|
||||
key_alias=key_alias,
|
||||
end_user_id=end_user_id,
|
||||
response_cost=response_cost,
|
||||
max_budget=end_user_max_budget,
|
||||
)
|
||||
else:
|
||||
# Non-model call types (health checks, afile_delete) have no model or standard_logging_object.
|
||||
# Use .get() for "stream" to avoid KeyError on health checks.
|
||||
if sl_object is None and not kwargs.get("model"):
|
||||
verbose_proxy_logger.warning(
|
||||
"Cost tracking - skipping, no standard_logging_object and no model for call_type=%s",
|
||||
kwargs.get("call_type", "unknown"),
|
||||
)
|
||||
return
|
||||
if kwargs.get("stream") is not True or (
|
||||
kwargs.get("stream") is True
|
||||
and "complete_streaming_response" in kwargs
|
||||
):
|
||||
if sl_object is not None:
|
||||
cost_tracking_failure_debug_info: Union[dict, str] = (
|
||||
sl_object["response_cost_failure_debug_info"] # type: ignore
|
||||
or "response_cost_failure_debug_info is None in standard_logging_object"
|
||||
)
|
||||
else:
|
||||
cost_tracking_failure_debug_info = (
|
||||
"standard_logging_object not found"
|
||||
)
|
||||
model = kwargs.get("model")
|
||||
raise Exception(
|
||||
f"Cost tracking failed for model={model}.\nDebug info - {cost_tracking_failure_debug_info}\nAdd custom pricing - https://docs.litellm.ai/docs/proxy/custom_pricing"
|
||||
)
|
||||
except Exception as e:
|
||||
error_msg = f"Error in tracking cost callback - {str(e)}\n Traceback:{traceback.format_exc()}"
|
||||
model = kwargs.get("model", "")
|
||||
metadata = get_litellm_metadata_from_kwargs(kwargs=kwargs)
|
||||
litellm_metadata = kwargs.get("litellm_params", {}).get(
|
||||
"litellm_metadata", {}
|
||||
)
|
||||
old_metadata = kwargs.get("litellm_params", {}).get("metadata", {})
|
||||
call_type = kwargs.get("call_type", "")
|
||||
error_msg += f"\n Args to _PROXY_track_cost_callback\n model: {model}\n chosen_metadata: {metadata}\n litellm_metadata: {litellm_metadata}\n old_metadata: {old_metadata}\n call_type: {call_type}\n"
|
||||
asyncio.create_task(
|
||||
proxy_logging_obj.failed_tracking_alert(
|
||||
error_message=error_msg,
|
||||
failing_model=model,
|
||||
)
|
||||
)
|
||||
|
||||
verbose_proxy_logger.exception(
|
||||
"Error in tracking cost callback - %s", str(e)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _should_track_errors_in_db():
|
||||
"""
|
||||
Returns True if errors should be tracked in the database
|
||||
|
||||
By default, errors are tracked in the database
|
||||
|
||||
If users want to disable error tracking, they can set the disable_error_logs flag in the general_settings
|
||||
"""
|
||||
from litellm.proxy.proxy_server import general_settings
|
||||
|
||||
if general_settings.get("disable_error_logs") is True:
|
||||
return False
|
||||
return
|
||||
|
||||
|
||||
def _should_track_cost_callback(
|
||||
user_api_key: Optional[str],
|
||||
user_id: Optional[str],
|
||||
team_id: Optional[str],
|
||||
end_user_id: Optional[str],
|
||||
) -> bool:
|
||||
"""
|
||||
Determine if the cost callback should be tracked based on the kwargs
|
||||
"""
|
||||
|
||||
# don't run track cost callback if user opted into disabling spend
|
||||
if ProxyUpdateSpend.disable_spend_updates() is True:
|
||||
return False
|
||||
|
||||
if (
|
||||
user_api_key is not None
|
||||
or user_id is not None
|
||||
or team_id is not None
|
||||
or end_user_id is not None
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user