mirror of
https://github.com/tiennm99/litellm.git
synced 2026-06-18 03:31:23 +00:00
26bdf7b7a8
- Removed duplicate comment in test_router_endpoints.py - Removed duplicate comment in logging.md - Kept clearer comment: 'Set litellm.callbacks = [proxy_handler_instance] on the proxy'
1130 lines
33 KiB
Python
1130 lines
33 KiB
Python
import sys
|
|
import os
|
|
import json
|
|
import traceback
|
|
from typing import Optional
|
|
from dotenv import load_dotenv
|
|
from fastapi import Request
|
|
from datetime import datetime
|
|
from unittest.mock import AsyncMock, patch, MagicMock
|
|
|
|
sys.path.insert(
|
|
0, os.path.abspath("../..")
|
|
) # Adds the parent directory to the system path
|
|
from litellm import Router, CustomLogger
|
|
from litellm.types.utils import StandardLoggingPayload
|
|
|
|
## Get the current directory of the file being run
|
|
pwd = os.path.dirname(os.path.realpath(__file__))
|
|
print(pwd)
|
|
|
|
file_path = os.path.join(pwd, "gettysburg.wav")
|
|
|
|
audio_file = open(file_path, "rb")
|
|
from pathlib import Path
|
|
import litellm
|
|
import pytest
|
|
import asyncio
|
|
|
|
|
|
@pytest.fixture
|
|
def model_list():
|
|
return [
|
|
{
|
|
"model_name": "gpt-3.5-turbo",
|
|
"litellm_params": {
|
|
"model": "gpt-3.5-turbo",
|
|
"api_key": os.getenv("OPENAI_API_KEY"),
|
|
},
|
|
},
|
|
{
|
|
"model_name": "gpt-4o",
|
|
"litellm_params": {
|
|
"model": "gpt-4o",
|
|
"api_key": os.getenv("OPENAI_API_KEY"),
|
|
},
|
|
},
|
|
{
|
|
"model_name": "dall-e-3",
|
|
"litellm_params": {
|
|
"model": "dall-e-3",
|
|
"api_key": os.getenv("OPENAI_API_KEY"),
|
|
},
|
|
},
|
|
{
|
|
"model_name": "cohere-rerank",
|
|
"litellm_params": {
|
|
"model": "cohere/rerank-english-v3.0",
|
|
"api_key": os.getenv("COHERE_API_KEY"),
|
|
},
|
|
},
|
|
{
|
|
"model_name": "claude-3-5-sonnet-20240620",
|
|
"litellm_params": {
|
|
"model": "gpt-3.5-turbo",
|
|
"mock_response": "hi this is macintosh.",
|
|
},
|
|
},
|
|
]
|
|
|
|
|
|
# This file includes the custom callbacks for LiteLLM Proxy
|
|
# Once defined, these can be passed in proxy_config.yaml
|
|
class MyCustomHandler(CustomLogger):
|
|
def __init__(self):
|
|
self.openai_client = None
|
|
|
|
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
|
try:
|
|
# init logging config
|
|
print("logging a transcript kwargs: ", kwargs)
|
|
print("openai client=", kwargs.get("client"))
|
|
self.openai_client = kwargs.get("client")
|
|
self.standard_logging_object: Optional[StandardLoggingPayload] = kwargs.get(
|
|
"standard_logging_object"
|
|
)
|
|
|
|
except Exception:
|
|
pass
|
|
|
|
|
|
# Set litellm.callbacks = [proxy_handler_instance] on the proxy
|
|
@pytest.mark.asyncio
|
|
@pytest.mark.flaky(retries=6, delay=10)
|
|
async def test_transcription_on_router():
|
|
proxy_handler_instance = MyCustomHandler()
|
|
litellm.set_verbose = True
|
|
litellm.callbacks = [proxy_handler_instance]
|
|
print("\n Testing async transcription on router\n")
|
|
try:
|
|
model_list = [
|
|
{
|
|
"model_name": "whisper",
|
|
"litellm_params": {
|
|
"model": "whisper-1",
|
|
},
|
|
},
|
|
{
|
|
"model_name": "whisper",
|
|
"litellm_params": {
|
|
"model": "azure/azure-whisper",
|
|
"api_base": "https://my-endpoint-europe-berri-992.openai.azure.com/",
|
|
"api_key": os.getenv("AZURE_EUROPE_API_KEY"),
|
|
"api_version": "2024-02-15-preview",
|
|
},
|
|
},
|
|
]
|
|
|
|
router = Router(model_list=model_list)
|
|
|
|
router_level_clients = []
|
|
for deployment in router.model_list:
|
|
_deployment_openai_client = router._get_client(
|
|
deployment=deployment,
|
|
kwargs={"model": "whisper-1"},
|
|
client_type="async",
|
|
)
|
|
|
|
router_level_clients.append(str(_deployment_openai_client))
|
|
|
|
## test 1: user facing function
|
|
response = await router.atranscription(
|
|
model="whisper",
|
|
file=audio_file,
|
|
)
|
|
|
|
## test 2: underlying function
|
|
response = await router._atranscription(
|
|
model="whisper",
|
|
file=audio_file,
|
|
)
|
|
print(response)
|
|
|
|
# PROD Test
|
|
# Ensure we ONLY use OpenAI/Azure client initialized on the router level
|
|
await asyncio.sleep(5)
|
|
print("OpenAI Client used= ", proxy_handler_instance.openai_client)
|
|
print("all router level clients= ", router_level_clients)
|
|
assert proxy_handler_instance.openai_client in router_level_clients
|
|
except Exception as e:
|
|
traceback.print_exc()
|
|
pytest.fail(f"Error occurred: {e}")
|
|
|
|
|
|
@pytest.mark.parametrize("mode", ["iterator"]) # "file",
|
|
@pytest.mark.asyncio
|
|
async def test_audio_speech_router(mode):
|
|
litellm.set_verbose = True
|
|
test_logger = MyCustomHandler()
|
|
litellm.callbacks = [test_logger]
|
|
from litellm import Router
|
|
|
|
client = Router(
|
|
model_list=[
|
|
{
|
|
"model_name": "tts",
|
|
"litellm_params": {
|
|
"model": "openai/tts-1",
|
|
},
|
|
},
|
|
]
|
|
)
|
|
|
|
response = await client.aspeech(
|
|
model="tts",
|
|
voice="alloy",
|
|
input="the quick brown fox jumped over the lazy dogs",
|
|
api_base=None,
|
|
api_key=None,
|
|
organization=None,
|
|
project=None,
|
|
max_retries=1,
|
|
timeout=600,
|
|
client=None,
|
|
optional_params={},
|
|
)
|
|
|
|
await asyncio.sleep(3)
|
|
|
|
from litellm.llms.openai.openai import HttpxBinaryResponseContent
|
|
|
|
assert isinstance(response, HttpxBinaryResponseContent)
|
|
|
|
assert test_logger.standard_logging_object is not None
|
|
print(
|
|
"standard_logging_object=",
|
|
json.dumps(test_logger.standard_logging_object, indent=4),
|
|
)
|
|
assert test_logger.standard_logging_object["model_group"] == "tts"
|
|
|
|
|
|
@pytest.mark.asyncio()
|
|
async def test_rerank_endpoint(model_list):
|
|
from litellm.types.utils import RerankResponse
|
|
|
|
router = Router(model_list=model_list)
|
|
|
|
## Test 1: user facing function
|
|
response = await router.arerank(
|
|
model="cohere-rerank",
|
|
query="hello",
|
|
documents=["hello", "world"],
|
|
top_n=3,
|
|
)
|
|
|
|
## Test 2: underlying function
|
|
response = await router._arerank(
|
|
model="cohere-rerank",
|
|
query="hello",
|
|
documents=["hello", "world"],
|
|
top_n=3,
|
|
)
|
|
|
|
print("async re rank response: ", response)
|
|
|
|
assert response.id is not None
|
|
assert response.results is not None
|
|
|
|
RerankResponse.model_validate(response)
|
|
|
|
|
|
@pytest.mark.asyncio()
|
|
@pytest.mark.parametrize(
|
|
"model", ["omni-moderation-latest", "openai/omni-moderation-latest", None]
|
|
)
|
|
async def test_moderation_endpoint(model):
|
|
litellm.set_verbose = True
|
|
router = Router(
|
|
model_list=[
|
|
{
|
|
"model_name": "openai/*",
|
|
"litellm_params": {
|
|
"model": "openai/*",
|
|
},
|
|
},
|
|
{
|
|
"model_name": "*",
|
|
"litellm_params": {
|
|
"model": "openai/*",
|
|
},
|
|
},
|
|
]
|
|
)
|
|
|
|
if model is None:
|
|
response = await router.amoderation(input="hello this is a test")
|
|
else:
|
|
response = await router.amoderation(model=model, input="hello this is a test")
|
|
|
|
print("moderation response: ", response)
|
|
|
|
|
|
@pytest.mark.asyncio()
|
|
async def test_moderation_endpoint_with_api_base():
|
|
"""
|
|
Test that the moderation endpoint respects api_base configuration
|
|
"""
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
custom_api_base = "https://us.api.openai.com/v1"
|
|
|
|
router = Router(
|
|
model_list=[
|
|
{
|
|
"model_name": "openai/omni-moderation-latest",
|
|
"litellm_params": {
|
|
"model": "openai/omni-moderation-latest",
|
|
"api_base": custom_api_base,
|
|
"api_key": "test-key"
|
|
},
|
|
},
|
|
]
|
|
)
|
|
|
|
# Mock the OpenAI client to verify api_base is passed
|
|
with patch("litellm.main.openai_chat_completions._get_openai_client") as mock_get_client:
|
|
mock_client = AsyncMock()
|
|
mock_response = MagicMock()
|
|
mock_response.model_dump.return_value = {
|
|
"id": "modr-123",
|
|
"model": "omni-moderation-latest",
|
|
"results": [
|
|
{
|
|
"flagged": False,
|
|
"categories": {},
|
|
"category_scores": {},
|
|
"category_applied_input_types": {}
|
|
}
|
|
]
|
|
}
|
|
mock_client.moderations.create = AsyncMock(return_value=mock_response)
|
|
mock_get_client.return_value = mock_client
|
|
|
|
response = await router.amoderation(
|
|
model="openai/omni-moderation-latest",
|
|
input="hello this is a test"
|
|
)
|
|
|
|
# Verify that _get_openai_client was called with the custom api_base
|
|
mock_get_client.assert_called()
|
|
call_kwargs = mock_get_client.call_args.kwargs
|
|
assert call_kwargs.get("api_base") == custom_api_base, \
|
|
f"Expected api_base to be {custom_api_base}, but got {call_kwargs.get('api_base')}"
|
|
|
|
print(f"✓ Moderation endpoint correctly uses api_base: {custom_api_base}")
|
|
|
|
|
|
@pytest.mark.parametrize("sync_mode", [True, False])
|
|
@pytest.mark.asyncio
|
|
async def test_aaaaatext_completion_endpoint(model_list, sync_mode):
|
|
router = Router(model_list=model_list)
|
|
|
|
if sync_mode:
|
|
response = router.text_completion(
|
|
model="gpt-3.5-turbo",
|
|
prompt="Hello, how are you?",
|
|
mock_response="I'm fine, thank you!",
|
|
)
|
|
else:
|
|
## Test 1: user facing function
|
|
response = await router.atext_completion(
|
|
model="gpt-3.5-turbo",
|
|
prompt="Hello, how are you?",
|
|
mock_response="I'm fine, thank you!",
|
|
)
|
|
|
|
## Test 2: underlying function
|
|
response_2 = await router._atext_completion(
|
|
model="gpt-3.5-turbo",
|
|
prompt="Hello, how are you?",
|
|
mock_response="I'm fine, thank you!",
|
|
)
|
|
assert response_2.choices[0].text == "I'm fine, thank you!"
|
|
|
|
assert response.choices[0].text == "I'm fine, thank you!"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_router_with_empty_choices(model_list):
|
|
"""
|
|
https://github.com/BerriAI/litellm/issues/8306
|
|
"""
|
|
router = Router(model_list=model_list)
|
|
mock_response = litellm.ModelResponse(
|
|
choices=[],
|
|
usage=litellm.Usage(
|
|
prompt_tokens=10,
|
|
completion_tokens=10,
|
|
total_tokens=20,
|
|
),
|
|
model="gpt-3.5-turbo",
|
|
object="chat.completion",
|
|
created=1723081200,
|
|
).model_dump()
|
|
response = await router.acompletion(
|
|
model="gpt-3.5-turbo",
|
|
messages=[{"role": "user", "content": "Hello, how are you?"}],
|
|
mock_response=mock_response,
|
|
)
|
|
assert response is not None
|
|
|
|
|
|
@pytest.mark.parametrize("sync_mode", [True, False])
|
|
def test_generic_api_call_with_fallbacks_basic(sync_mode):
|
|
"""
|
|
Test both the sync and async versions of generic_api_call_with_fallbacks with a basic successful call
|
|
"""
|
|
# Create a mock function that will be passed to generic_api_call_with_fallbacks
|
|
if sync_mode:
|
|
from unittest.mock import Mock
|
|
|
|
mock_function = Mock()
|
|
mock_function.__name__ = "test_function"
|
|
else:
|
|
mock_function = AsyncMock()
|
|
mock_function.__name__ = "test_function"
|
|
|
|
# Create a mock response
|
|
mock_response = {
|
|
"id": "resp_123456",
|
|
"role": "assistant",
|
|
"content": "This is a test response",
|
|
"model": "test-model",
|
|
"usage": {"input_tokens": 10, "output_tokens": 20},
|
|
}
|
|
mock_function.return_value = mock_response
|
|
|
|
# Create a router with a test model
|
|
router = Router(
|
|
model_list=[
|
|
{
|
|
"model_name": "test-model-alias",
|
|
"litellm_params": {
|
|
"model": "anthropic/test-model",
|
|
"api_key": "fake-api-key",
|
|
},
|
|
}
|
|
]
|
|
)
|
|
|
|
# Call the appropriate generic_api_call_with_fallbacks method
|
|
if sync_mode:
|
|
response = router._generic_api_call_with_fallbacks(
|
|
model="test-model-alias",
|
|
original_function=mock_function,
|
|
messages=[{"role": "user", "content": "Hello"}],
|
|
max_tokens=100,
|
|
)
|
|
else:
|
|
response = asyncio.run(
|
|
router._ageneric_api_call_with_fallbacks(
|
|
model="test-model-alias",
|
|
original_function=mock_function,
|
|
messages=[{"role": "user", "content": "Hello"}],
|
|
max_tokens=100,
|
|
)
|
|
)
|
|
|
|
# Verify the mock function was called
|
|
mock_function.assert_called_once()
|
|
|
|
# Verify the response
|
|
assert response == mock_response
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_aadapter_completion():
|
|
"""
|
|
Test the aadapter_completion method which uses async_function_with_fallbacks
|
|
"""
|
|
# Create a mock for the _aadapter_completion method
|
|
mock_response = {
|
|
"id": "adapter_resp_123",
|
|
"object": "adapter.completion",
|
|
"created": 1677858242,
|
|
"model": "test-model-with-adapter",
|
|
"choices": [
|
|
{
|
|
"text": "This is a test adapter response",
|
|
"index": 0,
|
|
"finish_reason": "stop",
|
|
}
|
|
],
|
|
"usage": {"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30},
|
|
}
|
|
|
|
# Create a router with a patched _aadapter_completion method
|
|
with patch.object(
|
|
Router, "_aadapter_completion", new_callable=AsyncMock
|
|
) as mock_method:
|
|
mock_method.return_value = mock_response
|
|
|
|
router = Router(
|
|
model_list=[
|
|
{
|
|
"model_name": "test-adapter-model",
|
|
"litellm_params": {
|
|
"model": "anthropic/test-model",
|
|
"api_key": "fake-api-key",
|
|
},
|
|
}
|
|
]
|
|
)
|
|
|
|
# Replace the async_function_with_fallbacks with a mock
|
|
router.async_function_with_fallbacks = AsyncMock(return_value=mock_response)
|
|
|
|
# Call the aadapter_completion method
|
|
response = await router.aadapter_completion(
|
|
adapter_id="test-adapter-id",
|
|
model="test-adapter-model",
|
|
prompt="This is a test prompt",
|
|
max_tokens=100,
|
|
)
|
|
|
|
# Verify the response
|
|
assert response == mock_response
|
|
|
|
# Verify async_function_with_fallbacks was called with the right parameters
|
|
router.async_function_with_fallbacks.assert_called_once()
|
|
call_kwargs = router.async_function_with_fallbacks.call_args.kwargs
|
|
assert call_kwargs["adapter_id"] == "test-adapter-id"
|
|
assert call_kwargs["model"] == "test-adapter-model"
|
|
assert call_kwargs["prompt"] == "This is a test prompt"
|
|
assert call_kwargs["max_tokens"] == 100
|
|
assert call_kwargs["original_function"] == router._aadapter_completion
|
|
assert "metadata" in call_kwargs
|
|
assert call_kwargs["metadata"]["model_group"] == "test-adapter-model"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test__aadapter_completion():
|
|
"""
|
|
Test the _aadapter_completion method directly
|
|
"""
|
|
# Create a mock response for litellm.aadapter_completion
|
|
mock_response = {
|
|
"id": "adapter_resp_123",
|
|
"object": "adapter.completion",
|
|
"created": 1677858242,
|
|
"model": "test-model-with-adapter",
|
|
"choices": [
|
|
{
|
|
"text": "This is a test adapter response",
|
|
"index": 0,
|
|
"finish_reason": "stop",
|
|
}
|
|
],
|
|
"usage": {"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30},
|
|
}
|
|
|
|
# Create a router with a mocked litellm.aadapter_completion
|
|
with patch(
|
|
"litellm.aadapter_completion", new_callable=AsyncMock
|
|
) as mock_adapter_completion:
|
|
mock_adapter_completion.return_value = mock_response
|
|
|
|
router = Router(
|
|
model_list=[
|
|
{
|
|
"model_name": "test-adapter-model",
|
|
"litellm_params": {
|
|
"model": "anthropic/test-model",
|
|
"api_key": "fake-api-key",
|
|
},
|
|
}
|
|
]
|
|
)
|
|
|
|
# Mock the async_get_available_deployment method
|
|
router.async_get_available_deployment = AsyncMock(
|
|
return_value={
|
|
"model_name": "test-adapter-model",
|
|
"litellm_params": {
|
|
"model": "test-model",
|
|
"api_key": "fake-api-key",
|
|
},
|
|
"model_info": {
|
|
"id": "test-unique-id",
|
|
},
|
|
}
|
|
)
|
|
|
|
# Mock the async_routing_strategy_pre_call_checks method
|
|
router.async_routing_strategy_pre_call_checks = AsyncMock()
|
|
|
|
# Call the _aadapter_completion method
|
|
response = await router._aadapter_completion(
|
|
adapter_id="test-adapter-id",
|
|
model="test-adapter-model",
|
|
prompt="This is a test prompt",
|
|
max_tokens=100,
|
|
)
|
|
|
|
# Verify the response
|
|
assert response == mock_response
|
|
|
|
# Verify litellm.aadapter_completion was called with the right parameters
|
|
mock_adapter_completion.assert_called_once()
|
|
call_kwargs = mock_adapter_completion.call_args.kwargs
|
|
assert call_kwargs["adapter_id"] == "test-adapter-id"
|
|
assert call_kwargs["model"] == "test-model"
|
|
assert call_kwargs["prompt"] == "This is a test prompt"
|
|
assert call_kwargs["max_tokens"] == 100
|
|
assert call_kwargs["api_key"] == "fake-api-key"
|
|
assert call_kwargs["caching"] == router.cache_responses
|
|
|
|
# Verify the success call was recorded
|
|
assert router.success_calls["test-model"] == 1
|
|
assert router.total_calls["test-model"] == 1
|
|
|
|
# Verify async_routing_strategy_pre_call_checks was called
|
|
router.async_routing_strategy_pre_call_checks.assert_called_once()
|
|
|
|
|
|
def test_initialize_router_endpoints():
|
|
"""
|
|
Test that initialize_router_endpoints correctly sets up all router endpoints
|
|
"""
|
|
# Create a router with a basic model
|
|
router = Router(
|
|
model_list=[
|
|
{
|
|
"model_name": "test-model",
|
|
"litellm_params": {
|
|
"model": "anthropic/test-model",
|
|
"api_key": "fake-api-key",
|
|
},
|
|
}
|
|
]
|
|
)
|
|
|
|
# Explicitly call initialize_router_endpoints
|
|
router.initialize_router_endpoints()
|
|
|
|
# Verify all expected endpoints are initialized
|
|
assert hasattr(router, "amoderation")
|
|
assert hasattr(router, "aanthropic_messages")
|
|
assert hasattr(router, "aresponses")
|
|
assert hasattr(router, "responses")
|
|
assert hasattr(router, "aget_responses")
|
|
assert hasattr(router, "adelete_responses")
|
|
# Verify the endpoints are callable
|
|
assert callable(router.amoderation)
|
|
assert callable(router.aanthropic_messages)
|
|
assert callable(router.aresponses)
|
|
assert callable(router.responses)
|
|
assert callable(router.aget_responses)
|
|
assert callable(router.adelete_responses)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_init_responses_api_endpoints():
|
|
"""
|
|
A simpler test for _init_responses_api_endpoints that focuses on the basic functionality
|
|
"""
|
|
from litellm.responses.utils import ResponsesAPIRequestUtils
|
|
# Create a router with a basic model
|
|
router = Router(
|
|
model_list=[
|
|
{
|
|
"model_name": "test-model",
|
|
"litellm_params": {
|
|
"model": "openai/test-model",
|
|
"api_key": "fake-api-key",
|
|
},
|
|
}
|
|
]
|
|
)
|
|
|
|
# Just mock the _ageneric_api_call_with_fallbacks method
|
|
router._ageneric_api_call_with_fallbacks = AsyncMock()
|
|
|
|
# Add a mock implementation of _get_model_id_from_response_id to the Router instance
|
|
ResponsesAPIRequestUtils.get_model_id_from_response_id = MagicMock(return_value=None)
|
|
|
|
# Call without a response_id (no model extraction should happen)
|
|
await router._init_responses_api_endpoints(
|
|
original_function=AsyncMock(),
|
|
thread_id="thread_xyz"
|
|
)
|
|
|
|
# Verify _ageneric_api_call_with_fallbacks was called but model wasn't changed
|
|
first_call_kwargs = router._ageneric_api_call_with_fallbacks.call_args.kwargs
|
|
assert "model" not in first_call_kwargs
|
|
assert first_call_kwargs["thread_id"] == "thread_xyz"
|
|
|
|
# Reset the mock
|
|
router._ageneric_api_call_with_fallbacks.reset_mock()
|
|
|
|
# Change the return value for the second call
|
|
ResponsesAPIRequestUtils.get_model_id_from_response_id.return_value = "claude-3-sonnet"
|
|
|
|
# Call with a response_id
|
|
await router._init_responses_api_endpoints(
|
|
original_function=AsyncMock(),
|
|
response_id="resp_claude_123"
|
|
)
|
|
|
|
# Verify model was updated in the kwargs
|
|
second_call_kwargs = router._ageneric_api_call_with_fallbacks.call_args.kwargs
|
|
assert second_call_kwargs["model"] == "claude-3-sonnet"
|
|
assert second_call_kwargs["response_id"] == "resp_claude_123"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_init_vector_store_api_endpoints():
|
|
"""
|
|
Test that _init_vector_store_api_endpoints correctly passes custom_llm_provider to kwargs
|
|
"""
|
|
# Create a router with a basic model
|
|
router = Router(
|
|
model_list=[
|
|
{
|
|
"model_name": "test-model",
|
|
"litellm_params": {
|
|
"model": "openai/test-model",
|
|
"api_key": "fake-api-key",
|
|
},
|
|
}
|
|
]
|
|
)
|
|
|
|
# Mock the original function
|
|
mock_original_function = AsyncMock(return_value={"status": "success"})
|
|
|
|
# Call without custom_llm_provider
|
|
result = await router._init_vector_store_api_endpoints(
|
|
original_function=mock_original_function,
|
|
vector_store_id="test-store"
|
|
)
|
|
|
|
# Verify original function was called with correct kwargs
|
|
mock_original_function.assert_called_once_with(vector_store_id="test-store")
|
|
assert result == {"status": "success"}
|
|
|
|
# Reset the mock
|
|
mock_original_function.reset_mock()
|
|
|
|
# Call with custom_llm_provider
|
|
await router._init_vector_store_api_endpoints(
|
|
original_function=mock_original_function,
|
|
custom_llm_provider="openai",
|
|
vector_store_id="test-store"
|
|
)
|
|
|
|
# Verify custom_llm_provider was added to kwargs
|
|
mock_original_function.assert_called_once_with(
|
|
vector_store_id="test-store",
|
|
custom_llm_provider="openai"
|
|
)
|
|
|
|
|
|
def test_apply_default_settings():
|
|
"""
|
|
Test the apply_default_settings method.
|
|
|
|
This test verifies that apply_default_settings correctly initializes
|
|
default pre-call checks and doesn't modify existing router state.
|
|
"""
|
|
# Test with fresh router
|
|
router = Router()
|
|
initial_optional_callbacks = router.optional_callbacks
|
|
|
|
# Test that the method runs without error
|
|
result = router.apply_default_settings()
|
|
|
|
# Verify method returns None as expected
|
|
assert result is None
|
|
|
|
# Verify that optional_callbacks remains None if it was initially None
|
|
# (since default_pre_call_checks is an empty list)
|
|
assert router.optional_callbacks == initial_optional_callbacks
|
|
|
|
# Test with router that already has some optional_callbacks
|
|
router_with_callbacks = Router()
|
|
mock_callback = MagicMock()
|
|
router_with_callbacks.optional_callbacks = [mock_callback]
|
|
|
|
# Apply default settings
|
|
result = router_with_callbacks.apply_default_settings()
|
|
|
|
# Verify method returns None
|
|
assert result is None
|
|
|
|
# Verify existing callbacks are preserved (since we're adding empty list)
|
|
assert mock_callback in router_with_callbacks.optional_callbacks
|
|
|
|
# Test that the method is called during router initialization
|
|
with patch.object(Router, 'apply_default_settings') as mock_apply:
|
|
Router()
|
|
mock_apply.assert_called_once()
|
|
|
|
# Test with mocked add_optional_pre_call_checks to verify internal call
|
|
router_test = Router()
|
|
with patch.object(router_test, 'add_optional_pre_call_checks') as mock_add_checks:
|
|
router_test.apply_default_settings()
|
|
|
|
# Verify add_optional_pre_call_checks was called with empty list
|
|
mock_add_checks.assert_called_once_with([])
|
|
|
|
|
|
|
|
|
|
def test_initialize_core_endpoints():
|
|
"""
|
|
Test that _initialize_core_endpoints correctly sets up all core router endpoints.
|
|
"""
|
|
router = Router(
|
|
model_list=[
|
|
{
|
|
"model_name": "test-model",
|
|
"litellm_params": {
|
|
"model": "anthropic/test-model",
|
|
"api_key": "fake-api-key",
|
|
},
|
|
}
|
|
]
|
|
)
|
|
|
|
router._initialize_core_endpoints()
|
|
|
|
core_endpoints = [
|
|
"amoderation",
|
|
"aanthropic_messages",
|
|
"agenerate_content",
|
|
"aadapter_generate_content",
|
|
"aresponses",
|
|
"afile_delete",
|
|
"afile_content",
|
|
"responses",
|
|
"aget_responses",
|
|
"acancel_responses",
|
|
"adelete_responses",
|
|
"alist_input_items",
|
|
"_arealtime",
|
|
"acreate_fine_tuning_job",
|
|
"acancel_fine_tuning_job",
|
|
"alist_fine_tuning_jobs",
|
|
"aretrieve_fine_tuning_job",
|
|
"afile_list",
|
|
"aimage_edit",
|
|
"allm_passthrough_route",
|
|
]
|
|
|
|
for endpoint in core_endpoints:
|
|
assert hasattr(router, endpoint)
|
|
assert callable(getattr(router, endpoint))
|
|
|
|
|
|
def test_initialize_specialized_endpoints():
|
|
"""
|
|
Test that _initialize_specialized_endpoints correctly sets up specialized endpoints.
|
|
"""
|
|
router = Router(
|
|
model_list=[
|
|
{
|
|
"model_name": "test-model",
|
|
"litellm_params": {
|
|
"model": "openai/test-model",
|
|
"api_key": "fake-api-key",
|
|
},
|
|
}
|
|
]
|
|
)
|
|
|
|
router._initialize_specialized_endpoints()
|
|
|
|
specialized_endpoints = [
|
|
"avector_store_search",
|
|
"avector_store_create",
|
|
"vector_store_search",
|
|
"vector_store_create",
|
|
"agenerate_content",
|
|
"generate_content",
|
|
"agenerate_content_stream",
|
|
"generate_content_stream",
|
|
"aocr",
|
|
"ocr",
|
|
"asearch",
|
|
"search",
|
|
"avideo_generation",
|
|
"video_generation",
|
|
"avideo_list",
|
|
"video_list",
|
|
"avideo_status",
|
|
"video_status",
|
|
"avideo_content",
|
|
"video_content",
|
|
"avideo_remix",
|
|
"video_remix",
|
|
"acreate_container",
|
|
"create_container",
|
|
"alist_containers",
|
|
"list_containers",
|
|
"aretrieve_container",
|
|
"retrieve_container",
|
|
"adelete_container",
|
|
"delete_container",
|
|
"acreate_skill",
|
|
"alist_skills",
|
|
"aget_skill",
|
|
"adelete_skill",
|
|
]
|
|
|
|
for endpoint in specialized_endpoints:
|
|
assert hasattr(router, endpoint)
|
|
assert callable(getattr(router, endpoint))
|
|
|
|
|
|
def test_initialize_vector_store_endpoints():
|
|
"""
|
|
Test that _initialize_vector_store_endpoints correctly sets up vector store endpoints.
|
|
"""
|
|
router = Router(
|
|
model_list=[
|
|
{
|
|
"model_name": "test-model",
|
|
"litellm_params": {
|
|
"model": "openai/test-model",
|
|
"api_key": "fake-api-key",
|
|
},
|
|
}
|
|
]
|
|
)
|
|
|
|
router._initialize_vector_store_endpoints()
|
|
|
|
vector_store_endpoints = [
|
|
"avector_store_search",
|
|
"avector_store_create",
|
|
"vector_store_search",
|
|
"vector_store_create",
|
|
]
|
|
|
|
for endpoint in vector_store_endpoints:
|
|
assert hasattr(router, endpoint)
|
|
assert callable(getattr(router, endpoint))
|
|
|
|
|
|
def test_initialize_vector_store_file_endpoints():
|
|
"""
|
|
Test that _initialize_vector_store_file_endpoints correctly sets up vector store file endpoints.
|
|
"""
|
|
router = Router(
|
|
model_list=[
|
|
{
|
|
"model_name": "test-model",
|
|
"litellm_params": {
|
|
"model": "openai/test-model",
|
|
"api_key": "fake-api-key",
|
|
},
|
|
}
|
|
]
|
|
)
|
|
|
|
router._initialize_vector_store_file_endpoints()
|
|
|
|
vector_store_file_endpoints = [
|
|
"avector_store_file_create",
|
|
"vector_store_file_create",
|
|
"avector_store_file_list",
|
|
"vector_store_file_list",
|
|
"avector_store_file_retrieve",
|
|
"vector_store_file_retrieve",
|
|
"avector_store_file_content",
|
|
"vector_store_file_content",
|
|
"avector_store_file_update",
|
|
"vector_store_file_update",
|
|
"avector_store_file_delete",
|
|
"vector_store_file_delete",
|
|
]
|
|
|
|
for endpoint in vector_store_file_endpoints:
|
|
assert hasattr(router, endpoint)
|
|
assert callable(getattr(router, endpoint))
|
|
|
|
|
|
def test_initialize_google_genai_endpoints():
|
|
"""
|
|
Test that _initialize_google_genai_endpoints correctly sets up Google GenAI endpoints.
|
|
"""
|
|
router = Router(
|
|
model_list=[
|
|
{
|
|
"model_name": "test-model",
|
|
"litellm_params": {
|
|
"model": "openai/test-model",
|
|
"api_key": "fake-api-key",
|
|
},
|
|
}
|
|
]
|
|
)
|
|
|
|
router._initialize_google_genai_endpoints()
|
|
|
|
google_genai_endpoints = [
|
|
"agenerate_content",
|
|
"generate_content",
|
|
"agenerate_content_stream",
|
|
"generate_content_stream",
|
|
]
|
|
|
|
for endpoint in google_genai_endpoints:
|
|
assert hasattr(router, endpoint)
|
|
assert callable(getattr(router, endpoint))
|
|
|
|
|
|
def test_initialize_ocr_search_endpoints():
|
|
"""
|
|
Test that _initialize_ocr_search_endpoints correctly sets up OCR and search endpoints.
|
|
"""
|
|
router = Router(
|
|
model_list=[
|
|
{
|
|
"model_name": "test-model",
|
|
"litellm_params": {
|
|
"model": "openai/test-model",
|
|
"api_key": "fake-api-key",
|
|
},
|
|
}
|
|
]
|
|
)
|
|
|
|
router._initialize_ocr_search_endpoints()
|
|
|
|
ocr_search_endpoints = [
|
|
"aocr",
|
|
"ocr",
|
|
"asearch",
|
|
"search",
|
|
]
|
|
|
|
for endpoint in ocr_search_endpoints:
|
|
assert hasattr(router, endpoint)
|
|
assert callable(getattr(router, endpoint))
|
|
|
|
|
|
def test_initialize_video_endpoints():
|
|
"""
|
|
Test that _initialize_video_endpoints correctly sets up video endpoints.
|
|
"""
|
|
router = Router(
|
|
model_list=[
|
|
{
|
|
"model_name": "test-model",
|
|
"litellm_params": {
|
|
"model": "openai/test-model",
|
|
"api_key": "fake-api-key",
|
|
},
|
|
}
|
|
]
|
|
)
|
|
|
|
router._initialize_video_endpoints()
|
|
|
|
video_endpoints = [
|
|
"avideo_generation",
|
|
"video_generation",
|
|
"avideo_list",
|
|
"video_list",
|
|
"avideo_status",
|
|
"video_status",
|
|
"avideo_content",
|
|
"video_content",
|
|
"avideo_remix",
|
|
"video_remix",
|
|
]
|
|
|
|
for endpoint in video_endpoints:
|
|
assert hasattr(router, endpoint)
|
|
assert callable(getattr(router, endpoint))
|
|
|
|
|
|
def test_initialize_container_endpoints():
|
|
"""
|
|
Test that _initialize_container_endpoints correctly sets up container endpoints.
|
|
"""
|
|
router = Router(
|
|
model_list=[
|
|
{
|
|
"model_name": "test-model",
|
|
"litellm_params": {
|
|
"model": "openai/test-model",
|
|
"api_key": "fake-api-key",
|
|
},
|
|
}
|
|
]
|
|
)
|
|
|
|
router._initialize_container_endpoints()
|
|
|
|
container_endpoints = [
|
|
"acreate_container",
|
|
"create_container",
|
|
"alist_containers",
|
|
"list_containers",
|
|
"aretrieve_container",
|
|
"retrieve_container",
|
|
"adelete_container",
|
|
"delete_container",
|
|
]
|
|
|
|
for endpoint in container_endpoints:
|
|
assert hasattr(router, endpoint)
|
|
assert callable(getattr(router, endpoint))
|
|
|
|
|
|
def test_initialize_skills_endpoints():
|
|
"""
|
|
Test that _initialize_skills_endpoints correctly sets up skills endpoints.
|
|
"""
|
|
router = Router(
|
|
model_list=[
|
|
{
|
|
"model_name": "test-model",
|
|
"litellm_params": {
|
|
"model": "anthropic/test-model",
|
|
"api_key": "fake-api-key",
|
|
},
|
|
}
|
|
]
|
|
)
|
|
|
|
router._initialize_skills_endpoints()
|
|
|
|
skills_endpoints = [
|
|
"acreate_skill",
|
|
"alist_skills",
|
|
"aget_skill",
|
|
"adelete_skill",
|
|
]
|
|
|
|
for endpoint in skills_endpoints:
|
|
assert hasattr(router, endpoint)
|
|
assert callable(getattr(router, endpoint))
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_init_containers_api_endpoints():
|
|
"""
|
|
Test that _init_containers_api_endpoints calls the original function
|
|
directly without model-based routing.
|
|
"""
|
|
router = Router(model_list=[])
|
|
|
|
mock_response = {"id": "cntr_test", "name": "Test Container"}
|
|
mock_original_function = AsyncMock(return_value=mock_response)
|
|
|
|
result = await router._init_containers_api_endpoints(
|
|
original_function=mock_original_function,
|
|
custom_llm_provider="openai",
|
|
name="Test Container"
|
|
)
|
|
|
|
mock_original_function.assert_called_once_with(
|
|
custom_llm_provider="openai",
|
|
name="Test Container"
|
|
)
|
|
assert result == mock_response
|