mirror of
https://github.com/tiennm99/litellm.git
synced 2026-06-18 07:33:58 +00:00
address greptile review feedback (greploop iteration 2)
- Thread api_version through HTTP handlers to Azure realtime endpoints - Make expires_at optional in RealtimeClientSecretResponse - Fix test token expiry times to be in the future - Populate user_id and team_id in minimal_auth for spend tracking Made-with: Cursor
This commit is contained in:
@@ -54,7 +54,7 @@ class BaseRealtimeHTTPConfig(ABC):
|
||||
# ------------------------------------------------------------------ #
|
||||
|
||||
@abstractmethod
|
||||
def get_complete_url(self, api_base: Optional[str], model: str) -> str:
|
||||
def get_complete_url(self, api_base: Optional[str], model: str, api_version: Optional[str] = None) -> str:
|
||||
"""Return the full URL for POST /realtime/client_secrets."""
|
||||
|
||||
@abstractmethod
|
||||
@@ -76,7 +76,7 @@ class BaseRealtimeHTTPConfig(ABC):
|
||||
# ------------------------------------------------------------------ #
|
||||
|
||||
def get_realtime_calls_url(
|
||||
self, api_base: Optional[str], model: str
|
||||
self, api_base: Optional[str], model: str, api_version: Optional[str] = None
|
||||
) -> str:
|
||||
"""Return the full URL for POST /realtime/calls (SDP exchange)."""
|
||||
base = (api_base or "").rstrip("/")
|
||||
|
||||
@@ -4746,6 +4746,7 @@ class BaseLLMHTTPHandler:
|
||||
model: Optional[str] = None,
|
||||
extra_headers: Optional[Dict[str, Any]] = None,
|
||||
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
|
||||
api_version: Optional[str] = None,
|
||||
) -> httpx.Response:
|
||||
"""
|
||||
Forward POST /v1/realtime/client_secrets to upstream provider.
|
||||
@@ -4761,7 +4762,7 @@ class BaseLLMHTTPHandler:
|
||||
async_httpx_client = client
|
||||
|
||||
if provider_config is not None:
|
||||
url = provider_config.get_complete_url(api_base=api_base, model=model or "")
|
||||
url = provider_config.get_complete_url(api_base=api_base, model=model or "", api_version=api_version)
|
||||
headers: Dict[str, Any] = provider_config.validate_environment(
|
||||
headers={}, model=model or "", api_key=api_key
|
||||
)
|
||||
@@ -4811,6 +4812,7 @@ class BaseLLMHTTPHandler:
|
||||
session_config: Optional[Dict[str, Any]] = None,
|
||||
extra_headers: Optional[Dict[str, Any]] = None,
|
||||
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
|
||||
api_version: Optional[str] = None,
|
||||
) -> httpx.Response:
|
||||
"""
|
||||
Forward POST /v1/realtime/calls (SDP exchange) to upstream provider.
|
||||
@@ -4830,7 +4832,7 @@ class BaseLLMHTTPHandler:
|
||||
async_httpx_client = client
|
||||
|
||||
if provider_config is not None:
|
||||
url = provider_config.get_realtime_calls_url(api_base=api_base, model=model or "")
|
||||
url = provider_config.get_realtime_calls_url(api_base=api_base, model=model or "", api_version=api_version)
|
||||
headers: Dict[str, Any] = provider_config.get_realtime_calls_headers(
|
||||
ephemeral_key=openai_ephemeral_key
|
||||
)
|
||||
|
||||
@@ -25,13 +25,13 @@ class OpenAIRealtimeHTTPConfig(BaseRealtimeHTTPConfig):
|
||||
or ""
|
||||
)
|
||||
|
||||
def get_complete_url(self, api_base: Optional[str], model: str) -> str:
|
||||
def get_complete_url(self, api_base: Optional[str], model: str, api_version: Optional[str] = None) -> str:
|
||||
base = self.get_api_base(api_base).rstrip("/")
|
||||
if base.endswith("/v1"):
|
||||
base = base[:-3]
|
||||
return f"{base}/v1/realtime/client_secrets"
|
||||
|
||||
def get_realtime_calls_url(self, api_base: Optional[str], model: str) -> str:
|
||||
def get_realtime_calls_url(self, api_base: Optional[str], model: str, api_version: Optional[str] = None) -> str:
|
||||
base = self.get_api_base(api_base).rstrip("/")
|
||||
if base.endswith("/v1"):
|
||||
base = base[:-3]
|
||||
|
||||
@@ -282,14 +282,21 @@ async def proxy_realtime_calls(
|
||||
or request.query_params.get("model")
|
||||
or "gpt-4o-realtime-preview"
|
||||
)
|
||||
user_id = decoded_payload.get("user_id") or None
|
||||
team_id = decoded_payload.get("team_id") or None
|
||||
else:
|
||||
# Backward compatibility: older tokens contained only encrypted upstream key.
|
||||
openai_ephemeral_key = decrypted_token_value
|
||||
model = request.query_params.get("model", "gpt-4o-realtime-preview")
|
||||
user_id = None
|
||||
team_id = None
|
||||
|
||||
# Build a minimal UserAPIKeyAuth so we can pass through the logging pipeline
|
||||
# even though this endpoint uses the provider ephemeral key for auth.
|
||||
minimal_auth = UserAPIKeyAuth()
|
||||
# Build a minimal UserAPIKeyAuth with user/team IDs from the token
|
||||
# so spend tracking and budget enforcement work correctly.
|
||||
minimal_auth = UserAPIKeyAuth(
|
||||
user_id=user_id,
|
||||
team_id=team_id,
|
||||
)
|
||||
|
||||
data: dict = {}
|
||||
try:
|
||||
|
||||
@@ -4,8 +4,7 @@ import os
|
||||
from typing import Any, Dict, Optional, cast
|
||||
|
||||
import litellm
|
||||
from litellm.constants import request_timeout
|
||||
from litellm.constants import REALTIME_WEBSOCKET_MAX_MESSAGE_SIZE_BYTES
|
||||
from litellm.constants import REALTIME_WEBSOCKET_MAX_MESSAGE_SIZE_BYTES, request_timeout
|
||||
from litellm.litellm_core_utils.get_llm_provider_logic import get_llm_provider
|
||||
from litellm.llms.base_llm.realtime.transformation import BaseRealtimeConfig
|
||||
from litellm.llms.custom_httpx.llm_http_handler import BaseLLMHTTPHandler
|
||||
@@ -56,7 +55,9 @@ def _get_realtime_http_provider_config(
|
||||
Uses ProviderConfigManager so each provider keeps its credential-resolution
|
||||
and URL-construction logic in its own transformation class.
|
||||
"""
|
||||
from litellm.llms.base_llm.realtime.http_transformation import BaseRealtimeHTTPConfig
|
||||
from litellm.llms.base_llm.realtime.http_transformation import (
|
||||
BaseRealtimeHTTPConfig,
|
||||
)
|
||||
|
||||
provider_config: Optional[BaseRealtimeHTTPConfig] = None
|
||||
if custom_llm_provider in LlmProviders._member_map_.values():
|
||||
@@ -138,6 +139,7 @@ async def acreate_realtime_client_secret(
|
||||
model=model_name,
|
||||
extra_headers=kwargs.get("extra_headers"),
|
||||
client=kwargs.get("client"),
|
||||
api_version=litellm_params.api_version,
|
||||
)
|
||||
|
||||
|
||||
@@ -182,6 +184,7 @@ async def arealtime_calls(
|
||||
session_config=session,
|
||||
extra_headers=kwargs.get("extra_headers"),
|
||||
client=kwargs.get("client"),
|
||||
api_version=litellm_params.api_version,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -112,6 +112,6 @@ class RealtimeClientSecretResponse(BaseModel):
|
||||
The `session` field is kept as a raw dict so unknown fields pass through.
|
||||
"""
|
||||
|
||||
expires_at: int
|
||||
expires_at: Optional[int] = None
|
||||
value: str
|
||||
session: Optional[Dict[str, Any]] = None
|
||||
|
||||
@@ -7,6 +7,7 @@ Tests for LiteLLM proxy realtime WebRTC HTTP endpoints:
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import httpx
|
||||
@@ -59,19 +60,20 @@ def test_encode_realtime_token_payload_none_optional_fields():
|
||||
|
||||
|
||||
def test_decode_realtime_token_payload_valid():
|
||||
future_expires_at = int(time.time()) + 3600
|
||||
payload = _encode_realtime_token_payload(
|
||||
ephemeral_key="epk_abc",
|
||||
model_id="gpt-4o",
|
||||
user_id=None,
|
||||
team_id=None,
|
||||
expires_at=999,
|
||||
expires_at=future_expires_at,
|
||||
)
|
||||
decrypted = json.loads(payload) # simulate decrypted value
|
||||
result = _decode_realtime_token_payload(json.dumps(decrypted))
|
||||
assert result is not None
|
||||
assert result["ephemeral_key"] == "epk_abc"
|
||||
assert result["model_id"] == "gpt-4o"
|
||||
assert result["expires_at"] == 999
|
||||
assert result["expires_at"] == future_expires_at
|
||||
|
||||
|
||||
def test_decode_realtime_token_payload_invalid_version():
|
||||
@@ -115,14 +117,15 @@ def proxy_app():
|
||||
@pytest.fixture
|
||||
def mock_route_request_client_secrets():
|
||||
"""Mock route_request to return a fake upstream client_secrets response."""
|
||||
future_expires_at = int(time.time()) + 3600
|
||||
mock_resp = MagicMock(spec=httpx.Response)
|
||||
mock_resp.status_code = 200
|
||||
mock_resp.text = '{"value":"upstream_ephemeral_key","expires_at":999}'
|
||||
mock_resp.content = b'{"value":"upstream_ephemeral_key","expires_at":999}'
|
||||
mock_resp.text = f'{{"value":"upstream_ephemeral_key","expires_at":{future_expires_at}}}'
|
||||
mock_resp.content = f'{{"value":"upstream_ephemeral_key","expires_at":{future_expires_at}}}'.encode()
|
||||
mock_resp.headers = {}
|
||||
mock_resp.json.return_value = {
|
||||
"value": "upstream_ephemeral_key",
|
||||
"expires_at": 999,
|
||||
"expires_at": future_expires_at,
|
||||
}
|
||||
|
||||
async def _mock_route(*args, **kwargs):
|
||||
@@ -215,7 +218,8 @@ async def test_client_secrets_success_with_mock(
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "value" in data
|
||||
assert data["expires_at"] == 999
|
||||
assert data["expires_at"] is not None
|
||||
assert data["expires_at"] > int(time.time()) # Should be in the future
|
||||
# Proxy encrypts the upstream value, so returned value should differ
|
||||
assert data["value"] != "upstream_ephemeral_key"
|
||||
|
||||
@@ -259,12 +263,13 @@ async def test_realtime_calls_success_with_valid_encrypted_token(
|
||||
proxy_server.master_key = "sk-test-master-key"
|
||||
|
||||
# Build a valid encrypted token (same format as client_secrets returns)
|
||||
future_expires_at = int(time.time()) + 3600
|
||||
token_payload = _encode_realtime_token_payload(
|
||||
ephemeral_key="fake_upstream_epk",
|
||||
model_id="gpt-4o-realtime-preview",
|
||||
user_id=None,
|
||||
team_id=None,
|
||||
expires_at=999,
|
||||
expires_at=future_expires_at,
|
||||
)
|
||||
encrypted_token = encrypt_value_helper(token_payload)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user