From bb451cfcb061c58e85180255dcdb7719b2d2e2ec Mon Sep 17 00:00:00 2001 From: Sameer Kankute Date: Thu, 12 Mar 2026 18:53:22 +0530 Subject: [PATCH] address greptile review feedback (greploop iteration 2) - Thread api_version through HTTP handlers to Azure realtime endpoints - Make expires_at optional in RealtimeClientSecretResponse - Fix test token expiry times to be in the future - Populate user_id and team_id in minimal_auth for spend tracking Made-with: Cursor --- .../base_llm/realtime/http_transformation.py | 4 ++-- litellm/llms/custom_httpx/llm_http_handler.py | 6 ++++-- .../openai/realtime/http_transformation.py | 4 ++-- litellm/proxy/realtime_endpoints/endpoints.py | 13 ++++++++++--- litellm/realtime_api/main.py | 9 ++++++--- litellm/types/realtime.py | 2 +- .../test_realtime_webrtc_endpoints.py | 19 ++++++++++++------- 7 files changed, 37 insertions(+), 20 deletions(-) diff --git a/litellm/llms/base_llm/realtime/http_transformation.py b/litellm/llms/base_llm/realtime/http_transformation.py index ccac7b0c68..7aadd49ffd 100644 --- a/litellm/llms/base_llm/realtime/http_transformation.py +++ b/litellm/llms/base_llm/realtime/http_transformation.py @@ -54,7 +54,7 @@ class BaseRealtimeHTTPConfig(ABC): # ------------------------------------------------------------------ # @abstractmethod - def get_complete_url(self, api_base: Optional[str], model: str) -> str: + def get_complete_url(self, api_base: Optional[str], model: str, api_version: Optional[str] = None) -> str: """Return the full URL for POST /realtime/client_secrets.""" @abstractmethod @@ -76,7 +76,7 @@ class BaseRealtimeHTTPConfig(ABC): # ------------------------------------------------------------------ # def get_realtime_calls_url( - self, api_base: Optional[str], model: str + self, api_base: Optional[str], model: str, api_version: Optional[str] = None ) -> str: """Return the full URL for POST /realtime/calls (SDP exchange).""" base = (api_base or "").rstrip("/") diff --git a/litellm/llms/custom_httpx/llm_http_handler.py b/litellm/llms/custom_httpx/llm_http_handler.py index 8f49e79a72..2c0a9a4f6f 100644 --- a/litellm/llms/custom_httpx/llm_http_handler.py +++ b/litellm/llms/custom_httpx/llm_http_handler.py @@ -4746,6 +4746,7 @@ class BaseLLMHTTPHandler: model: Optional[str] = None, extra_headers: Optional[Dict[str, Any]] = None, client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None, + api_version: Optional[str] = None, ) -> httpx.Response: """ Forward POST /v1/realtime/client_secrets to upstream provider. @@ -4761,7 +4762,7 @@ class BaseLLMHTTPHandler: async_httpx_client = client if provider_config is not None: - url = provider_config.get_complete_url(api_base=api_base, model=model or "") + url = provider_config.get_complete_url(api_base=api_base, model=model or "", api_version=api_version) headers: Dict[str, Any] = provider_config.validate_environment( headers={}, model=model or "", api_key=api_key ) @@ -4811,6 +4812,7 @@ class BaseLLMHTTPHandler: session_config: Optional[Dict[str, Any]] = None, extra_headers: Optional[Dict[str, Any]] = None, client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None, + api_version: Optional[str] = None, ) -> httpx.Response: """ Forward POST /v1/realtime/calls (SDP exchange) to upstream provider. @@ -4830,7 +4832,7 @@ class BaseLLMHTTPHandler: async_httpx_client = client if provider_config is not None: - url = provider_config.get_realtime_calls_url(api_base=api_base, model=model or "") + url = provider_config.get_realtime_calls_url(api_base=api_base, model=model or "", api_version=api_version) headers: Dict[str, Any] = provider_config.get_realtime_calls_headers( ephemeral_key=openai_ephemeral_key ) diff --git a/litellm/llms/openai/realtime/http_transformation.py b/litellm/llms/openai/realtime/http_transformation.py index 33d1cdf322..ff69ef987d 100644 --- a/litellm/llms/openai/realtime/http_transformation.py +++ b/litellm/llms/openai/realtime/http_transformation.py @@ -25,13 +25,13 @@ class OpenAIRealtimeHTTPConfig(BaseRealtimeHTTPConfig): or "" ) - def get_complete_url(self, api_base: Optional[str], model: str) -> str: + def get_complete_url(self, api_base: Optional[str], model: str, api_version: Optional[str] = None) -> str: base = self.get_api_base(api_base).rstrip("/") if base.endswith("/v1"): base = base[:-3] return f"{base}/v1/realtime/client_secrets" - def get_realtime_calls_url(self, api_base: Optional[str], model: str) -> str: + def get_realtime_calls_url(self, api_base: Optional[str], model: str, api_version: Optional[str] = None) -> str: base = self.get_api_base(api_base).rstrip("/") if base.endswith("/v1"): base = base[:-3] diff --git a/litellm/proxy/realtime_endpoints/endpoints.py b/litellm/proxy/realtime_endpoints/endpoints.py index 75587f0828..bb286d1fd0 100644 --- a/litellm/proxy/realtime_endpoints/endpoints.py +++ b/litellm/proxy/realtime_endpoints/endpoints.py @@ -282,14 +282,21 @@ async def proxy_realtime_calls( or request.query_params.get("model") or "gpt-4o-realtime-preview" ) + user_id = decoded_payload.get("user_id") or None + team_id = decoded_payload.get("team_id") or None else: # Backward compatibility: older tokens contained only encrypted upstream key. openai_ephemeral_key = decrypted_token_value model = request.query_params.get("model", "gpt-4o-realtime-preview") + user_id = None + team_id = None - # Build a minimal UserAPIKeyAuth so we can pass through the logging pipeline - # even though this endpoint uses the provider ephemeral key for auth. - minimal_auth = UserAPIKeyAuth() + # Build a minimal UserAPIKeyAuth with user/team IDs from the token + # so spend tracking and budget enforcement work correctly. + minimal_auth = UserAPIKeyAuth( + user_id=user_id, + team_id=team_id, + ) data: dict = {} try: diff --git a/litellm/realtime_api/main.py b/litellm/realtime_api/main.py index 01b76ad805..81f29ca6e3 100644 --- a/litellm/realtime_api/main.py +++ b/litellm/realtime_api/main.py @@ -4,8 +4,7 @@ import os from typing import Any, Dict, Optional, cast import litellm -from litellm.constants import request_timeout -from litellm.constants import REALTIME_WEBSOCKET_MAX_MESSAGE_SIZE_BYTES +from litellm.constants import REALTIME_WEBSOCKET_MAX_MESSAGE_SIZE_BYTES, request_timeout from litellm.litellm_core_utils.get_llm_provider_logic import get_llm_provider from litellm.llms.base_llm.realtime.transformation import BaseRealtimeConfig from litellm.llms.custom_httpx.llm_http_handler import BaseLLMHTTPHandler @@ -56,7 +55,9 @@ def _get_realtime_http_provider_config( Uses ProviderConfigManager so each provider keeps its credential-resolution and URL-construction logic in its own transformation class. """ - from litellm.llms.base_llm.realtime.http_transformation import BaseRealtimeHTTPConfig + from litellm.llms.base_llm.realtime.http_transformation import ( + BaseRealtimeHTTPConfig, + ) provider_config: Optional[BaseRealtimeHTTPConfig] = None if custom_llm_provider in LlmProviders._member_map_.values(): @@ -138,6 +139,7 @@ async def acreate_realtime_client_secret( model=model_name, extra_headers=kwargs.get("extra_headers"), client=kwargs.get("client"), + api_version=litellm_params.api_version, ) @@ -182,6 +184,7 @@ async def arealtime_calls( session_config=session, extra_headers=kwargs.get("extra_headers"), client=kwargs.get("client"), + api_version=litellm_params.api_version, ) diff --git a/litellm/types/realtime.py b/litellm/types/realtime.py index d341a32654..62e4044061 100644 --- a/litellm/types/realtime.py +++ b/litellm/types/realtime.py @@ -112,6 +112,6 @@ class RealtimeClientSecretResponse(BaseModel): The `session` field is kept as a raw dict so unknown fields pass through. """ - expires_at: int + expires_at: Optional[int] = None value: str session: Optional[Dict[str, Any]] = None diff --git a/tests/test_litellm/proxy/realtime_endpoints/test_realtime_webrtc_endpoints.py b/tests/test_litellm/proxy/realtime_endpoints/test_realtime_webrtc_endpoints.py index 1ab876e7ff..3d82e4177a 100644 --- a/tests/test_litellm/proxy/realtime_endpoints/test_realtime_webrtc_endpoints.py +++ b/tests/test_litellm/proxy/realtime_endpoints/test_realtime_webrtc_endpoints.py @@ -7,6 +7,7 @@ Tests for LiteLLM proxy realtime WebRTC HTTP endpoints: import json import os import sys +import time from unittest.mock import AsyncMock, MagicMock, patch import httpx @@ -59,19 +60,20 @@ def test_encode_realtime_token_payload_none_optional_fields(): def test_decode_realtime_token_payload_valid(): + future_expires_at = int(time.time()) + 3600 payload = _encode_realtime_token_payload( ephemeral_key="epk_abc", model_id="gpt-4o", user_id=None, team_id=None, - expires_at=999, + expires_at=future_expires_at, ) decrypted = json.loads(payload) # simulate decrypted value result = _decode_realtime_token_payload(json.dumps(decrypted)) assert result is not None assert result["ephemeral_key"] == "epk_abc" assert result["model_id"] == "gpt-4o" - assert result["expires_at"] == 999 + assert result["expires_at"] == future_expires_at def test_decode_realtime_token_payload_invalid_version(): @@ -115,14 +117,15 @@ def proxy_app(): @pytest.fixture def mock_route_request_client_secrets(): """Mock route_request to return a fake upstream client_secrets response.""" + future_expires_at = int(time.time()) + 3600 mock_resp = MagicMock(spec=httpx.Response) mock_resp.status_code = 200 - mock_resp.text = '{"value":"upstream_ephemeral_key","expires_at":999}' - mock_resp.content = b'{"value":"upstream_ephemeral_key","expires_at":999}' + mock_resp.text = f'{{"value":"upstream_ephemeral_key","expires_at":{future_expires_at}}}' + mock_resp.content = f'{{"value":"upstream_ephemeral_key","expires_at":{future_expires_at}}}'.encode() mock_resp.headers = {} mock_resp.json.return_value = { "value": "upstream_ephemeral_key", - "expires_at": 999, + "expires_at": future_expires_at, } async def _mock_route(*args, **kwargs): @@ -215,7 +218,8 @@ async def test_client_secrets_success_with_mock( assert response.status_code == 200 data = response.json() assert "value" in data - assert data["expires_at"] == 999 + assert data["expires_at"] is not None + assert data["expires_at"] > int(time.time()) # Should be in the future # Proxy encrypts the upstream value, so returned value should differ assert data["value"] != "upstream_ephemeral_key" @@ -259,12 +263,13 @@ async def test_realtime_calls_success_with_valid_encrypted_token( proxy_server.master_key = "sk-test-master-key" # Build a valid encrypted token (same format as client_secrets returns) + future_expires_at = int(time.time()) + 3600 token_payload = _encode_realtime_token_payload( ephemeral_key="fake_upstream_epk", model_id="gpt-4o-realtime-preview", user_id=None, team_id=None, - expires_at=999, + expires_at=future_expires_at, ) encrypted_token = encrypt_value_helper(token_payload)