Files
litellm/tests/local_testing/test_pass_through_endpoints.py
T
Julio Quinteros Pro bb63de2f82 fix(tests): make RPM limit test sequential to avoid race condition
Concurrent requests via run_in_executor + asyncio.gather caused a race
condition where more requests slipped through the rate limiter than
expected, leading to flaky test failures (e.g. 3 successes instead of 2
with rpm_limit=2).

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-23 16:34:52 -03:00

560 lines
20 KiB
Python

import os
import sys
from litellm._uuid import uuid
from functools import partial
from typing import Optional
from urllib.parse import urlparse, parse_qs
import pytest
from fastapi import FastAPI
from fastapi.testclient import TestClient
sys.path.insert(
0, os.path.abspath("../..")
) # Adds-the parent directory to the system path
import asyncio
from unittest.mock import Mock
import httpx
from litellm.proxy.proxy_server import initialize_pass_through_endpoints
# Mock the async_client used in the pass_through_request function
async def mock_request(*args, **kwargs):
mock_response = httpx.Response(200, json={"message": "Mocked response"})
mock_response.request = Mock(spec=httpx.Request)
return mock_response
def remove_rerank_route(app):
for route in app.routes:
if route.path == "/v1/rerank" and "POST" in route.methods:
app.routes.remove(route)
print("Rerank route removed successfully")
print("ALL Routes on app=", app.routes)
@pytest.fixture
def client():
from litellm.proxy.proxy_server import app
remove_rerank_route(
app=app
) # remove the native rerank route on the litellm proxy - since we're testing the pass through endpoints
return TestClient(app)
@pytest.mark.asyncio
async def test_pass_through_endpoint_no_headers(client, monkeypatch):
# Mock the httpx.AsyncClient.request method
monkeypatch.setattr("httpx.AsyncClient.request", mock_request)
import litellm
# Define a pass-through endpoint
pass_through_endpoints = [
{
"path": "/test-endpoint",
"target": "https://api.example.com/v1/chat/completions",
}
]
# Initialize the pass-through endpoint
await initialize_pass_through_endpoints(pass_through_endpoints)
general_settings: dict = (
getattr(litellm.proxy.proxy_server, "general_settings", {}) or {}
)
general_settings.update({"pass_through_endpoints": pass_through_endpoints})
setattr(litellm.proxy.proxy_server, "general_settings", general_settings)
# Make a request to the pass-through endpoint
response = client.post("/test-endpoint", json={"prompt": "Hello, world!"})
# Assert the response
assert response.status_code == 200
assert response.json() == {"message": "Mocked response"}
@pytest.mark.asyncio
async def test_pass_through_endpoint(client, monkeypatch):
# Mock the httpx.AsyncClient.request method
monkeypatch.setattr("httpx.AsyncClient.request", mock_request)
import litellm
# Define a pass-through endpoint
pass_through_endpoints = [
{
"path": "/test-endpoint",
"target": "https://api.example.com/v1/chat/completions",
"headers": {"Authorization": "Bearer test-token"},
}
]
# Initialize the pass-through endpoint
await initialize_pass_through_endpoints(pass_through_endpoints)
general_settings: Optional[dict] = (
getattr(litellm.proxy.proxy_server, "general_settings", {}) or {}
)
general_settings.update({"pass_through_endpoints": pass_through_endpoints})
setattr(litellm.proxy.proxy_server, "general_settings", general_settings)
# Make a request to the pass-through endpoint
response = client.post("/test-endpoint", json={"prompt": "Hello, world!"})
# Assert the response
assert response.status_code == 200
assert response.json() == {"message": "Mocked response"}
@pytest.mark.asyncio
async def test_pass_through_endpoint_rerank(client):
_cohere_api_key = os.environ.get("COHERE_API_KEY")
import litellm
# Define a pass-through endpoint
pass_through_endpoints = [
{
"path": "/v1/rerank",
"target": "https://api.cohere.com/v1/rerank",
"headers": {"Authorization": f"Bearer {_cohere_api_key}"},
}
]
# Initialize the pass-through endpoint
await initialize_pass_through_endpoints(pass_through_endpoints)
general_settings: Optional[dict] = (
getattr(litellm.proxy.proxy_server, "general_settings", {}) or {}
)
general_settings.update({"pass_through_endpoints": pass_through_endpoints})
setattr(litellm.proxy.proxy_server, "general_settings", general_settings)
_json_data = {
"model": "rerank-english-v3.0",
"query": "What is the capital of the United States?",
"top_n": 3,
"documents": [
"Carson City is the capital city of the American state of Nevada."
],
}
# Make a request to the pass-through endpoint
response = client.post("/v1/rerank", json=_json_data)
print("JSON response: ", _json_data)
# Assert the response
assert response.status_code == 200
@pytest.mark.parametrize(
"auth, rpm_limit, requests_to_make, expected_status_codes, num_users",
[
# Single user tests
(True, 0, 1, [429], 1),
(True, 1, 1, [200], 1),
(True, 1, 2, [200, 429], 1),
(True, 2, 4, [200, 200, 429, 429], 1),
(True, 3, 4, [200, 200, 200, 429], 1),
(True, 4, 4, [200, 200, 200, 200], 1),
(False, 0, 1, [200], 1),
(False, 0, 4, [200, 200, 200, 200], 1),
# Multiple user tests (same parameters as single user)
(True, 0, 1, [429], 2),
(True, 1, 1, [200], 2),
(True, 1, 2, [200, 429], 2),
(True, 2, 4, [200, 200, 429, 429], 2),
(True, 3, 4, [200, 200, 200, 429], 2),
(True, 4, 4, [200, 200, 200, 200], 2),
(False, 0, 1, [200], 2),
(False, 0, 4, [200, 200, 200, 200], 2),
],
)
@pytest.mark.asyncio
async def test_pass_through_endpoint_rpm_limit(
client, monkeypatch, auth, rpm_limit, requests_to_make, expected_status_codes, num_users
):
monkeypatch.setattr("httpx.AsyncClient.request", mock_request)
import litellm
from litellm.proxy._types import UserAPIKeyAuth
from litellm.proxy.proxy_server import ProxyLogging, hash_token, user_api_key_cache
proxy_logging_obj = ProxyLogging(user_api_key_cache=user_api_key_cache)
proxy_logging_obj._init_litellm_callbacks()
setattr(litellm.proxy.proxy_server, "user_api_key_cache", user_api_key_cache)
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
setattr(litellm.proxy.proxy_server, "prisma_client", "FAKE-VAR")
setattr(litellm.proxy.proxy_server, "proxy_logging_obj", proxy_logging_obj)
# Define a pass-through endpoint
_cohere_api_key = os.environ.get("COHERE_API_KEY")
pass_through_endpoints = [
{
"path": "/v1/rerank",
"target": "https://api.cohere.com/v1/rerank",
"auth": auth,
"headers": {"Authorization": f"Bearer {_cohere_api_key}"},
}
]
# Initialize the pass-through endpoint
await initialize_pass_through_endpoints(pass_through_endpoints)
general_settings: Optional[dict] = (
getattr(litellm.proxy.proxy_server, "general_settings", {}) or {}
)
general_settings.update({"pass_through_endpoints": pass_through_endpoints})
setattr(litellm.proxy.proxy_server, "general_settings", general_settings)
# Setup API keys and cache
mock_api_keys = [f"sk-test-{uuid.uuid4().hex}" for _ in range(num_users)]
for mock_api_key in mock_api_keys:
cache_value = UserAPIKeyAuth(token=hash_token(mock_api_key), rpm_limit=rpm_limit)
user_api_key_cache.set_cache(key=hash_token(mock_api_key), value=cache_value)
_json_data = {
"model": "rerank-english-v3.0",
"query": "What is the capital of the United States?",
"top_n": 3,
"documents": [
"Carson City is the capital city of the American state of Nevada."
],
}
# Make requests sequentially to avoid race conditions in rate limiter
# Concurrent requests can slip through before the counter is updated
responses = []
for mock_api_key in mock_api_keys:
for _ in range(requests_to_make):
response = client.post(
"/v1/rerank",
json=_json_data,
headers={"Authorization": "Bearer {}".format(mock_api_key)},
)
responses.append(response)
if num_users == 1:
status_codes = sorted([response.status_code for response in responses])
assert status_codes == sorted(expected_status_codes)
else:
first_user_responses = responses[requests_to_make:]
second_user_responses = responses[:requests_to_make]
first_user_status_codes = sorted([response.status_code for response in first_user_responses])
second_user_status_codes = sorted([response.status_code for response in second_user_responses])
expected_status_codes.sort()
assert first_user_status_codes == expected_status_codes
assert second_user_status_codes == expected_status_codes
print("JSON response: ", _json_data)
@pytest.mark.parametrize(
"auth, rpm_limit, requests_to_make, expected_status_codes",
[
# Multiple user tests (same parameters as single user)
(True, 0, 1, [429]),
(True, 1, 1, [200]),
(True, 1, 2, [200, 429]),
(True, 2, 4, [200, 200, 429, 429]),
(True, 3, 4, [200, 200, 200, 429]),
(True, 4, 4, [200, 200, 200, 200]),
(False, 0, 1, [200]),
(False, 0, 4, [200, 200, 200, 200]),
],
)
@pytest.mark.asyncio
async def test_pass_through_endpoint_sequential_rpm_limit(
client, monkeypatch, auth, rpm_limit, requests_to_make, expected_status_codes
):
monkeypatch.setattr("httpx.AsyncClient.request", mock_request)
import litellm
from litellm.proxy._types import UserAPIKeyAuth
from litellm.proxy.proxy_server import ProxyLogging, hash_token, user_api_key_cache
proxy_logging_obj = ProxyLogging(user_api_key_cache=user_api_key_cache)
proxy_logging_obj._init_litellm_callbacks()
setattr(litellm.proxy.proxy_server, "user_api_key_cache", user_api_key_cache)
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
setattr(litellm.proxy.proxy_server, "prisma_client", "FAKE-VAR")
setattr(litellm.proxy.proxy_server, "proxy_logging_obj", proxy_logging_obj)
# Define a pass-through endpoint
_cohere_api_key = os.environ.get("COHERE_API_KEY")
pass_through_endpoints = [
{
"path": "/v1/rerank",
"target": "https://api.cohere.com/v1/rerank",
"auth": auth,
"headers": {"Authorization": f"Bearer {_cohere_api_key}"},
}
]
# Initialize the pass-through endpoint
await initialize_pass_through_endpoints(pass_through_endpoints)
general_settings: Optional[dict] = (
getattr(litellm.proxy.proxy_server, "general_settings", {}) or {}
)
general_settings.update({"pass_through_endpoints": pass_through_endpoints})
setattr(litellm.proxy.proxy_server, "general_settings", general_settings)
# Setup API keys and cache
mock_api_keys = [f"sk-test-{uuid.uuid4().hex}" for _ in range(2)]
for mock_api_key in mock_api_keys:
cache_value = UserAPIKeyAuth(token=hash_token(mock_api_key), rpm_limit=rpm_limit)
user_api_key_cache.set_cache(key=hash_token(mock_api_key), value=cache_value)
_json_data = {
"model": "rerank-english-v3.0",
"query": "What is the capital of the United States?",
"top_n": 3,
"documents": [
"Carson City is the capital city of the American state of Nevada."
],
}
# Make a request to the pass-through endpoint
first_user_responses = []
second_user_responses = []
for _ in range(requests_to_make):
requests = []
for mock_api_key in mock_api_keys:
task = asyncio.get_running_loop().run_in_executor(
None,
partial(
client.post,
"/v1/rerank",
json=_json_data,
headers={"Authorization": "Bearer {}".format(mock_api_key)},
),
)
requests.append(task)
first_user_response, second_user_response = await asyncio.gather(*requests)
first_user_responses.append(first_user_response)
second_user_responses.append(second_user_response)
first_user_status_codes = sorted([response.status_code for response in first_user_responses])
second_user_status_codes = sorted([response.status_code for response in second_user_responses])
expected_status_codes.sort()
assert first_user_status_codes == expected_status_codes
assert second_user_status_codes == expected_status_codes
print("JSON response: ", _json_data)
@pytest.mark.parametrize(
"auth, rpm_limit, expected_error_code",
[(True, 0, 429), (True, 2, 207), (False, 0, 207)],
)
@pytest.mark.asyncio
async def test_aaapass_through_endpoint_pass_through_keys_langfuse(
auth, expected_error_code, rpm_limit
):
from litellm.proxy.proxy_server import app
client = TestClient(app)
import litellm
from litellm.proxy._types import UserAPIKeyAuth
from litellm.proxy.proxy_server import ProxyLogging, hash_token, user_api_key_cache
# Store original values
original_user_api_key_cache = getattr(
litellm.proxy.proxy_server, "user_api_key_cache", None
)
original_master_key = getattr(litellm.proxy.proxy_server, "master_key", None)
original_prisma_client = getattr(litellm.proxy.proxy_server, "prisma_client", None)
original_proxy_logging_obj = getattr(
litellm.proxy.proxy_server, "proxy_logging_obj", None
)
try:
mock_api_key = "sk-my-test-key"
cache_value = UserAPIKeyAuth(
token=hash_token(mock_api_key), rpm_limit=rpm_limit
)
_cohere_api_key = os.environ.get("COHERE_API_KEY")
user_api_key_cache.set_cache(key=hash_token(mock_api_key), value=cache_value)
proxy_logging_obj = ProxyLogging(user_api_key_cache=user_api_key_cache)
proxy_logging_obj._init_litellm_callbacks()
setattr(litellm.proxy.proxy_server, "user_api_key_cache", user_api_key_cache)
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
setattr(litellm.proxy.proxy_server, "prisma_client", "FAKE-VAR")
setattr(litellm.proxy.proxy_server, "proxy_logging_obj", proxy_logging_obj)
# Define a pass-through endpoint
pass_through_endpoints = [
{
"path": "/api/public/ingestion",
"target": "https://us.cloud.langfuse.com/api/public/ingestion",
"auth": auth,
"custom_auth_parser": "langfuse",
"headers": {
"LANGFUSE_PUBLIC_KEY": "os.environ/LANGFUSE_PUBLIC_KEY",
"LANGFUSE_SECRET_KEY": "os.environ/LANGFUSE_SECRET_KEY",
},
}
]
# Initialize the pass-through endpoint
await initialize_pass_through_endpoints(pass_through_endpoints)
general_settings: Optional[dict] = (
getattr(litellm.proxy.proxy_server, "general_settings", {}) or {}
)
old_general_settings = general_settings
general_settings.update({"pass_through_endpoints": pass_through_endpoints})
setattr(litellm.proxy.proxy_server, "general_settings", general_settings)
_json_data = {
"batch": [
{
"id": "80e2141f-0ca6-47b7-9c06-dde5e97de690",
"type": "trace-create",
"body": {
"id": "0687af7b-4a75-4de8-a4f6-cba1cdc00865",
"timestamp": "2024-08-14T02:38:56.092950Z",
"name": "test-trace-litellm-proxy-passthrough",
},
"timestamp": "2024-08-14T02:38:56.093352Z",
}
],
"metadata": {
"batch_size": 1,
"sdk_integration": "default",
"sdk_name": "python",
"sdk_version": "2.27.0",
"public_key": "anything",
},
}
# Make a request to the pass-through endpoint
# For langfuse custom_auth_parser, the Authorization header must be valid base64
# Format: base64(public_key:secret_key) where public_key is the LiteLLM API key
import base64
auth_token = base64.b64encode(f"{mock_api_key}:anything".encode()).decode()
response = client.post(
"/api/public/ingestion",
json=_json_data,
headers={"Authorization": "Basic c2stbXktdGVzdC1rZXk6YW55dGhpbmc="},
)
print("JSON response: ", _json_data)
print("RESPONSE RECEIVED - {}".format(response.text))
# Assert the response
assert response.status_code == expected_error_code
setattr(litellm.proxy.proxy_server, "general_settings", old_general_settings)
finally:
# Reset to original values
setattr(
litellm.proxy.proxy_server,
"user_api_key_cache",
original_user_api_key_cache,
)
setattr(litellm.proxy.proxy_server, "master_key", original_master_key)
setattr(litellm.proxy.proxy_server, "prisma_client", original_prisma_client)
setattr(
litellm.proxy.proxy_server, "proxy_logging_obj", original_proxy_logging_obj
)
@pytest.mark.asyncio
async def test_pass_through_endpoint_bing(client, monkeypatch):
import litellm
captured_requests = []
async def mock_bing_request(*args, **kwargs):
captured_requests.append((args, kwargs))
mock_response = httpx.Response(
200,
json={
"_type": "SearchResponse",
"queryContext": {"originalQuery": "bob barker"},
"webPages": {
"webSearchUrl": "https://www.bing.com/search?q=bob+barker",
"totalEstimatedMatches": 12000000,
"value": [],
},
},
)
mock_response.request = Mock(spec=httpx.Request)
return mock_response
monkeypatch.setattr("httpx.AsyncClient.request", mock_bing_request)
# Define a pass-through endpoint
pass_through_endpoints = [
{
"path": "/bing/search",
"target": "https://api.bing.microsoft.com/v7.0/search?setLang=en-US&mkt=en-US",
"headers": {"Ocp-Apim-Subscription-Key": "XX"},
"forward_headers": True,
# Additional settings
"merge_query_params": True,
"auth": True,
},
{
"path": "/bing/search-no-merge-params",
"target": "https://api.bing.microsoft.com/v7.0/search?setLang=en-US&mkt=en-US",
"headers": {"Ocp-Apim-Subscription-Key": "XX"},
"forward_headers": True,
},
]
# Initialize the pass-through endpoint
await initialize_pass_through_endpoints(pass_through_endpoints)
general_settings: Optional[dict] = (
getattr(litellm.proxy.proxy_server, "general_settings", {}) or {}
)
general_settings.update({"pass_through_endpoints": pass_through_endpoints})
setattr(litellm.proxy.proxy_server, "general_settings", general_settings)
# Make 2 requests thru the pass-through endpoint
client.get("/bing/search?q=bob+barker")
client.get("/bing/search-no-merge-params?q=bob+barker")
first_transformed_url = captured_requests[0][1]["url"]
second_transformed_url = captured_requests[1][1]["url"]
# Parse URLs to compare query params order-independently
# Parse first URL
parsed_first = urlparse(str(first_transformed_url))
first_params = parse_qs(parsed_first.query)
# Parse second URL
parsed_second = urlparse(str(second_transformed_url))
second_params = parse_qs(parsed_second.query)
# Expected values (parse_qs decodes + as space)
expected_first_params = {"q": ["bob barker"], "setLang": ["en-US"], "mkt": ["en-US"]}
expected_second_params = {"setLang": ["en-US"], "mkt": ["en-US"]}
# Assert the response - compare base URL and params separately
assert (
parsed_first.scheme == "https"
and parsed_first.netloc == "api.bing.microsoft.com"
and parsed_first.path == "/v7.0/search"
and first_params == expected_first_params
and parsed_second.scheme == "https"
and parsed_second.netloc == "api.bing.microsoft.com"
and parsed_second.path == "/v7.0/search"
and second_params == expected_second_params
)