Merge branch 'BerriAI:main' into fix-today-selector-date-mutation-bug

This commit is contained in:
Cole McIntosh
2025-06-25 10:51:53 -06:00
committed by GitHub
66 changed files with 3672 additions and 339 deletions
+89
View File
@@ -0,0 +1,89 @@
# CLAUDE.md
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
## Development Commands
### Installation
- `make install-dev` - Install core development dependencies
- `make install-proxy-dev` - Install proxy development dependencies with full feature set
- `make install-test-deps` - Install all test dependencies
### Testing
- `make test` - Run all tests
- `make test-unit` - Run unit tests (tests/test_litellm) with 4 parallel workers
- `make test-integration` - Run integration tests (excludes unit tests)
- `pytest tests/` - Direct pytest execution
### Code Quality
- `make lint` - Run all linting (Ruff, MyPy, Black, circular imports, import safety)
- `make format` - Apply Black code formatting
- `make lint-ruff` - Run Ruff linting only
- `make lint-mypy` - Run MyPy type checking only
### Single Test Files
- `poetry run pytest tests/path/to/test_file.py -v` - Run specific test file
- `poetry run pytest tests/path/to/test_file.py::test_function -v` - Run specific test
## Architecture Overview
LiteLLM is a unified interface for 100+ LLM providers with two main components:
### Core Library (`litellm/`)
- **Main entry point**: `litellm/main.py` - Contains core completion() function
- **Provider implementations**: `litellm/llms/` - Each provider has its own subdirectory
- **Router system**: `litellm/router.py` + `litellm/router_utils/` - Load balancing and fallback logic
- **Type definitions**: `litellm/types/` - Pydantic models and type hints
- **Integrations**: `litellm/integrations/` - Third-party observability, caching, logging
- **Caching**: `litellm/caching/` - Multiple cache backends (Redis, in-memory, S3, etc.)
### Proxy Server (`litellm/proxy/`)
- **Main server**: `proxy_server.py` - FastAPI application
- **Authentication**: `auth/` - API key management, JWT, OAuth2
- **Database**: `db/` - Prisma ORM with PostgreSQL/SQLite support
- **Management endpoints**: `management_endpoints/` - Admin APIs for keys, teams, models
- **Pass-through endpoints**: `pass_through_endpoints/` - Provider-specific API forwarding
- **Guardrails**: `guardrails/` - Safety and content filtering hooks
- **UI Dashboard**: Served from `_experimental/out/` (Next.js build)
## Key Patterns
### Provider Implementation
- Providers inherit from base classes in `litellm/llms/base.py`
- Each provider has transformation functions for input/output formatting
- Support both sync and async operations
- Handle streaming responses and function calling
### Error Handling
- Provider-specific exceptions mapped to OpenAI-compatible errors
- Fallback logic handled by Router system
- Comprehensive logging through `litellm/_logging.py`
### Configuration
- YAML config files for proxy server (see `proxy/example_config_yaml/`)
- Environment variables for API keys and settings
- Database schema managed via Prisma (`proxy/schema.prisma`)
## Development Notes
### Code Style
- Uses Black formatter, Ruff linter, MyPy type checker
- Pydantic v2 for data validation
- Async/await patterns throughout
- Type hints required for all public APIs
### Testing Strategy
- Unit tests in `tests/test_litellm/`
- Integration tests for each provider in `tests/llm_translation/`
- Proxy tests in `tests/proxy_unit_tests/`
- Load tests in `tests/load_tests/`
### Database Migrations
- Prisma handles schema migrations
- Migration files auto-generated with `prisma migrate dev`
- Always test migrations against both PostgreSQL and SQLite
### Enterprise Features
- Enterprise-specific code in `enterprise/` directory
- Optional features enabled via environment variables
- Separate licensing and authentication for enterprise features
@@ -17,6 +17,8 @@ LiteLLM integrates with vector stores, allowing your models to access your organ
## Supported Vector Stores
- [Bedrock Knowledge Bases](https://aws.amazon.com/bedrock/knowledge-bases/)
- [OpenAI Vector Stores](https://platform.openai.com/docs/api-reference/vector-stores/search)
- [Azure Vector Stores](https://learn.microsoft.com/en-us/azure/ai-services/openai/how-to/file-search?tabs=python#vector-stores
## Quick Start
@@ -233,3 +233,63 @@ for event in response:
</TabItem>
</Tabs>
## Calling via `/chat/completions`
You can also call the Azure Responses API via the `/chat/completions` endpoint.
<Tabs>
<TabItem value="litellm-sdk" label="LiteLLM SDK">
```python showLineNumbers
from litellm import completion
import os
os.environ["AZURE_API_BASE"] = "https://my-endpoint-sweden-berri992.openai.azure.com/"
os.environ["AZURE_API_VERSION"] = "2023-03-15-preview"
os.environ["AZURE_API_KEY"] = "my-api-key"
response = completion(
model="azure/responses/my-custom-o1-pro",
messages=[{"role": "user", "content": "Hello world"}],
)
print(response)
```
</TabItem>
<TabItem value="proxy" label="OpenAI SDK with LiteLLM Proxy">
1. Setup config.yaml
```yaml showLineNumbers
model_list:
- model_name: my-custom-o1-pro
litellm_params:
model: azure/responses/my-custom-o1-pro
api_key: os.environ/AZURE_API_KEY
api_base: https://my-endpoint-sweden-berri992.openai.azure.com/
api_version: 2023-03-15-preview
```
2. Start LiteLLM proxy
```bash
litellm --config /path/to/config.yaml
# RUNNING on http://0.0.0.0:4000
```
3. Test it!
```bash
curl http://localhost:4000/v1/chat/completions \
-X POST \
-H "Content-Type: application/json" \
-H "Authorization: Bearer $LITELLM_API_KEY" \
-d '{
"model": "my-custom-o1-pro",
"messages": [{"role": "user", "content": "Hello world"}]
}'
```
</TabItem>
</Tabs>
@@ -13,11 +13,12 @@ Call your custom torch-serve / internal LLM APIs via LiteLLM
:::
Supported Routes:
- `/v1/chat/completions` -> `litellm.completion`
- `/v1/completions` -> `litellm.text_completion`
- `/v1/embeddings` -> `litellm.embedding`
- `/v1/images/generations` -> `litellm.image_generation`
- `/v1/chat/completions` -> `litellm.acompletion`
- `/v1/completions` -> `litellm.atext_completion`
- `/v1/embeddings` -> `litellm.aembedding`
- `/v1/images/generations` -> `litellm.aimage_generation`
- `/v1/messages` -> `litellm.acompletion`
## Quick Start
@@ -262,6 +263,102 @@ Expected Response
}
```
## Anthropic `/v1/messages`
- Write the integration for .acompletion
- litellm will transform it to /v1/messages
1. Setup your `custom_handler.py` file
```python
import litellm
from litellm import CustomLLM, completion, get_llm_provider
class MyCustomLLM(CustomLLM):
async def acompletion(self, *args, **kwargs) -> litellm.ModelResponse:
return litellm.completion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Hello world"}],
mock_response="Hi!",
) # type: ignore
my_custom_llm = MyCustomLLM()
```
2. Add to `config.yaml`
In the config below, we pass
python_filename: `custom_handler.py`
custom_handler_instance_name: `my_custom_llm`. This is defined in Step 1
custom_handler: `custom_handler.my_custom_llm`
```yaml
model_list:
- model_name: "test-model"
litellm_params:
model: "openai/text-embedding-ada-002"
- model_name: "my-custom-model"
litellm_params:
model: "my-custom-llm/my-model"
litellm_settings:
custom_provider_map:
- {"provider": "my-custom-llm", "custom_handler": custom_handler.my_custom_llm}
```
```bash
litellm --config /path/to/config.yaml
```
3. Test it!
```bash
curl -L -X POST 'http://0.0.0.0:4000/v1/messages' \
-H 'anthropic-version: 2023-06-01' \
-H 'content-type: application/json' \
-H 'Authorization: Bearer sk-1234' \
-d '{
"model": "my-custom-model",
"max_tokens": 1024,
"messages": [{
"role": "user",
"content": [
{
"type": "text",
"text": "What are the key findings in this document 12?"
}]
}]
}'
```
Expected Response
```json
{
"id": "chatcmpl-Bm4qEp4h4vCe7Zi4Gud1MAxTWgibO",
"type": "message",
"role": "assistant",
"model": "gpt-3.5-turbo-0125",
"stop_sequence": null,
"usage": {
"input_tokens": 18,
"output_tokens": 44
},
"content": [
{
"type": "text",
"text": "Without the specific document being provided, it is not possible to determine the key findings within it. If you can provide the content or a summary of document 12, I would be happy to help identify the key findings."
}
],
"stop_reason": "end_turn"
}
```
## Additional Parameters
Additional parameters are passed inside `optional_params` key in the `completion` or `image_generation` function.
+11 -6
View File
@@ -292,13 +292,18 @@ Let's also set the default models to `no-default-models`. This means a user can
</TabItem>
<TabItem value="yaml" label="YAML">
:::info
Team must be created before setting it as the default team.
:::
```yaml
default_internal_user_params: # Default Params used when a new user signs in Via SSO
user_role: "internal_user" # one of "internal_user", "internal_user_viewer",
models: ["no-default-models"] # Optional[List[str]], optional): models to be used by the user
teams: # Optional[List[NewUserRequestTeam]], optional): teams to be used by the user
- team_id: "team_id_1" # Required[str]: team_id to be used by the user
user_role: "user" # Optional[str], optional): Default role in the team. Values: "user" or "admin". Defaults to "user"
litellm_settings:
default_internal_user_params: # Default Params used when a new user signs in Via SSO
user_role: "internal_user" # one of "internal_user", "internal_user_viewer",
models: ["no-default-models"] # Optional[List[str]], optional): models to be used by the user
teams: # Optional[List[NewUserRequestTeam]], optional): teams to be used by the user
- team_id: "team_id_1" # Required[str]: team_id to be used by the user
user_role: "user" # Optional[str], optional): Default role in the team. Values: "user" or "admin". Defaults to "user"
```
</TabItem>
@@ -0,0 +1,30 @@
from typing import Optional
from litellm.proxy._types import GenerateKeyRequest, LiteLLM_TeamTable
def add_team_member_key_duration(
team_table: Optional[LiteLLM_TeamTable],
data: GenerateKeyRequest,
) -> GenerateKeyRequest:
if team_table is None:
return data
if data.user_id is None: # only apply for team member keys, not service accounts
return data
if (
team_table.metadata is not None
and team_table.metadata.get("team_member_key_duration") is not None
):
data.duration = team_table.metadata["team_member_key_duration"]
return data
def apply_enterprise_key_management_params(
data: GenerateKeyRequest,
team_table: Optional[LiteLLM_TeamTable],
) -> GenerateKeyRequest:
data = add_team_member_key_duration(team_table, data)
return data
+1
View File
@@ -1141,6 +1141,7 @@ from .router import Router
from .assistants.main import *
from .batches.main import *
from .images.main import *
from .vector_stores import *
from .batch_completion.main import * # type: ignore
from .rerank_api.main import *
from .llms.anthropic.experimental_pass_through.messages.handler import *
@@ -0,0 +1,409 @@
# +-------------------------------------------------------------+
#
# Add Bedrock Knowledge Base Context to your LLM calls
#
# +-------------------------------------------------------------+
# Thank you users! We ❤️ you! - Krrish & Ishaan
import json
from datetime import datetime
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
import litellm
from litellm._logging import verbose_logger, verbose_proxy_logger
from litellm.integrations.custom_logger import CustomLogger
from litellm.integrations.vector_store_integrations.base_vector_store import (
BaseVectorStore,
)
from litellm.llms.bedrock.base_aws_llm import BaseAWSLLM
from litellm.llms.custom_httpx.http_handler import (
get_async_httpx_client,
httpxSpecialProvider,
)
from litellm.types.integrations.rag.bedrock_knowledgebase import (
BedrockKBContent,
BedrockKBGuardrailConfiguration,
BedrockKBRequest,
BedrockKBResponse,
BedrockKBRetrievalConfiguration,
BedrockKBRetrievalQuery,
BedrockKBRetrievalResult,
)
from litellm.types.llms.openai import AllMessageValues, ChatCompletionUserMessage
from litellm.types.utils import StandardLoggingVectorStoreRequest
from litellm.types.vector_stores import (
VectorStoreResultContent,
VectorStoreSearchResponse,
VectorStoreSearchResult,
)
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
else:
LiteLLMLoggingObj = Any
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import StandardCallbackDynamicParams
else:
StandardCallbackDynamicParams = Any
class BedrockVectorStore(BaseVectorStore, BaseAWSLLM):
CONTENT_PREFIX_STRING = "Context: \n\n"
CUSTOM_LLM_PROVIDER = "bedrock"
def __init__(
self,
**kwargs,
):
self.async_handler = get_async_httpx_client(
llm_provider=httpxSpecialProvider.LoggingCallback
)
# store kwargs as optional_params
self.optional_params = kwargs
super().__init__(**kwargs)
BaseAWSLLM.__init__(self)
async def async_get_chat_completion_prompt(
self,
model: str,
messages: List[AllMessageValues],
non_default_params: dict,
prompt_id: Optional[str],
prompt_variables: Optional[dict],
dynamic_callback_params: StandardCallbackDynamicParams,
litellm_logging_obj: LiteLLMLoggingObj,
tools: Optional[List[Dict]] = None,
prompt_label: Optional[str] = None,
) -> Tuple[str, List[AllMessageValues], dict]:
"""
Retrieves the context from the Bedrock Knowledge Base and appends it to the messages.
"""
if litellm.vector_store_registry is None:
return model, messages, non_default_params
vector_store_ids = litellm.vector_store_registry.pop_vector_store_ids_to_run(
non_default_params=non_default_params, tools=tools
)
vector_store_request_metadata: List[StandardLoggingVectorStoreRequest] = []
if vector_store_ids:
for vector_store_id in vector_store_ids:
start_time = datetime.now()
query = self._get_kb_query_from_messages(messages)
bedrock_kb_response = await self.make_bedrock_kb_retrieve_request(
knowledge_base_id=vector_store_id,
query=query,
non_default_params=non_default_params,
)
verbose_logger.debug(
f"Bedrock Knowledge Base Response: {bedrock_kb_response}"
)
(
context_message,
context_string,
) = self.get_chat_completion_message_from_bedrock_kb_response(
bedrock_kb_response
)
if context_message is not None:
messages.append(context_message)
#################################################################################################
########## LOGGING for Standard Logging Payload, Langfuse, s3, LiteLLM DB etc. ##################
#################################################################################################
vector_store_search_response: VectorStoreSearchResponse = (
self.transform_bedrock_kb_response_to_vector_store_search_response(
bedrock_kb_response=bedrock_kb_response, query=query
)
)
vector_store_request_metadata.append(
StandardLoggingVectorStoreRequest(
vector_store_id=vector_store_id,
query=query,
vector_store_search_response=vector_store_search_response,
custom_llm_provider=self.CUSTOM_LLM_PROVIDER,
start_time=start_time.timestamp(),
end_time=datetime.now().timestamp(),
)
)
litellm_logging_obj.model_call_details[
"vector_store_request_metadata"
] = vector_store_request_metadata
return model, messages, non_default_params
def transform_bedrock_kb_response_to_vector_store_search_response(
self,
bedrock_kb_response: BedrockKBResponse,
query: str,
) -> VectorStoreSearchResponse:
"""
Transform a BedrockKBResponse to a VectorStoreSearchResponse
"""
retrieval_results: Optional[
List[BedrockKBRetrievalResult]
] = bedrock_kb_response.get("retrievalResults", None)
vector_store_search_response: VectorStoreSearchResponse = (
VectorStoreSearchResponse(search_query=query, data=[])
)
if retrieval_results is None:
return vector_store_search_response
vector_search_response_data: List[VectorStoreSearchResult] = []
for retrieval_result in retrieval_results:
content: Optional[BedrockKBContent] = retrieval_result.get("content", None)
if content is None:
continue
content_text: Optional[str] = content.get("text", None)
if content_text is None:
continue
vector_store_search_result: VectorStoreSearchResult = (
VectorStoreSearchResult(
score=retrieval_result.get("score", None),
content=[VectorStoreResultContent(text=content_text, type="text")],
)
)
vector_search_response_data.append(vector_store_search_result)
vector_store_search_response["data"] = vector_search_response_data
return vector_store_search_response
def _get_kb_query_from_messages(self, messages: List[AllMessageValues]) -> str:
"""
Uses the text `content` field of the last message in the list of messages
"""
if len(messages) == 0:
return ""
last_message = messages[-1]
last_message_content = last_message.get("content", None)
if last_message_content is None:
return ""
if isinstance(last_message_content, str):
return last_message_content
elif isinstance(last_message_content, list):
return "\n".join([item.get("text", "") for item in last_message_content])
return ""
def _prepare_request(
self,
credentials: Any,
data: BedrockKBRequest,
optional_params: dict,
aws_region_name: str,
api_base: str,
extra_headers: Optional[dict] = None,
) -> Any:
"""
Prepare a signed AWS request.
Args:
credentials: AWS credentials
data: Request data
optional_params: Additional parameters
aws_region_name: AWS region name
api_base: Base API URL
extra_headers: Additional headers
Returns:
AWSRequest: A signed AWS request
"""
try:
from botocore.auth import SigV4Auth
from botocore.awsrequest import AWSRequest
except ImportError:
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
sigv4 = SigV4Auth(credentials, "bedrock", aws_region_name)
encoded_data = json.dumps(data).encode("utf-8")
headers = {"Content-Type": "application/json"}
if extra_headers is not None:
headers = {"Content-Type": "application/json", **extra_headers}
request = AWSRequest(
method="POST", url=api_base, data=encoded_data, headers=headers
)
sigv4.add_auth(request)
if extra_headers is not None and "Authorization" in extra_headers:
# prevent sigv4 from overwriting the auth header
request.headers["Authorization"] = extra_headers["Authorization"]
return request.prepare()
async def make_bedrock_kb_retrieve_request(
self,
knowledge_base_id: str,
query: str,
guardrail_id: Optional[str] = None,
guardrail_version: Optional[str] = None,
next_token: Optional[str] = None,
retrieval_configuration: Optional[BedrockKBRetrievalConfiguration] = None,
non_default_params: Optional[dict] = None,
) -> BedrockKBResponse:
"""
Make a Bedrock Knowledge Base retrieve request.
Args:
knowledge_base_id (str): The unique identifier of the knowledge base to query
query (str): The query text to search for
guardrail_id (Optional[str]): The guardrail ID to apply
guardrail_version (Optional[str]): The version of the guardrail to apply
next_token (Optional[str]): Token for pagination
retrieval_configuration (Optional[BedrockKBRetrievalConfiguration]): Configuration for the retrieval process
Returns:
BedrockKBRetrievalResponse: A typed response object containing the retrieval results
"""
from fastapi import HTTPException
non_default_params = non_default_params or {}
credentials_dict: Dict[str, Any] = {}
if litellm.vector_store_registry is not None:
credentials_dict = (
litellm.vector_store_registry.get_credentials_for_vector_store(
knowledge_base_id
)
)
credentials = self.get_credentials(
aws_access_key_id=credentials_dict.get(
"aws_access_key_id", non_default_params.get("aws_access_key_id", None)
),
aws_secret_access_key=credentials_dict.get(
"aws_secret_access_key",
non_default_params.get("aws_secret_access_key", None),
),
aws_session_token=credentials_dict.get(
"aws_session_token", non_default_params.get("aws_session_token", None)
),
aws_region_name=credentials_dict.get(
"aws_region_name", non_default_params.get("aws_region_name", None)
),
aws_session_name=credentials_dict.get(
"aws_session_name", non_default_params.get("aws_session_name", None)
),
aws_profile_name=credentials_dict.get(
"aws_profile_name", non_default_params.get("aws_profile_name", None)
),
aws_role_name=credentials_dict.get(
"aws_role_name", non_default_params.get("aws_role_name", None)
),
aws_web_identity_token=credentials_dict.get(
"aws_web_identity_token",
non_default_params.get("aws_web_identity_token", None),
),
aws_sts_endpoint=credentials_dict.get(
"aws_sts_endpoint", non_default_params.get("aws_sts_endpoint", None)
),
)
aws_region_name = self.get_aws_region_name_for_non_llm_api_calls(
aws_region_name=credentials_dict.get(
"aws_region_name", non_default_params.get("aws_region_name", None)
),
)
# Prepare request data
request_data: BedrockKBRequest = BedrockKBRequest(
retrievalQuery=BedrockKBRetrievalQuery(text=query),
)
if next_token:
request_data["nextToken"] = next_token
if retrieval_configuration:
request_data["retrievalConfiguration"] = retrieval_configuration
if guardrail_id and guardrail_version:
request_data["guardrailConfiguration"] = BedrockKBGuardrailConfiguration(
guardrailId=guardrail_id, guardrailVersion=guardrail_version
)
verbose_logger.debug(
f"Request Data: {json.dumps(request_data, indent=4, default=str)}"
)
# Prepare the request
api_base = f"https://bedrock-agent-runtime.{aws_region_name}.amazonaws.com/knowledgebases/{knowledge_base_id}/retrieve"
prepared_request = self._prepare_request(
credentials=credentials,
data=request_data,
optional_params=self.optional_params,
aws_region_name=aws_region_name,
api_base=api_base,
)
verbose_proxy_logger.debug(
"Bedrock Knowledge Base request body: %s, url %s, headers: %s",
request_data,
prepared_request.url,
prepared_request.headers,
)
response = await self.async_handler.post(
url=prepared_request.url,
data=prepared_request.body, # type: ignore
headers=prepared_request.headers, # type: ignore
)
verbose_proxy_logger.debug("Bedrock Knowledge Base response: %s", response.text)
if response.status_code == 200:
response_data = response.json()
return BedrockKBResponse(**response_data)
else:
verbose_proxy_logger.error(
"Bedrock Knowledge Base: error in response. Status code: %s, response: %s",
response.status_code,
response.text,
)
raise HTTPException(
status_code=response.status_code,
detail={
"error": "Error calling Bedrock Knowledge Base",
"response": response.text,
},
)
@staticmethod
def get_initialized_custom_logger() -> Optional[CustomLogger]:
from litellm.litellm_core_utils.litellm_logging import (
_init_custom_logger_compatible_class,
)
return _init_custom_logger_compatible_class(
logging_integration="bedrock_vector_store",
internal_usage_cache=None,
llm_router=None,
)
@staticmethod
def get_chat_completion_message_from_bedrock_kb_response(
response: BedrockKBResponse,
) -> Tuple[Optional[ChatCompletionUserMessage], str]:
"""
Retrieves the context from the Bedrock Knowledge Base response and returns a ChatCompletionUserMessage object.
"""
retrieval_results: Optional[List[BedrockKBRetrievalResult]] = response.get(
"retrievalResults", None
)
if retrieval_results is None:
return None, ""
# string to combine the context from the knowledge base
context_string: str = BedrockVectorStore.CONTENT_PREFIX_STRING
for retrieval_result in retrieval_results:
retrieval_result_content: Optional[BedrockKBContent] = (
retrieval_result.get("content", None) or {}
)
if retrieval_result_content is None:
continue
retrieval_result_text: Optional[str] = retrieval_result_content.get(
"text", None
)
if retrieval_result_text is None:
continue
context_string += retrieval_result_text
message = ChatCompletionUserMessage(
role="user",
content=context_string,
)
return message, context_string
@@ -12,7 +12,9 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
import litellm
from litellm._logging import verbose_logger, verbose_proxy_logger
from litellm.integrations.custom_logger import CustomLogger
from litellm.integrations.vector_stores.base_vector_store import BaseVectorStore
from litellm.integrations.vector_store_integrations.base_vector_store import (
BaseVectorStore,
)
from litellm.llms.bedrock.base_aws_llm import BaseAWSLLM
from litellm.llms.custom_httpx.http_handler import (
get_async_httpx_client,
@@ -32,7 +32,9 @@ from litellm.integrations.opentelemetry import OpenTelemetry
from litellm.integrations.opik.opik import OpikLogger
from litellm.integrations.prometheus import PrometheusLogger
from litellm.integrations.s3_v2 import S3Logger
from litellm.integrations.vector_stores.bedrock_vector_store import BedrockVectorStore
from litellm.integrations.vector_store_integrations.bedrock_vector_store import (
BedrockVectorStore,
)
from litellm.proxy.hooks.dynamic_rate_limiter import _PROXY_DynamicRateLimitHandler
@@ -54,7 +54,9 @@ from litellm.integrations.custom_guardrail import CustomGuardrail
from litellm.integrations.custom_logger import CustomLogger
from litellm.integrations.deepeval.deepeval import DeepEvalLogger
from litellm.integrations.mlflow import MlflowLogger
from litellm.integrations.vector_stores.bedrock_vector_store import BedrockVectorStore
from litellm.integrations.vector_store_integrations.bedrock_vector_store import (
BedrockVectorStore,
)
from litellm.litellm_core_utils.get_litellm_params import get_litellm_params
from litellm.litellm_core_utils.llm_cost_calc.tool_call_cost_tracking import (
StandardBuiltInToolCostTracking,
@@ -16,7 +16,6 @@ from typing import (
Optional,
Union,
cast,
Set,
)
from litellm.types.llms.openai import (
@@ -508,6 +507,7 @@ def unpack_defs(schema: dict, defs: dict) -> None:
"""
import copy
from collections import deque
# Combine the defs handed down by the caller with defs/definitions found on
# the current node. Local keys shadow parent keys to match JSON-schema
@@ -518,14 +518,17 @@ def unpack_defs(schema: dict, defs: dict) -> None:
**schema.get("definitions", {}),
}
def _walk_and_resolve(node: Any, active_defs: dict, seen: Set[int]): # type: ignore[name-defined]
"""Depth-first resolver that replaces ``{"$ref": "#/defs/Foo"}`` with
the *actual* ``Foo`` schema.
"""
# Avoid infinite recursion on self-referential schemas
# Use iterative approach with queue to avoid recursion
# Each item in queue is (node, parent_container, key/index, active_defs, seen_ids)
queue: deque[tuple[Any, Union[dict, list, None], Union[str, int, None], dict, set]] = deque([(schema, None, None, root_defs, set())])
while queue:
node, parent, key, active_defs, seen = queue.popleft()
# Avoid infinite loops on self-referential schemas
if id(node) in seen:
return node
continue
seen = seen.copy() # Create new set for this branch
seen.add(id(node))
# ----------------------------- dict -----------------------------
@@ -536,7 +539,7 @@ def unpack_defs(schema: dict, defs: dict) -> None:
target_schema = active_defs.get(ref_name)
# Unknown reference leave untouched
if target_schema is None:
return node
continue
# Merge defs from the target to capture nested definitions
child_defs = {
@@ -545,12 +548,24 @@ def unpack_defs(schema: dict, defs: dict) -> None:
**target_schema.get("definitions", {}),
}
# Recursively resolve the target *copy* to avoid mutating the
# shared definition map.
resolved = _walk_and_resolve(copy.deepcopy(target_schema), child_defs, seen)
return resolved
# Replace the reference with resolved copy
resolved = copy.deepcopy(target_schema)
if parent is not None and key is not None:
if isinstance(parent, dict) and isinstance(key, str):
parent[key] = resolved
elif isinstance(parent, list) and isinstance(key, int):
parent[key] = resolved
else:
# This is the root schema itself
schema.clear()
schema.update(resolved)
resolved = schema
# Add resolved node to queue for further processing
queue.append((resolved, parent, key, child_defs, seen))
continue
# --- Case 2: regular dict recurse into its values ---
# --- Case 2: regular dict process its values ---
# Update defs with any nested $defs/definitions present *here*.
current_defs = {
**active_defs,
@@ -558,32 +573,15 @@ def unpack_defs(schema: dict, defs: dict) -> None:
**node.get("definitions", {}),
}
for key, val in list(node.items()):
node[key] = _walk_and_resolve(val, current_defs, seen)
return node
# Add all dict values to queue
for k, v in node.items():
queue.append((v, node, k, current_defs, seen))
# ---------------------------- list ------------------------------
if isinstance(node, list):
elif isinstance(node, list):
# Add all list items to queue
for idx, item in enumerate(node):
node[idx] = _walk_and_resolve(item, active_defs, seen)
return node
# -------------------------- primitive ---------------------------
return node
# Kick off traversal
resolved_root = _walk_and_resolve(schema, root_defs, set())
# If the resolver returned a *different* dict (e.g., the root itself was a
# $ref), mirror the changes back into the original object so that callers
# holding a reference to ``schema`` see the updated structure.
if resolved_root is not schema:
schema.clear()
if isinstance(resolved_root, dict):
schema.update(resolved_root)
else:
# In the very unlikely case the root was resolved to a non-dict
# (e.g., a primitive), replace in-place via a sentinel key.
schema["__resolved_value__"] = resolved_root # type: ignore
queue.append((item, node, idx, active_defs, seen))
def _get_image_mime_type_from_url(url: str) -> Optional[str]:
@@ -119,6 +119,8 @@ def anthropic_messages_handler(
"""
Makes Anthropic `/v1/messages` API calls In the Anthropic API Spec
"""
from litellm.types.utils import LlmProviders
local_vars = locals()
is_async = kwargs.pop("is_async", False)
# Use provided client or create a new one
@@ -141,12 +143,17 @@ def anthropic_messages_handler(
api_key=litellm_params.api_key,
)
anthropic_messages_provider_config: Optional[
BaseAnthropicMessagesConfig
] = ProviderConfigManager.get_provider_anthropic_messages_config(
model=model,
provider=litellm.LlmProviders(custom_llm_provider),
)
anthropic_messages_provider_config: Optional[BaseAnthropicMessagesConfig] = None
if custom_llm_provider is not None and custom_llm_provider in [
provider.value for provider in LlmProviders
]:
anthropic_messages_provider_config = (
ProviderConfigManager.get_provider_anthropic_messages_config(
model=model,
provider=litellm.LlmProviders(custom_llm_provider),
)
)
if anthropic_messages_provider_config is None:
# Handle non-Anthropic models using the adapter
return (
+78 -2
View File
@@ -1,12 +1,11 @@
import json
import os
from typing import Any, Callable, Dict, Optional, Union
from typing import Any, Callable, Dict, Literal, Optional, Union, cast
import httpx
from openai import AsyncAzureOpenAI, AzureOpenAI
import litellm
from litellm.types.router import GenericLiteLLMParams
from litellm._logging import verbose_logger
from litellm.caching.caching import DualCache
from litellm.llms.base_llm.chat.transformation import BaseLLMException
@@ -15,6 +14,8 @@ from litellm.secret_managers.get_azure_ad_token_provider import (
get_azure_ad_token_provider,
)
from litellm.secret_managers.main import get_secret_str
from litellm.types.router import GenericLiteLLMParams
from litellm.utils import _add_path_to_api_base
azure_ad_cache = DualCache()
@@ -613,3 +614,78 @@ class BaseAzureLLM(BaseOpenAILLM):
else:
client = AzureOpenAI(**azure_client_params) # type: ignore
return client
@staticmethod
def _base_validate_azure_environment(
headers: dict, litellm_params: Optional[GenericLiteLLMParams]
) -> dict:
litellm_params = litellm_params or GenericLiteLLMParams()
api_key = (
litellm_params.api_key
or litellm.api_key
or litellm.azure_key
or get_secret_str("AZURE_OPENAI_API_KEY")
or get_secret_str("AZURE_API_KEY")
)
if api_key:
headers["api-key"] = api_key
return headers
### Fallback to Azure AD token-based authentication if no API key is available
### Retrieves Azure AD token and adds it to the Authorization header
azure_ad_token = get_azure_ad_token(litellm_params)
if azure_ad_token:
headers["Authorization"] = f"Bearer {azure_ad_token}"
return headers
@staticmethod
def _get_base_azure_url(
api_base: Optional[str],
litellm_params: Optional[Union[GenericLiteLLMParams, Dict[str, Any]]],
route: Literal["/openai/responses", "/openai/vector_stores"]
) -> str:
api_base = api_base or litellm.api_base or get_secret_str("AZURE_API_BASE")
if api_base is None:
raise ValueError(
f"api_base is required for Azure AI Studio. Please set the api_base parameter. Passed `api_base={api_base}`"
)
original_url = httpx.URL(api_base)
# Extract api_version or use default
litellm_params = litellm_params or {}
api_version = cast(Optional[str], litellm_params.get("api_version"))
# Create a new dictionary with existing params
query_params = dict(original_url.params)
# Add api_version if needed
if "api-version" not in query_params and api_version:
query_params["api-version"] = api_version
# Add the path to the base URL
if route not in api_base:
new_url = _add_path_to_api_base(
api_base=api_base, ending_path=route
)
else:
new_url = api_base
if BaseAzureLLM._is_azure_v1_api_version(api_version):
# ensure the request go to /openai/v1 and not just /openai
if "/openai/v1" not in new_url:
parsed_url = httpx.URL(new_url)
new_url = str(parsed_url.copy_with(path=parsed_url.path.replace("/openai", "/openai/v1")))
# Use the new query_params dictionary
final_url = httpx.URL(new_url).copy_with(params=query_params)
return str(final_url)
@staticmethod
def _is_azure_v1_api_version(api_version: Optional[str]) -> bool:
if api_version is None:
return False
return api_version == "preview" or api_version == "latest"
+10 -66
View File
@@ -1,16 +1,11 @@
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, cast
from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Tuple
import httpx
import litellm
from litellm._logging import verbose_logger
from litellm.llms.azure.common_utils import get_azure_ad_token
from litellm.llms.azure.common_utils import BaseAzureLLM
from litellm.llms.openai.responses.transformation import OpenAIResponsesAPIConfig
from litellm.secret_managers.main import get_secret_str
from litellm.types.llms.openai import *
from litellm.types.responses.main import *
from litellm.types.router import GenericLiteLLMParams
from litellm.utils import _add_path_to_api_base
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
@@ -24,27 +19,11 @@ class AzureOpenAIResponsesAPIConfig(OpenAIResponsesAPIConfig):
def validate_environment(
self, headers: dict, model: str, litellm_params: Optional[GenericLiteLLMParams]
) -> dict:
litellm_params = litellm_params or GenericLiteLLMParams()
api_key = (
litellm_params.api_key
or litellm.api_key
or litellm.azure_key
or get_secret_str("AZURE_OPENAI_API_KEY")
or get_secret_str("AZURE_API_KEY")
return BaseAzureLLM._base_validate_azure_environment(
headers=headers,
litellm_params=litellm_params
)
if api_key:
headers["api-key"] = api_key
return headers
### Fallback to Azure AD token-based authentication if no API key is available
### Retrieves Azure AD token and adds it to the Authorization header
azure_ad_token = get_azure_ad_token(litellm_params)
if azure_ad_token:
headers["Authorization"] = f"Bearer {azure_ad_token}"
return headers
def get_complete_url(
self,
api_base: Optional[str],
@@ -66,47 +45,12 @@ class AzureOpenAIResponsesAPIConfig(OpenAIResponsesAPIConfig):
- A complete URL string, e.g.,
"https://litellm8397336933.openai.azure.com/openai/responses?api-version=2024-05-01-preview"
"""
api_base = api_base or litellm.api_base or get_secret_str("AZURE_API_BASE")
if api_base is None:
raise ValueError(
f"api_base is required for Azure AI Studio. Please set the api_base parameter. Passed `api_base={api_base}`"
)
original_url = httpx.URL(api_base)
# Extract api_version or use default
api_version = cast(Optional[str], litellm_params.get("api_version"))
# Create a new dictionary with existing params
query_params = dict(original_url.params)
# Add api_version if needed
if "api-version" not in query_params and api_version:
query_params["api-version"] = api_version
# Add the path to the base URL
if "/openai/responses" not in api_base:
new_url = _add_path_to_api_base(
api_base=api_base, ending_path="/openai/responses"
)
else:
new_url = api_base
if self._is_azure_v1_api_version(api_version):
# ensure the request go to /openai/v1 and not just /openai
if "/openai/v1" not in new_url:
parsed_url = httpx.URL(new_url)
new_url = str(parsed_url.copy_with(path=parsed_url.path.replace("/openai", "/openai/v1")))
# Use the new query_params dictionary
final_url = httpx.URL(new_url).copy_with(params=query_params)
return str(final_url)
return BaseAzureLLM._get_base_azure_url(
api_base=api_base,
litellm_params=litellm_params,
route="/openai/responses"
)
def _is_azure_v1_api_version(self, api_version: Optional[str]) -> bool:
if api_version is None:
return False
return api_version == "preview" or api_version == "latest"
#########################################################
########## DELETE RESPONSE API TRANSFORMATION ##############
@@ -0,0 +1,27 @@
from typing import Optional
from litellm.llms.azure.common_utils import BaseAzureLLM
from litellm.llms.openai.vector_stores.transformation import OpenAIVectorStoreConfig
from litellm.types.router import GenericLiteLLMParams
class AzureOpenAIVectorStoreConfig(OpenAIVectorStoreConfig):
def get_complete_url(
self,
api_base: Optional[str],
litellm_params: dict,
) -> str:
return BaseAzureLLM._get_base_azure_url(
api_base=api_base,
litellm_params=litellm_params,
route="/openai/vector_stores"
)
def validate_environment(
self, headers: dict, litellm_params: Optional[GenericLiteLLMParams]
) -> dict:
return BaseAzureLLM._base_validate_azure_environment(
headers=headers,
litellm_params=litellm_params
)
@@ -0,0 +1,86 @@
from abc import abstractmethod
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
import httpx
from litellm.types.router import GenericLiteLLMParams
from litellm.types.vector_stores import (
VectorStoreCreateOptionalRequestParams,
VectorStoreCreateResponse,
VectorStoreSearchOptionalRequestParams,
VectorStoreSearchResponse,
)
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
from ..chat.transformation import BaseLLMException as _BaseLLMException
LiteLLMLoggingObj = _LiteLLMLoggingObj
BaseLLMException = _BaseLLMException
else:
LiteLLMLoggingObj = Any
BaseLLMException = Any
class BaseVectorStoreConfig:
@abstractmethod
def transform_search_vector_store_request(
self,
vector_store_id: str,
query: Union[str, List[str]],
vector_store_search_optional_params: VectorStoreSearchOptionalRequestParams,
api_base: str,
) -> Tuple[str, Dict]:
pass
@abstractmethod
def transform_search_vector_store_response(self, response: httpx.Response) -> VectorStoreSearchResponse:
pass
@abstractmethod
def transform_create_vector_store_request(
self,
vector_store_create_optional_params: VectorStoreCreateOptionalRequestParams,
api_base: str,
) -> Tuple[str, Dict]:
pass
@abstractmethod
def transform_create_vector_store_response(self, response: httpx.Response) -> VectorStoreCreateResponse:
pass
@abstractmethod
def validate_environment(
self, headers: dict, litellm_params: Optional[GenericLiteLLMParams]
) -> dict:
return {}
@abstractmethod
def get_complete_url(
self,
api_base: Optional[str],
litellm_params: dict,
) -> str:
"""
OPTIONAL
Get the complete url for the request
Some providers need `model` in `api_base`
"""
if api_base is None:
raise ValueError("api_base is required")
return api_base
def get_error_class(
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
) -> BaseLLMException:
from ..chat.transformation import BaseLLMException
raise BaseLLMException(
status_code=status_code,
message=error_message,
headers=headers,
)
+276 -1
View File
@@ -35,6 +35,7 @@ from litellm.llms.base_llm.image_edit.transformation import BaseImageEditConfig
from litellm.llms.base_llm.realtime.transformation import BaseRealtimeConfig
from litellm.llms.base_llm.rerank.transformation import BaseRerankConfig
from litellm.llms.base_llm.responses.transformation import BaseResponsesAPIConfig
from litellm.llms.base_llm.vector_store.transformation import BaseVectorStoreConfig
from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler,
HTTPHandler,
@@ -60,6 +61,12 @@ from litellm.types.rerank import OptionalRerankParams, RerankResponse
from litellm.types.responses.main import DeleteResponseResult
from litellm.types.router import GenericLiteLLMParams
from litellm.types.utils import EmbeddingResponse, FileTypes, TranscriptionResponse
from litellm.types.vector_stores import (
VectorStoreCreateOptionalRequestParams,
VectorStoreCreateResponse,
VectorStoreSearchOptionalRequestParams,
VectorStoreSearchResponse,
)
from litellm.utils import (
CustomStreamWrapper,
ImageResponse,
@@ -2342,7 +2349,7 @@ class BaseLLMHTTPHandler:
self,
e: Exception,
provider_config: Union[
BaseConfig, BaseRerankConfig, BaseResponsesAPIConfig, BaseImageEditConfig
BaseConfig, BaseRerankConfig, BaseResponsesAPIConfig, BaseImageEditConfig, BaseVectorStoreConfig
],
):
status_code = getattr(e, "status_code", 500)
@@ -2613,3 +2620,271 @@ class BaseLLMHTTPHandler:
raw_response=response,
logging_obj=logging_obj,
)
###### VECTOR STORE HANDLER ######
async def async_vector_store_search_handler(
self,
vector_store_id: str,
query: Union[str, List[str]],
vector_store_search_optional_params: VectorStoreSearchOptionalRequestParams,
vector_store_provider_config: BaseVectorStoreConfig,
custom_llm_provider: str,
litellm_params: GenericLiteLLMParams,
logging_obj: LiteLLMLoggingObj,
extra_headers: Optional[Dict[str, Any]] = None,
extra_body: Optional[Dict[str, Any]] = None,
timeout: Optional[Union[float, httpx.Timeout]] = None,
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
_is_async: bool = False,
) -> VectorStoreSearchResponse:
if client is None or not isinstance(client, AsyncHTTPHandler):
async_httpx_client = get_async_httpx_client(
llm_provider=litellm.LlmProviders(custom_llm_provider),
params={"ssl_verify": litellm_params.get("ssl_verify", None)},
)
else:
async_httpx_client = client
headers = vector_store_provider_config.validate_environment(
headers=extra_headers or {},
litellm_params=litellm_params
)
if extra_headers:
headers.update(extra_headers)
api_base = vector_store_provider_config.get_complete_url(
api_base=litellm_params.api_base,
litellm_params=dict(litellm_params),
)
url, request_body = vector_store_provider_config.transform_search_vector_store_request(
vector_store_id=vector_store_id,
query=query,
vector_store_search_optional_params=vector_store_search_optional_params,
api_base=api_base,
)
logging_obj.pre_call(
input="",
api_key="",
additional_args={
"complete_input_dict": request_body,
"api_base": api_base,
"headers": headers,
},
)
try:
response = await async_httpx_client.post(url=url, headers=headers, json=request_body, timeout=timeout)
except Exception as e:
raise self._handle_error(e=e, provider_config=vector_store_provider_config)
return vector_store_provider_config.transform_search_vector_store_response(
response=response,
)
def vector_store_search_handler(
self,
vector_store_id: str,
query: Union[str, List[str]],
vector_store_search_optional_params: VectorStoreSearchOptionalRequestParams,
vector_store_provider_config: BaseVectorStoreConfig,
custom_llm_provider: str,
litellm_params: GenericLiteLLMParams,
logging_obj: LiteLLMLoggingObj,
extra_headers: Optional[Dict[str, Any]] = None,
extra_body: Optional[Dict[str, Any]] = None,
timeout: Optional[Union[float, httpx.Timeout]] = None,
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
_is_async: bool = False,
) -> Union[VectorStoreSearchResponse, Coroutine[Any, Any, VectorStoreSearchResponse]]:
if _is_async:
return self.async_vector_store_search_handler(
vector_store_id=vector_store_id,
query=query,
vector_store_search_optional_params=vector_store_search_optional_params,
vector_store_provider_config=vector_store_provider_config,
litellm_params=litellm_params,
logging_obj=logging_obj,
custom_llm_provider=custom_llm_provider,
extra_headers=extra_headers,
extra_body=extra_body,
timeout=timeout,
client=client,
)
if client is None or not isinstance(client, HTTPHandler):
sync_httpx_client = _get_httpx_client(
params={"ssl_verify": litellm_params.get("ssl_verify", None)}
)
else:
sync_httpx_client = client
headers = vector_store_provider_config.validate_environment(
headers=extra_headers or {},
litellm_params=litellm_params
)
if extra_headers:
headers.update(extra_headers)
api_base = vector_store_provider_config.get_complete_url(
api_base=litellm_params.api_base,
litellm_params=dict(litellm_params),
)
url, request_body = vector_store_provider_config.transform_search_vector_store_request(
vector_store_id=vector_store_id,
query=query,
vector_store_search_optional_params=vector_store_search_optional_params,
api_base=api_base,
)
logging_obj.pre_call(
input="",
api_key="",
additional_args={
"complete_input_dict": request_body,
"api_base": api_base,
"headers": headers,
},
)
try:
response = sync_httpx_client.post(url=url, headers=headers, json=request_body)
except Exception as e:
raise self._handle_error(e=e, provider_config=vector_store_provider_config)
return vector_store_provider_config.transform_search_vector_store_response(
response=response,
)
async def async_vector_store_create_handler(
self,
vector_store_create_optional_params: VectorStoreCreateOptionalRequestParams,
vector_store_provider_config: BaseVectorStoreConfig,
custom_llm_provider: str,
litellm_params: GenericLiteLLMParams,
logging_obj: LiteLLMLoggingObj,
extra_headers: Optional[Dict[str, Any]] = None,
extra_body: Optional[Dict[str, Any]] = None,
timeout: Optional[Union[float, httpx.Timeout]] = None,
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
_is_async: bool = False,
) -> VectorStoreCreateResponse:
if client is None or not isinstance(client, AsyncHTTPHandler):
async_httpx_client = get_async_httpx_client(
llm_provider=litellm.LlmProviders(custom_llm_provider),
params={"ssl_verify": litellm_params.get("ssl_verify", None)},
)
else:
async_httpx_client = client
headers = vector_store_provider_config.validate_environment(
headers=extra_headers or {},
litellm_params=litellm_params
)
if extra_headers:
headers.update(extra_headers)
api_base = vector_store_provider_config.get_complete_url(
api_base=litellm_params.api_base,
litellm_params=dict(litellm_params),
)
url, request_body = vector_store_provider_config.transform_create_vector_store_request(
vector_store_create_optional_params=vector_store_create_optional_params,
api_base=api_base,
)
logging_obj.pre_call(
input="",
api_key="",
additional_args={
"complete_input_dict": request_body,
"api_base": api_base,
"headers": headers,
},
)
try:
response = await async_httpx_client.post(url=url, headers=headers, json=request_body, timeout=timeout)
except Exception as e:
raise self._handle_error(e=e, provider_config=vector_store_provider_config)
return vector_store_provider_config.transform_create_vector_store_response(
response=response,
)
def vector_store_create_handler(
self,
vector_store_create_optional_params: VectorStoreCreateOptionalRequestParams,
vector_store_provider_config: BaseVectorStoreConfig,
custom_llm_provider: str,
litellm_params: GenericLiteLLMParams,
logging_obj: LiteLLMLoggingObj,
extra_headers: Optional[Dict[str, Any]] = None,
extra_body: Optional[Dict[str, Any]] = None,
timeout: Optional[Union[float, httpx.Timeout]] = None,
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
_is_async: bool = False,
) -> Union[VectorStoreCreateResponse, Coroutine[Any, Any, VectorStoreCreateResponse]]:
if _is_async:
return self.async_vector_store_create_handler(
vector_store_create_optional_params=vector_store_create_optional_params,
vector_store_provider_config=vector_store_provider_config,
litellm_params=litellm_params,
logging_obj=logging_obj,
custom_llm_provider=custom_llm_provider,
extra_headers=extra_headers,
extra_body=extra_body,
timeout=timeout,
client=client,
)
if client is None or not isinstance(client, HTTPHandler):
sync_httpx_client = _get_httpx_client(
params={"ssl_verify": litellm_params.get("ssl_verify", None)}
)
else:
sync_httpx_client = client
headers = vector_store_provider_config.validate_environment(
headers=extra_headers or {},
litellm_params=litellm_params
)
if extra_headers:
headers.update(extra_headers)
api_base = vector_store_provider_config.get_complete_url(
api_base=litellm_params.api_base,
litellm_params=dict(litellm_params),
)
url, request_body = vector_store_provider_config.transform_create_vector_store_request(
vector_store_create_optional_params=vector_store_create_optional_params,
api_base=api_base,
)
logging_obj.pre_call(
input="",
api_key="",
additional_args={
"complete_input_dict": request_body,
"api_base": api_base,
"headers": headers,
},
)
try:
response = sync_httpx_client.post(url=url, headers=headers, json=request_body)
except Exception as e:
raise self._handle_error(e=e, provider_config=vector_store_provider_config)
return vector_store_provider_config.transform_create_vector_store_response(
response=response,
)
@@ -6,7 +6,7 @@ Why separate file? Make it easy to see how transformation works
Docs - https://docs.mistral.ai/api/
"""
from typing import Any, Coroutine, List, Literal, Optional, Tuple, Union, overload, cast
from typing import Any, Coroutine, List, Literal, Optional, Tuple, Union, cast, overload
from litellm.litellm_core_utils.prompt_templates.common_utils import (
handle_messages_with_content_list_to_str_conversion,
@@ -107,15 +107,28 @@ class MistralConfig(OpenAIGPTConfig):
def _get_mistral_reasoning_system_prompt() -> str:
"""
Returns the system prompt for Mistral reasoning models.
Based on Mistral's documentation: https://docs.mistral.ai/capabilities/reasoning/
Based on Mistral's documentation: https://huggingface.co/mistralai/Magistral-Small-2506
Mistral recommends the following system prompt for reasoning:
"""
return """When solving problems, think step-by-step in <think> tags before providing your final answer. Use the following format:
return """
<s>[SYSTEM_PROMPT]system_prompt
A user will ask you to solve a task. You should first draft your thinking process (inner monologue) until you have derived the final answer. Afterwards, write a self-contained summary of your thoughts (i.e. your summary should be succinct but contain all the critical steps you needed to reach the conclusion). You should use Markdown to format your response. Write both your thoughts and summary in the same language as the task posed by the user. NEVER use \boxed{} in your response.
<think>
Your step-by-step reasoning process. Be thorough and work through the problem carefully.
</think>
Your thinking process must follow the template below:
<think>
Your thoughts or/and draft, like working through an exercise on scratch paper. Be as casual and as long as you want until you are confident to generate a correct answer.
</think>
Then provide a clear, concise answer based on your reasoning."""
Here, provide a concise summary that reflects your reasoning and presents a clear final answer to the user. Don't mention that this is a summary.
Problem:
[/SYSTEM_PROMPT][INST]user_message[/INST]<think>
reasoning_traces
</think>
assistant_response</s>[INST]user_message[/INST]
"""
def map_openai_params(
self,
+17 -10
View File
@@ -5,15 +5,15 @@ Ollama /chat/completion calls handled in llm_http_handler.py
"""
from typing import Any, Dict, List
import litellm
from litellm.types.utils import EmbeddingResponse
def _prepare_ollama_embedding_payload(
model: str,
prompts: List[str],
optional_params: Dict[str, Any]
model: str, prompts: List[str], optional_params: Dict[str, Any]
) -> Dict[str, Any]:
data: Dict[str, Any] = {"model": model, "input": prompts}
special_optional_params = ["truncate", "options", "keep_alive"]
@@ -26,13 +26,14 @@ def _prepare_ollama_embedding_payload(
data["options"].update({k: v})
return data
def _process_ollama_embedding_response(
response_json: dict,
prompts: List[str],
model: str,
model_response: EmbeddingResponse,
logging_obj: Any,
encoding: Any
encoding: Any,
) -> EmbeddingResponse:
output_data = []
embeddings: List[List[float]] = response_json["embeddings"]
@@ -46,11 +47,15 @@ def _process_ollama_embedding_response(
if encoding is not None:
input_tokens = len(encoding.encode("".join(prompts)))
if logging_obj:
logging_obj.debug("Ollama response missing prompt_eval_count; estimated with encoding.")
logging_obj.debug(
"Ollama response missing prompt_eval_count; estimated with encoding."
)
else:
input_tokens = 0
if logging_obj:
logging_obj.warning("Missing prompt_eval_count and no encoding provided; defaulted to 0.")
logging_obj.warning(
"Missing prompt_eval_count and no encoding provided; defaulted to 0."
)
model_response.object = "list"
model_response.data = output_data
@@ -64,6 +69,7 @@ def _process_ollama_embedding_response(
)
return model_response
async def ollama_aembeddings(
api_base: str,
model: str,
@@ -79,7 +85,7 @@ async def ollama_aembeddings(
data = _prepare_ollama_embedding_payload(model, prompts, optional_params)
response = await litellm.module_level_aclient.post(url=api_base, json=data)
response_json = await response.json()
response_json = response.json()
return _process_ollama_embedding_response(
response_json=response_json,
@@ -87,9 +93,10 @@ async def ollama_aembeddings(
model=model,
model_response=model_response,
logging_obj=logging_obj,
encoding=encoding
encoding=encoding,
)
def ollama_embeddings(
api_base: str,
model: str,
@@ -113,5 +120,5 @@ def ollama_embeddings(
model=model,
model_response=model_response,
logging_obj=logging_obj,
encoding=encoding
encoding=encoding,
)
@@ -0,0 +1,140 @@
from typing import Dict, List, Optional, Tuple, Union, cast
import httpx
import litellm
from litellm.llms.base_llm.vector_store.transformation import BaseVectorStoreConfig
from litellm.secret_managers.main import get_secret_str
from litellm.types.router import GenericLiteLLMParams
from litellm.types.vector_stores import (
VectorStoreCreateOptionalRequestParams,
VectorStoreCreateRequest,
VectorStoreCreateResponse,
VectorStoreSearchOptionalRequestParams,
VectorStoreSearchRequest,
VectorStoreSearchResponse,
)
class OpenAIVectorStoreConfig(BaseVectorStoreConfig):
ASSISTANTS_HEADER_KEY = "OpenAI-Beta"
ASSISTANTS_HEADER_VALUE = "assistants=v2"
def validate_environment(
self, headers: dict, litellm_params: Optional[GenericLiteLLMParams]
) -> dict:
litellm_params = litellm_params or GenericLiteLLMParams()
api_key = (
litellm_params.api_key
or litellm.api_key
or litellm.openai_key
or get_secret_str("OPENAI_API_KEY")
)
headers.update(
{
"Authorization": f"Bearer {api_key}",
}
)
#########################################################
# Ensure OpenAI Assistants header is includes
#########################################################
if self.ASSISTANTS_HEADER_KEY not in headers:
headers.update(
{
self.ASSISTANTS_HEADER_KEY: self.ASSISTANTS_HEADER_VALUE,
}
)
return headers
def get_complete_url(
self,
api_base: Optional[str],
litellm_params: dict,
) -> str:
"""
Get the Base endpoint for OpenAI Vector Stores API
"""
api_base = (
api_base
or litellm.api_base
or get_secret_str("OPENAI_BASE_URL")
or get_secret_str("OPENAI_API_BASE")
or "https://api.openai.com/v1"
)
# Remove trailing slashes
api_base = api_base.rstrip("/")
return f"{api_base}/vector_stores"
def transform_search_vector_store_request(
self,
vector_store_id: str,
query: Union[str, List[str]],
vector_store_search_optional_params: VectorStoreSearchOptionalRequestParams,
api_base: str,
) -> Tuple[str, Dict]:
url = f"{api_base}/{vector_store_id}/search"
typed_request_body = VectorStoreSearchRequest(
query=query,
filters=vector_store_search_optional_params.get("filters", None),
max_num_results=vector_store_search_optional_params.get("max_num_results", None),
ranking_options=vector_store_search_optional_params.get("ranking_options", None),
rewrite_query=vector_store_search_optional_params.get("rewrite_query", None),
)
dict_request_body = cast(dict, typed_request_body)
return url, dict_request_body
def transform_search_vector_store_response(self, response: httpx.Response) -> VectorStoreSearchResponse:
try:
response_json = response.json()
return VectorStoreSearchResponse(
**response_json
)
except Exception as e:
raise self.get_error_class(
error_message=str(e),
status_code=response.status_code,
headers=response.headers
)
def transform_create_vector_store_request(
self,
vector_store_create_optional_params: VectorStoreCreateOptionalRequestParams,
api_base: str,
) -> Tuple[str, Dict]:
url = api_base # Base URL for creating vector stores
typed_request_body = VectorStoreCreateRequest(
name=vector_store_create_optional_params.get("name", None),
file_ids=vector_store_create_optional_params.get("file_ids", None),
expires_after=vector_store_create_optional_params.get("expires_after", None),
chunking_strategy=vector_store_create_optional_params.get("chunking_strategy", None),
metadata=vector_store_create_optional_params.get("metadata", None),
)
dict_request_body = cast(dict, typed_request_body)
return url, dict_request_body
def transform_create_vector_store_response(self, response: httpx.Response) -> VectorStoreCreateResponse:
try:
response_json = response.json()
return VectorStoreCreateResponse(
**response_json
)
except Exception as e:
raise self.get_error_class(
error_message=str(e),
status_code=response.status_code,
headers=response.headers
)
+23
View File
@@ -63,3 +63,26 @@ class VolcEngineConfig(OpenAILikeChatConfig):
"extra_headers",
"thinking",
] # works across all models
def map_openai_params(
self,
non_default_params: dict,
optional_params: dict,
model: str,
drop_params: bool,
replace_max_completion_tokens_with_max_tokens: bool = True,
) -> dict:
optional_params = super().map_openai_params(
non_default_params,
optional_params,
model,
drop_params,
replace_max_completion_tokens_with_max_tokens,
)
if "thinking" in optional_params:
optional_params.setdefault("extra_body", {})["thinking"] = (
optional_params.pop("thinking")
)
return optional_params
+6
View File
@@ -1297,6 +1297,12 @@ def completion( # type: ignore # noqa: PLR0915
except Exception as e:
verbose_logger.debug("Error getting model info: {}".format(e))
model_info = {}
if model.startswith(
"responses/"
): # handle azure models - `azure/responses/<deployment-name>`
model = model.split("/")[1]
mode = "responses"
model_info["mode"] = mode
if model_info.get("mode") == "responses":
from litellm.completion_extras import responses_api_bridge
@@ -4046,7 +4046,8 @@
"litellm_provider": "mistral",
"mode": "chat",
"supports_assistant_prefill": true,
"supports_tool_choice": true
"supports_tool_choice": true,
"supports_response_schema": true
},
"mistral/mistral-small": {
"max_tokens": 8191,
@@ -4058,7 +4059,8 @@
"supports_function_calling": true,
"mode": "chat",
"supports_assistant_prefill": true,
"supports_tool_choice": true
"supports_tool_choice": true,
"supports_response_schema": true
},
"mistral/mistral-small-latest": {
"max_tokens": 8191,
@@ -4070,7 +4072,8 @@
"supports_function_calling": true,
"mode": "chat",
"supports_assistant_prefill": true,
"supports_tool_choice": true
"supports_tool_choice": true,
"supports_response_schema": true
},
"mistral/mistral-medium": {
"max_tokens": 8191,
@@ -4081,7 +4084,8 @@
"litellm_provider": "mistral",
"mode": "chat",
"supports_assistant_prefill": true,
"supports_tool_choice": true
"supports_tool_choice": true,
"supports_response_schema": true
},
"mistral/mistral-medium-latest": {
"max_tokens": 8191,
@@ -4093,7 +4097,8 @@
"mode": "chat",
"supports_function_calling": true,
"supports_assistant_prefill": true,
"supports_tool_choice": true
"supports_tool_choice": true,
"supports_response_schema": true
},
"mistral/mistral-medium-2505": {
"max_tokens": 8191,
@@ -4105,7 +4110,8 @@
"mode": "chat",
"supports_function_calling": true,
"supports_assistant_prefill": true,
"supports_tool_choice": true
"supports_tool_choice": true,
"supports_response_schema": true
},
"mistral/mistral-medium-2312": {
"max_tokens": 8191,
@@ -4116,7 +4122,8 @@
"litellm_provider": "mistral",
"mode": "chat",
"supports_assistant_prefill": true,
"supports_tool_choice": true
"supports_tool_choice": true,
"supports_response_schema": true
},
"mistral/mistral-large-latest": {
"max_tokens": 128000,
@@ -4128,7 +4135,8 @@
"mode": "chat",
"supports_function_calling": true,
"supports_assistant_prefill": true,
"supports_tool_choice": true
"supports_tool_choice": true,
"supports_response_schema": true
},
"mistral/mistral-large-2411": {
"max_tokens": 128000,
@@ -4140,7 +4148,8 @@
"mode": "chat",
"supports_function_calling": true,
"supports_assistant_prefill": true,
"supports_tool_choice": true
"supports_tool_choice": true,
"supports_response_schema": true
},
"mistral/mistral-large-2402": {
"max_tokens": 8191,
@@ -4152,7 +4161,8 @@
"mode": "chat",
"supports_function_calling": true,
"supports_assistant_prefill": true,
"supports_tool_choice": true
"supports_tool_choice": true,
"supports_response_schema": true
},
"mistral/mistral-large-2407": {
"max_tokens": 128000,
@@ -4164,7 +4174,8 @@
"mode": "chat",
"supports_function_calling": true,
"supports_assistant_prefill": true,
"supports_tool_choice": true
"supports_tool_choice": true,
"supports_response_schema": true
},
"mistral/pixtral-large-latest": {
"max_tokens": 128000,
@@ -4177,7 +4188,8 @@
"supports_function_calling": true,
"supports_assistant_prefill": true,
"supports_vision": true,
"supports_tool_choice": true
"supports_tool_choice": true,
"supports_response_schema": true
},
"mistral/pixtral-large-2411": {
"max_tokens": 128000,
@@ -4190,7 +4202,8 @@
"supports_function_calling": true,
"supports_assistant_prefill": true,
"supports_vision": true,
"supports_tool_choice": true
"supports_tool_choice": true,
"supports_response_schema": true
},
"mistral/pixtral-12b-2409": {
"max_tokens": 128000,
@@ -4203,7 +4216,8 @@
"supports_function_calling": true,
"supports_assistant_prefill": true,
"supports_vision": true,
"supports_tool_choice": true
"supports_tool_choice": true,
"supports_response_schema": true
},
"mistral/open-mistral-7b": {
"max_tokens": 8191,
@@ -4214,7 +4228,8 @@
"litellm_provider": "mistral",
"mode": "chat",
"supports_assistant_prefill": true,
"supports_tool_choice": true
"supports_tool_choice": true,
"supports_response_schema": true
},
"mistral/open-mixtral-8x7b": {
"max_tokens": 8191,
@@ -4226,7 +4241,8 @@
"mode": "chat",
"supports_function_calling": true,
"supports_assistant_prefill": true,
"supports_tool_choice": true
"supports_tool_choice": true,
"supports_response_schema": true
},
"mistral/open-mixtral-8x22b": {
"max_tokens": 8191,
@@ -4238,7 +4254,8 @@
"mode": "chat",
"supports_function_calling": true,
"supports_assistant_prefill": true,
"supports_tool_choice": true
"supports_tool_choice": true,
"supports_response_schema": true
},
"mistral/codestral-latest": {
"max_tokens": 8191,
@@ -4249,7 +4266,8 @@
"litellm_provider": "mistral",
"mode": "chat",
"supports_assistant_prefill": true,
"supports_tool_choice": true
"supports_tool_choice": true,
"supports_response_schema": true
},
"mistral/codestral-2405": {
"max_tokens": 8191,
@@ -4260,7 +4278,8 @@
"litellm_provider": "mistral",
"mode": "chat",
"supports_assistant_prefill": true,
"supports_tool_choice": true
"supports_tool_choice": true,
"supports_response_schema": true
},
"mistral/open-mistral-nemo": {
"max_tokens": 128000,
@@ -4272,7 +4291,8 @@
"mode": "chat",
"source": "https://mistral.ai/technology/",
"supports_assistant_prefill": true,
"supports_tool_choice": true
"supports_tool_choice": true,
"supports_response_schema": true
},
"mistral/open-mistral-nemo-2407": {
"max_tokens": 128000,
@@ -4284,7 +4304,8 @@
"mode": "chat",
"source": "https://mistral.ai/technology/",
"supports_assistant_prefill": true,
"supports_tool_choice": true
"supports_tool_choice": true,
"supports_response_schema": true
},
"mistral/open-codestral-mamba": {
"max_tokens": 256000,
@@ -4321,7 +4342,8 @@
"source": "https://mistral.ai/news/devstral",
"supports_function_calling": true,
"supports_assistant_prefill": true,
"supports_tool_choice": true
"supports_tool_choice": true,
"supports_response_schema": true
},
"mistral/magistral-medium-latest": {
"max_tokens": 40000,
@@ -4335,7 +4357,8 @@
"supports_function_calling": true,
"supports_assistant_prefill": true,
"supports_tool_choice": true,
"supports_reasoning": true
"supports_reasoning": true,
"supports_response_schema": true
},
"mistral/magistral-medium-2506": {
"max_tokens": 40000,
@@ -4349,7 +4372,8 @@
"supports_function_calling": true,
"supports_assistant_prefill": true,
"supports_tool_choice": true,
"supports_reasoning": true
"supports_reasoning": true,
"supports_response_schema": true
},
"mistral/magistral-small-latest": {
"max_tokens": 40000,
@@ -4363,7 +4387,8 @@
"supports_function_calling": true,
"supports_assistant_prefill": true,
"supports_tool_choice": true,
"supports_reasoning": true
"supports_reasoning": true,
"supports_response_schema": true
},
"mistral/magistral-small-2506": {
"max_tokens": 40000,
@@ -4377,7 +4402,8 @@
"supports_function_calling": true,
"supports_assistant_prefill": true,
"supports_tool_choice": true,
"supports_reasoning": true
"supports_reasoning": true,
"supports_response_schema": true
},
"mistral/mistral-embed": {
"max_tokens": 8192,
+1 -1
View File
@@ -1,4 +1,4 @@
model_list:
- model_name: gemini-2.5-pro
litellm_params:
model: gemini/gemini-2.5-pro
model: gemini/gemini-2.5-pro
+3
View File
@@ -1115,6 +1115,7 @@ class NewTeamRequest(TeamBase):
team_member_budget: Optional[float] = (
None # allow user to set a budget for all team members
)
team_member_key_duration: Optional[str] = None # e.g. "1d", "1w", "1m"
model_config = ConfigDict(protected_namespaces=())
@@ -1157,6 +1158,7 @@ class UpdateTeamRequest(LiteLLMPydanticObjectBase):
guardrails: Optional[List[str]] = None
object_permission: Optional[LiteLLM_ObjectPermissionBase] = None
team_member_budget: Optional[float] = None
team_member_key_duration: Optional[str] = None
class ResetTeamBudgetRequest(LiteLLMPydanticObjectBase):
@@ -2792,6 +2794,7 @@ LiteLLM_ManagementEndpoint_MetadataFields = [
LiteLLM_ManagementEndpoint_MetadataFields_Premium = [
"guardrails",
"tags",
"team_member_key_duration",
]
+2
View File
@@ -164,7 +164,9 @@ class JWTHandler:
self.litellm_jwtauth.team_ids_jwt_field is not None
and token.get(self.litellm_jwtauth.team_ids_jwt_field) is not None
):
return token[self.litellm_jwtauth.team_ids_jwt_field]
return []
def get_end_user_id(
+357
View File
@@ -0,0 +1,357 @@
18:10:09 - LiteLLM Router:INFO: router.py:660 - Routing strategy: simple-shuffle
18:10:11 - LiteLLM Proxy:INFO: utils.py:1317 - All necessary views exist!
18:10:11 - LiteLLM Router:WARNING: router.py:4862 - Error upserting deployment: vertex_project, and vertex_location must be set in litellm_params for pass-through endpoints., ignoring and continuing with other deployments.
18:10:11 - LiteLLM Router:WARNING: router.py:4862 - Error upserting deployment: vertex_project, and vertex_location must be set in litellm_params for pass-through endpoints., ignoring and continuing with other deployments.
18:10:23 - LiteLLM Proxy:INFO: ui_sso.py:129 - Redirecting to SSO login for http://localhost:4000/sso/callback
18:10:27 - LiteLLM Proxy:INFO: ui_sso.py:495 - Starting SSO callback
18:10:27 - LiteLLM Proxy:INFO: ui_sso.py:550 - Redirecting to http://localhost:4000/sso/callback
18:10:28 - LiteLLM Proxy:INFO: ui_sso.py:581 - SSO callback result: id='krrishd' email='krrishdholakia@gmail.com' first_name=None last_name=None display_name='a3f1c107-04dc-4c93-ae60-7f32eb4b05ce' picture=None provider=None team_ids=[]
18:10:28 - LiteLLM Proxy:INFO: ui_sso.py:671 - user_defined_values for creating ui key: {'models': [], 'user_id': 'krrishd', 'user_email': 'krrishdholakia@gmail.com', 'max_budget': None, 'user_role': 'proxy_admin', 'budget_duration': None}
18:10:28 - LiteLLM Proxy:INFO: utils.py:1856 - Data Inserted into Keys Table
18:10:28 - LiteLLM Proxy:INFO: ui_sso.py:761 - user_id: krrishd; jwt_token: eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJ1c2VyX2lkIjoia3JyaXNoZCIsImtleSI6InNrLTVvOXVVc0ZaaTVBRFBiWERoanhCZlEiLCJ1c2VyX2VtYWlsIjoia3JyaXNoZGhvbGFraWFAZ21haWwuY29tIiwidXNlcl9yb2xlIjoicHJveHlfYWRtaW4iLCJsb2dpbl9tZXRob2QiOiJzc28iLCJwcmVtaXVtX3VzZXIiOnRydWUsImF1dGhfaGVhZGVyX25hbWUiOiJBdXRob3JpemF0aW9uIiwiZGlzYWJsZWRfbm9uX2FkbWluX3BlcnNvbmFsX2tleV9jcmVhdGlvbiI6ZmFsc2UsInNlcnZlcl9yb290X3BhdGgiOiIvIn0.OiZdFjZ2wiMhFbMCwu2cZYXh7oV5BB8Vta-Ysk5JBQU
18:10:28 - LiteLLM Proxy:INFO: ui_sso.py:764 - Redirecting to http://localhost:4000/ui/?login=success
18:10:30 - LiteLLM Proxy:ERROR: key_management_endpoints.py:2275 - Error in list_keys: Server disconnected without sending a response.
Traceback (most recent call last):
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpx/_transports/default.py", line 101, in map_httpcore_exceptions
yield
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpx/_transports/default.py", line 394, in handle_async_request
resp = await self._pool.handle_async_request(req)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpcore/_async/connection_pool.py", line 216, in handle_async_request
raise exc from None
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpcore/_async/connection_pool.py", line 196, in handle_async_request
response = await connection.handle_async_request(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpcore/_async/connection.py", line 101, in handle_async_request
return await self._connection.handle_async_request(request)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpcore/_async/http11.py", line 143, in handle_async_request
raise exc
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpcore/_async/http11.py", line 113, in handle_async_request
) = await self._receive_response_headers(**kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpcore/_async/http11.py", line 186, in _receive_response_headers
event = await self._receive_event(timeout=timeout)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpcore/_async/http11.py", line 238, in _receive_event
raise RemoteProtocolError(msg)
httpcore.RemoteProtocolError: Server disconnected without sending a response.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/Users/krrishdholakia/Documents/litellm/litellm/proxy/management_endpoints/key_management_endpoints.py", line 2255, in list_keys
response = await _list_key_helper(
^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/krrishdholakia/Documents/litellm/litellm/proxy/management_endpoints/key_management_endpoints.py", line 2434, in _list_key_helper
total_count = await prisma_client.db.litellm_verificationtoken.count(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/prisma/actions.py", line 10157, in count
resp = await self._client._execute(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/prisma/_base_client.py", line 543, in _execute
return await self._engine.query(builder.build(), tx_id=self._tx_id)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/prisma/engine/_query.py", line 402, in query
return await self.request(
^^^^^^^^^^^^^^^^^^^
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/prisma/engine/_http.py", line 217, in request
response = await self.session.request(method, url, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/prisma/_async_http.py", line 26, in request
return Response(await self.session.request(method, url, **kwargs))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpx/_client.py", line 1540, in request
return await self.send(request, auth=auth, follow_redirects=follow_redirects)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpx/_client.py", line 1629, in send
response = await self._send_handling_auth(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpx/_client.py", line 1657, in _send_handling_auth
response = await self._send_handling_redirects(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpx/_client.py", line 1694, in _send_handling_redirects
response = await self._send_single_request(request)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpx/_client.py", line 1730, in _send_single_request
response = await transport.handle_async_request(request)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpx/_transports/default.py", line 393, in handle_async_request
with map_httpcore_exceptions():
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/contextlib.py", line 155, in __exit__
self.gen.throw(typ, value, traceback)
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpx/_transports/default.py", line 118, in map_httpcore_exceptions
raise mapped_exc(message) from exc
httpx.RemoteProtocolError: Server disconnected without sending a response.
18:10:30 - LiteLLM Proxy:ERROR: proxy_server.py:2730 - litellm.proxy_server.py::add_deployment() - Error getting new models from DB - All connection attempts failed
Traceback (most recent call last):
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpx/_transports/default.py", line 101, in map_httpcore_exceptions
yield
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpx/_transports/default.py", line 394, in handle_async_request
resp = await self._pool.handle_async_request(req)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpcore/_async/connection_pool.py", line 216, in handle_async_request
raise exc from None
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpcore/_async/connection_pool.py", line 196, in handle_async_request
response = await connection.handle_async_request(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpcore/_async/connection.py", line 99, in handle_async_request
raise exc
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpcore/_async/connection.py", line 76, in handle_async_request
stream = await self._connect(request)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpcore/_async/connection.py", line 122, in _connect
stream = await self._network_backend.connect_tcp(**kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpcore/_backends/auto.py", line 30, in connect_tcp
return await self._backend.connect_tcp(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpcore/_backends/anyio.py", line 112, in connect_tcp
with map_exceptions(exc_map):
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/contextlib.py", line 155, in __exit__
self.gen.throw(typ, value, traceback)
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpcore/_exceptions.py", line 14, in map_exceptions
raise to_exc(exc) from exc
httpcore.ConnectError: All connection attempts failed
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/Users/krrishdholakia/Documents/litellm/litellm/proxy/proxy_server.py", line 2728, in _get_models_from_db
new_models = await prisma_client.db.litellm_proxymodeltable.find_many()
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/prisma/actions.py", line 2540, in find_many
resp = await self._client._execute(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/prisma/_base_client.py", line 543, in _execute
return await self._engine.query(builder.build(), tx_id=self._tx_id)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/prisma/engine/_query.py", line 402, in query
return await self.request(
^^^^^^^^^^^^^^^^^^^
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/prisma/engine/_http.py", line 217, in request
response = await self.session.request(method, url, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/prisma/_async_http.py", line 26, in request
return Response(await self.session.request(method, url, **kwargs))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpx/_client.py", line 1540, in request
return await self.send(request, auth=auth, follow_redirects=follow_redirects)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpx/_client.py", line 1629, in send
response = await self._send_handling_auth(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpx/_client.py", line 1657, in _send_handling_auth
response = await self._send_handling_redirects(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpx/_client.py", line 1694, in _send_handling_redirects
response = await self._send_single_request(request)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpx/_client.py", line 1730, in _send_single_request
response = await transport.handle_async_request(request)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpx/_transports/default.py", line 393, in handle_async_request
with map_httpcore_exceptions():
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/contextlib.py", line 155, in __exit__
self.gen.throw(typ, value, traceback)
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpx/_transports/default.py", line 118, in map_httpcore_exceptions
raise mapped_exc(message) from exc
httpx.ConnectError: All connection attempts failed
18:10:30 - LiteLLM Proxy:ERROR: utils.py:1404 - LiteLLM Prisma Client Exception get_generic_data: All connection attempts failed
18:10:30 - LiteLLM Proxy:ERROR: utils.py:1404 - LiteLLM Prisma Client Exception get_generic_data: All connection attempts failed
18:10:30 - LiteLLM Proxy:ERROR: proxy_server.py:2778 - litellm.proxy.proxy_server.py::ProxyConfig:add_deployment - All connection attempts failed
Traceback (most recent call last):
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpx/_transports/default.py", line 101, in map_httpcore_exceptions
yield
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpx/_transports/default.py", line 394, in handle_async_request
resp = await self._pool.handle_async_request(req)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpcore/_async/connection_pool.py", line 216, in handle_async_request
raise exc from None
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpcore/_async/connection_pool.py", line 196, in handle_async_request
response = await connection.handle_async_request(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpcore/_async/connection.py", line 99, in handle_async_request
raise exc
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpcore/_async/connection.py", line 76, in handle_async_request
stream = await self._connect(request)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpcore/_async/connection.py", line 122, in _connect
stream = await self._network_backend.connect_tcp(**kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpcore/_backends/auto.py", line 30, in connect_tcp
return await self._backend.connect_tcp(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpcore/_backends/anyio.py", line 112, in connect_tcp
with map_exceptions(exc_map):
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/contextlib.py", line 155, in __exit__
self.gen.throw(typ, value, traceback)
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpcore/_exceptions.py", line 14, in map_exceptions
raise to_exc(exc) from exc
httpcore.ConnectError: All connection attempts failed
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/Users/krrishdholakia/Documents/litellm/litellm/proxy/proxy_server.py", line 2760, in add_deployment
await self._update_llm_router(
File "/Users/krrishdholakia/Documents/litellm/litellm/proxy/proxy_server.py", line 2418, in _update_llm_router
config_data = await proxy_config.get_config()
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/krrishdholakia/Documents/litellm/litellm/proxy/proxy_server.py", line 1584, in get_config
config = await self._update_config_from_db(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/krrishdholakia/Documents/litellm/litellm/proxy/proxy_server.py", line 2706, in _update_config_from_db
responses = await asyncio.gather(*_tasks)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/krrishdholakia/Documents/litellm/litellm/proxy/db/log_db_metrics.py", line 99, in wrapper
raise e
File "/Users/krrishdholakia/Documents/litellm/litellm/proxy/db/log_db_metrics.py", line 42, in wrapper
result = await func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/backoff/_async.py", line 151, in retry
ret = await target(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/krrishdholakia/Documents/litellm/litellm/proxy/utils.py", line 1418, in get_generic_data
raise e
File "/Users/krrishdholakia/Documents/litellm/litellm/proxy/utils.py", line 1392, in get_generic_data
response = await self.db.litellm_config.find_first( # type: ignore
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/prisma/actions.py", line 11822, in find_first
resp = await self._client._execute(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/prisma/_base_client.py", line 543, in _execute
return await self._engine.query(builder.build(), tx_id=self._tx_id)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/prisma/engine/_query.py", line 402, in query
return await self.request(
^^^^^^^^^^^^^^^^^^^
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/prisma/engine/_http.py", line 217, in request
response = await self.session.request(method, url, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/prisma/_async_http.py", line 26, in request
return Response(await self.session.request(method, url, **kwargs))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpx/_client.py", line 1540, in request
return await self.send(request, auth=auth, follow_redirects=follow_redirects)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpx/_client.py", line 1629, in send
response = await self._send_handling_auth(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpx/_client.py", line 1657, in _send_handling_auth
response = await self._send_handling_redirects(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpx/_client.py", line 1694, in _send_handling_redirects
response = await self._send_single_request(request)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpx/_client.py", line 1730, in _send_single_request
response = await transport.handle_async_request(request)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpx/_transports/default.py", line 393, in handle_async_request
with map_httpcore_exceptions():
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/contextlib.py", line 155, in __exit__
self.gen.throw(typ, value, traceback)
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpx/_transports/default.py", line 118, in map_httpcore_exceptions
raise mapped_exc(message) from exc
httpx.ConnectError: All connection attempts failed
18:10:30 - LiteLLM Proxy:ERROR: utils.py:1404 - LiteLLM Prisma Client Exception get_generic_data: All connection attempts failed
18:10:30 - LiteLLM Proxy:ERROR: utils.py:1404 - LiteLLM Prisma Client Exception get_generic_data: All connection attempts failed
18:10:30 - LiteLLM Proxy:INFO: proxy_server.py:490 - Shutting down LiteLLM Proxy Server
18:11:47 - LiteLLM Router:INFO: router.py:660 - Routing strategy: simple-shuffle
18:11:49 - LiteLLM Proxy:INFO: utils.py:1317 - All necessary views exist!
18:11:50 - LiteLLM Router:WARNING: router.py:4862 - Error upserting deployment: vertex_project, and vertex_location must be set in litellm_params for pass-through endpoints., ignoring and continuing with other deployments.
18:11:50 - LiteLLM Router:WARNING: router.py:4862 - Error upserting deployment: vertex_project, and vertex_location must be set in litellm_params for pass-through endpoints., ignoring and continuing with other deployments.
18:12:00 - LiteLLM Proxy:ERROR: proxy_server.py:2925 - litellm.proxy_server.py::get_credentials() - Error getting credentials from DB - Server disconnected without sending a response.
Traceback (most recent call last):
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpx/_transports/default.py", line 101, in map_httpcore_exceptions
yield
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpx/_transports/default.py", line 394, in handle_async_request
resp = await self._pool.handle_async_request(req)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpcore/_async/connection_pool.py", line 216, in handle_async_request
raise exc from None
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpcore/_async/connection_pool.py", line 196, in handle_async_request
response = await connection.handle_async_request(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpcore/_async/connection.py", line 101, in handle_async_request
return await self._connection.handle_async_request(request)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpcore/_async/http11.py", line 143, in handle_async_request
raise exc
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpcore/_async/http11.py", line 113, in handle_async_request
) = await self._receive_response_headers(**kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpcore/_async/http11.py", line 186, in _receive_response_headers
event = await self._receive_event(timeout=timeout)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpcore/_async/http11.py", line 238, in _receive_event
raise RemoteProtocolError(msg)
httpcore.RemoteProtocolError: Server disconnected without sending a response.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/Users/krrishdholakia/Documents/litellm/litellm/proxy/proxy_server.py", line 2916, in get_credentials
credentials = await prisma_client.db.litellm_credentialstable.find_many()
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/prisma/actions.py", line 1502, in find_many
resp = await self._client._execute(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/prisma/_base_client.py", line 543, in _execute
return await self._engine.query(builder.build(), tx_id=self._tx_id)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/prisma/engine/_query.py", line 402, in query
return await self.request(
^^^^^^^^^^^^^^^^^^^
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/prisma/engine/_http.py", line 217, in request
response = await self.session.request(method, url, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/prisma/_async_http.py", line 26, in request
return Response(await self.session.request(method, url, **kwargs))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpx/_client.py", line 1540, in request
return await self.send(request, auth=auth, follow_redirects=follow_redirects)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpx/_client.py", line 1629, in send
response = await self._send_handling_auth(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpx/_client.py", line 1657, in _send_handling_auth
response = await self._send_handling_redirects(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpx/_client.py", line 1694, in _send_handling_redirects
response = await self._send_single_request(request)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpx/_client.py", line 1730, in _send_single_request
response = await transport.handle_async_request(request)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpx/_transports/default.py", line 393, in handle_async_request
with map_httpcore_exceptions():
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/contextlib.py", line 155, in __exit__
self.gen.throw(typ, value, traceback)
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpx/_transports/default.py", line 118, in map_httpcore_exceptions
raise mapped_exc(message) from exc
httpx.RemoteProtocolError: Server disconnected without sending a response.
18:12:01 - LiteLLM Proxy:INFO: proxy_server.py:490 - Shutting down LiteLLM Proxy Server
18:12:14 - LiteLLM Router:INFO: router.py:660 - Routing strategy: simple-shuffle
18:12:16 - LiteLLM Proxy:INFO: utils.py:1317 - All necessary views exist!
18:12:16 - LiteLLM Router:WARNING: router.py:4862 - Error upserting deployment: vertex_project, and vertex_location must be set in litellm_params for pass-through endpoints., ignoring and continuing with other deployments.
18:12:16 - LiteLLM Router:WARNING: router.py:4862 - Error upserting deployment: vertex_project, and vertex_location must be set in litellm_params for pass-through endpoints., ignoring and continuing with other deployments.
18:12:21 - LiteLLM Proxy:INFO: ui_sso.py:129 - Redirecting to SSO login for http://localhost:4000/sso/callback
18:12:26 - LiteLLM Proxy:INFO: ui_sso.py:495 - Starting SSO callback
18:12:26 - LiteLLM Proxy:INFO: ui_sso.py:550 - Redirecting to http://localhost:4000/sso/callback
18:12:26 - LiteLLM Proxy:INFO: ui_sso.py:581 - SSO callback result: id='krrishd' email='krrishdholakia@gmail.com' first_name=None last_name=None display_name='a3f1c107-04dc-4c93-ae60-7f32eb4b05ce' picture=None provider=None team_ids=[]
18:12:27 - LiteLLM Proxy:INFO: ui_sso.py:672 - user_defined_values for creating ui key: {'models': [], 'user_id': 'krrishd', 'user_email': 'krrishdholakia@gmail.com', 'max_budget': None, 'user_role': 'proxy_admin', 'budget_duration': None}
18:12:27 - LiteLLM Proxy:INFO: utils.py:1856 - Data Inserted into Keys Table
18:12:27 - LiteLLM Proxy:INFO: ui_sso.py:762 - user_id: krrishd; jwt_token: eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJ1c2VyX2lkIjoia3JyaXNoZCIsImtleSI6InNrLUQzMEFpdW9lckU3YlMyakFXWVFLd1EiLCJ1c2VyX2VtYWlsIjoia3JyaXNoZGhvbGFraWFAZ21haWwuY29tIiwidXNlcl9yb2xlIjoicHJveHlfYWRtaW4iLCJsb2dpbl9tZXRob2QiOiJzc28iLCJwcmVtaXVtX3VzZXIiOnRydWUsImF1dGhfaGVhZGVyX25hbWUiOiJBdXRob3JpemF0aW9uIiwiZGlzYWJsZWRfbm9uX2FkbWluX3BlcnNvbmFsX2tleV9jcmVhdGlvbiI6ZmFsc2UsInNlcnZlcl9yb290X3BhdGgiOiIvIn0.EzYP86hw12J4WHLe6ZZz4YgVNGPnxM_PHqLjINH2_-U
18:12:27 - LiteLLM Proxy:INFO: ui_sso.py:765 - Redirecting to http://localhost:4000/ui/?login=success
18:12:31 - LiteLLM Proxy:INFO: proxy_server.py:490 - Shutting down LiteLLM Proxy Server
18:15:07 - LiteLLM Router:INFO: router.py:660 - Routing strategy: simple-shuffle
18:15:09 - LiteLLM Proxy:INFO: utils.py:1317 - All necessary views exist!
18:15:09 - LiteLLM Router:WARNING: router.py:4862 - Error upserting deployment: vertex_project, and vertex_location must be set in litellm_params for pass-through endpoints., ignoring and continuing with other deployments.
18:15:09 - LiteLLM Router:WARNING: router.py:4862 - Error upserting deployment: vertex_project, and vertex_location must be set in litellm_params for pass-through endpoints., ignoring and continuing with other deployments.
18:15:17 - LiteLLM Proxy:INFO: utils.py:1916 - Data Inserted into Config Table
18:15:28 - LiteLLM Proxy:INFO: ui_sso.py:129 - Redirecting to SSO login for http://localhost:4000/sso/callback
18:15:32 - LiteLLM Proxy:INFO: ui_sso.py:495 - Starting SSO callback
18:15:32 - LiteLLM Proxy:INFO: ui_sso.py:550 - Redirecting to http://localhost:4000/sso/callback
18:15:32 - LiteLLM Proxy:INFO: ui_sso.py:581 - SSO callback result: id='krrishd' email='krrishdholakia@gmail.com' first_name=None last_name=None display_name='a3f1c107-04dc-4c93-ae60-7f32eb4b05ce' picture=None provider=None team_ids=[]
18:15:37 - LiteLLM Proxy:INFO: proxy_server.py:490 - Shutting down LiteLLM Proxy Server
@@ -512,6 +512,20 @@ async def generate_key_fn( # noqa: PLR0915
},
)
# APPLY ENTERPRISE KEY MANAGEMENT PARAMS
try:
from litellm_enterprise.proxy.management_endpoints.key_management_endpoints import (
apply_enterprise_key_management_params,
)
data = apply_enterprise_key_management_params(data, team_table)
except Exception as e:
verbose_proxy_logger.info(
"litellm.proxy.proxy_server.generate_key_fn(): Enterprise key management params not applied - {}".format(
str(e)
)
)
# TODO: @ishaan-jaff: Migrate all budget tracking to use LiteLLM_BudgetTable
_budget_id = data.budget_id
if prisma_client is not None and data.soft_budget is not None:
@@ -536,7 +550,7 @@ async def generate_key_fn( # noqa: PLR0915
# ADD METADATA FIELDS
# Set Management Endpoint Metadata Fields
for field in LiteLLM_ManagementEndpoint_MetadataFields_Premium:
if getattr(data, field) is not None:
if getattr(data, field, None) is not None:
_set_object_metadata_field(
object_data=data,
field_name=field,
@@ -589,9 +603,9 @@ async def generate_key_fn( # noqa: PLR0915
request_type="key", **data_json, table_name="key"
)
response[
"soft_budget"
] = data.soft_budget # include the user-input soft budget in the response
response["soft_budget"] = (
data.soft_budget
) # include the user-input soft budget in the response
response = GenerateKeyResponse(**response)
@@ -667,9 +681,9 @@ async def _set_object_permission(
data=data_json["object_permission"],
)
)
data_json[
"object_permission_id"
] = created_object_permission.object_permission_id
data_json["object_permission_id"] = (
created_object_permission.object_permission_id
)
# delete the object_permission from the data_json
data_json.pop("object_permission")
@@ -1652,10 +1666,10 @@ async def delete_verification_tokens(
try:
if prisma_client:
tokens = [_hash_token_if_needed(token=key) for key in tokens]
_keys_being_deleted: List[
LiteLLM_VerificationToken
] = await prisma_client.db.litellm_verificationtoken.find_many(
where={"token": {"in": tokens}}
_keys_being_deleted: List[LiteLLM_VerificationToken] = (
await prisma_client.db.litellm_verificationtoken.find_many(
where={"token": {"in": tokens}}
)
)
if len(_keys_being_deleted) == 0:
@@ -1763,9 +1777,9 @@ async def _rotate_master_key(
from litellm.proxy.proxy_server import proxy_config
try:
models: Optional[
List
] = await prisma_client.db.litellm_proxymodeltable.find_many()
models: Optional[List] = (
await prisma_client.db.litellm_proxymodeltable.find_many()
)
except Exception:
models = None
# 2. process model table
@@ -2057,11 +2071,11 @@ async def validate_key_list_check(
param="user_id",
code=status.HTTP_403_FORBIDDEN,
)
complete_user_info_db_obj: Optional[
BaseModel
] = await prisma_client.db.litellm_usertable.find_unique(
where={"user_id": user_api_key_dict.user_id},
include={"organization_memberships": True},
complete_user_info_db_obj: Optional[BaseModel] = (
await prisma_client.db.litellm_usertable.find_unique(
where={"user_id": user_api_key_dict.user_id},
include={"organization_memberships": True},
)
)
if complete_user_info_db_obj is None:
@@ -2147,10 +2161,10 @@ async def get_admin_team_ids(
if complete_user_info is None:
return []
# Get all teams that user is an admin of
teams: Optional[
List[BaseModel]
] = await prisma_client.db.litellm_teamtable.find_many(
where={"team_id": {"in": complete_user_info.teams}}
teams: Optional[List[BaseModel]] = (
await prisma_client.db.litellm_teamtable.find_many(
where={"team_id": {"in": complete_user_info.teams}}
)
)
if teams is None:
return []
@@ -2403,12 +2417,14 @@ async def _list_key_helper(
where=where, # type: ignore
skip=skip, # type: ignore
take=size, # type: ignore
order=order_by
if order_by
else [
{"created_at": "desc"},
{"token": "desc"}, # fallback sort
],
order=(
order_by
if order_by
else [
{"created_at": "desc"},
{"token": "desc"}, # fallback sort
]
),
include={"object_permission": True},
)
@@ -18,6 +18,7 @@ from fastapi import (
Response,
)
import litellm
from litellm._logging import verbose_proxy_logger
from litellm.litellm_core_utils.safe_json_dumps import safe_dumps
from litellm.proxy._types import (
@@ -361,6 +362,18 @@ async def create_user(
# Create user in database
user_id = user.userName or str(uuid.uuid4())
metadata = _build_scim_metadata(user_data["given_name"], user_data["family_name"])
default_role: Optional[
Literal[
LitellmUserRoles.PROXY_ADMIN,
LitellmUserRoles.PROXY_ADMIN_VIEW_ONLY,
LitellmUserRoles.INTERNAL_USER,
LitellmUserRoles.INTERNAL_USER_VIEW_ONLY,
]
] = LitellmUserRoles.INTERNAL_USER_VIEW_ONLY
if litellm.default_internal_user_params:
default_role = litellm.default_internal_user_params.get("user_role")
new_user_request = NewUserRequest(
user_id=user_id,
user_email=user_data["user_email"],
@@ -368,6 +381,7 @@ async def create_user(
teams=user_data["teams"],
metadata=metadata,
auto_create_key=False,
user_role=default_role,
)
# Check if user with email already exists and update if found
@@ -1,9 +1,14 @@
from typing import Dict, Union
from litellm.proxy._types import LitellmUserRoles
def check_is_admin_only_access(ui_access_mode: str) -> bool:
def check_is_admin_only_access(ui_access_mode: Union[str, Dict]) -> bool:
"""Checks ui access mode is admin_only"""
return ui_access_mode == "admin_only"
if isinstance(ui_access_mode, str):
return ui_access_mode == "admin_only"
else:
return False
def has_admin_ui_access(user_role: str) -> bool:
@@ -262,6 +262,7 @@ async def new_team( # noqa: PLR0915
- guardrails: Optional[List[str]] - Guardrails for the team. [Docs](https://docs.litellm.ai/docs/proxy/guardrails)
- object_permission: Optional[LiteLLM_ObjectPermissionBase] - team-specific object permission. Example - {"vector_stores": ["vector_store_1", "vector_store_2"]}. IF null or {} then no object permission.
- team_member_budget: Optional[float] - The maximum budget allocated to an individual team member.
- team_member_key_duration: Optional[str] - The duration for a team member's key. e.g. "1d", "1w", "1mo"
Returns:
- team_id: (str) Unique team id - used for tracking spend across multiple keys for same team id.
@@ -688,6 +689,7 @@ async def update_team(
- guardrails: Optional[List[str]] - Guardrails for the team. [Docs](https://docs.litellm.ai/docs/proxy/guardrails)
- object_permission: Optional[LiteLLM_ObjectPermissionBase] - team-specific object permission. Example - {"vector_stores": ["vector_store_1", "vector_store_2"]}. IF null or {} then no object permission.
- team_member_budget: Optional[float] - The maximum budget allocated to an individual team member.
- team_member_key_duration: Optional[str] - The duration for a team member's key. e.g. "1d", "1w", "1mo"
Example - update team TPM Limit
```
+113 -8
View File
@@ -145,7 +145,11 @@ async def google_login(request: Request): # noqa: PLR0915
return HTMLResponse(content=html_form, status_code=200)
def generic_response_convertor(response, jwt_handler: JWTHandler):
def generic_response_convertor(
response,
jwt_handler: JWTHandler,
sso_jwt_handler: Optional[JWTHandler] = None,
):
generic_user_id_attribute_name = os.getenv(
"GENERIC_USER_ID_ATTRIBUTE", "preferred_username"
)
@@ -171,6 +175,13 @@ def generic_response_convertor(response, jwt_handler: JWTHandler):
f" generic_user_id_attribute_name: {generic_user_id_attribute_name}\n generic_user_email_attribute_name: {generic_user_email_attribute_name}"
)
all_teams = []
if sso_jwt_handler is not None:
team_ids = sso_jwt_handler.get_team_ids_from_jwt(cast(dict, response))
all_teams.extend(team_ids)
team_ids = jwt_handler.get_team_ids_from_jwt(cast(dict, response))
all_teams.extend(team_ids)
return CustomOpenID(
id=response.get(generic_user_id_attribute_name),
display_name=response.get(generic_user_display_name_attribute_name),
@@ -178,20 +189,24 @@ def generic_response_convertor(response, jwt_handler: JWTHandler):
first_name=response.get(generic_user_first_name_attribute_name),
last_name=response.get(generic_user_last_name_attribute_name),
provider=response.get(generic_provider_attribute_name),
team_ids=jwt_handler.get_team_ids_from_jwt(cast(dict, response)),
team_ids=all_teams,
)
async def get_generic_sso_response(
request: Request,
jwt_handler: JWTHandler,
sso_jwt_handler: Optional[
JWTHandler
], # sso specific jwt handler - used for restricted sso group access control
generic_client_id: str,
redirect_url: str,
) -> Union[OpenID, dict]:
) -> Tuple[Union[OpenID, dict], Optional[dict]]: # return received response
# make generic sso provider
from fastapi_sso.sso.base import DiscoveryDocument
from fastapi_sso.sso.generic import create_provider
received_response: Optional[dict] = None
generic_client_secret = os.getenv("GENERIC_CLIENT_SECRET", None)
generic_scope = os.getenv("GENERIC_SCOPE", "openid email profile").split(" ")
generic_authorization_endpoint = os.getenv("GENERIC_AUTHORIZATION_ENDPOINT", None)
@@ -242,9 +257,12 @@ async def get_generic_sso_response(
)
def response_convertor(response, client):
nonlocal received_response # return for user debugging
received_response = response
return generic_response_convertor(
response=response,
jwt_handler=jwt_handler,
sso_jwt_handler=sso_jwt_handler,
)
SSOProvider = create_provider(
@@ -284,7 +302,7 @@ async def get_generic_sso_response(
)
raise e
verbose_proxy_logger.debug("generic result: %s", result)
return result or {}
return result or {}, received_response
async def create_team_member_add_task(team_id, user_info):
@@ -480,6 +498,8 @@ async def check_and_update_if_proxy_admin_id(
async def auth_callback(request: Request): # noqa: PLR0915
"""Verify login"""
verbose_proxy_logger.info("Starting SSO callback")
from litellm.proxy._types import LiteLLM_JWTAuth
from litellm.proxy.auth.handle_jwt import JWTHandler
from litellm.proxy.management_endpoints.key_management_endpoints import (
generate_key_helper_fn,
)
@@ -490,7 +510,6 @@ async def auth_callback(request: Request): # noqa: PLR0915
premium_user,
prisma_client,
proxy_logging_obj,
ui_access_mode,
user_api_key_cache,
user_custom_sso,
)
@@ -502,9 +521,25 @@ async def auth_callback(request: Request): # noqa: PLR0915
status_code=500, detail=CommonProxyErrors.db_not_connected_error.value
)
sso_jwt_handler: Optional[JWTHandler] = None
ui_access_mode = general_settings.get("ui_access_mode", None)
if ui_access_mode is not None and isinstance(ui_access_mode, dict):
sso_jwt_handler = JWTHandler()
sso_jwt_handler.update_environment(
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
litellm_jwtauth=LiteLLM_JWTAuth(
team_ids_jwt_field=general_settings.get("ui_access_mode", {}).get(
"sso_group_jwt_field", None
),
),
leeway=0,
)
microsoft_client_id = os.getenv("MICROSOFT_CLIENT_ID", None)
google_client_id = os.getenv("GOOGLE_CLIENT_ID", None)
generic_client_id = os.getenv("GENERIC_CLIENT_ID", None)
received_response: Optional[dict] = None
# get url from request
if master_key is None:
raise ProxyException(
@@ -532,11 +567,12 @@ async def auth_callback(request: Request): # noqa: PLR0915
redirect_url=redirect_url,
)
elif generic_client_id is not None:
result = await get_generic_sso_response(
result, received_response = await get_generic_sso_response(
request=request,
jwt_handler=jwt_handler,
generic_client_id=generic_client_id,
redirect_url=redirect_url,
sso_jwt_handler=sso_jwt_handler,
)
if result is None:
@@ -547,6 +583,7 @@ async def auth_callback(request: Request): # noqa: PLR0915
# User is Authe'd in - generate key for the UI to access Proxy
verbose_proxy_logger.info(f"SSO callback result: {result}")
user_email: Optional[str] = getattr(result, "email", None)
user_id: Optional[str] = getattr(result, "id", None) if result is not None else None
@@ -612,6 +649,13 @@ async def auth_callback(request: Request): # noqa: PLR0915
budget_duration=internal_user_budget_duration,
)
# (IF SET) Verify user is in restricted SSO group
SSOAuthenticationHandler.verify_user_in_restricted_sso_group(
general_settings=general_settings,
result=result,
received_response=received_response,
)
user_info = await get_user_info_from_db(
result=result,
prisma_client=prisma_client,
@@ -1055,6 +1099,44 @@ class SSOAuthenticationHandler:
sso_teams = getattr(result, "team_ids", [])
await add_missing_team_member(user_info=user_info, sso_teams=sso_teams)
@staticmethod
def verify_user_in_restricted_sso_group(
general_settings: Dict,
result: Optional[Union[CustomOpenID, OpenID, dict]],
received_response: Optional[dict],
) -> Literal[True]:
"""
when ui_access_mode.type == "restricted_sso_group":
- result.team_ids should contain the restricted_sso_group
- if not, raise a ProxyException
- if so, return True
- if result.team_ids is None, return False
- if result.team_ids is an empty list, return False
- if result.team_ids is a list, return True if the restricted_sso_group is in the list, otherwise return False
"""
ui_access_mode = cast(
Optional[Union[Dict, str]], general_settings.get("ui_access_mode")
)
if ui_access_mode is None:
return True
if isinstance(ui_access_mode, str):
return True
team_ids = getattr(result, "team_ids", [])
if ui_access_mode.get("type") == "restricted_sso_group":
restricted_sso_group = ui_access_mode.get("restricted_sso_group")
if restricted_sso_group not in team_ids:
raise ProxyException(
message=f"User is not in the restricted SSO group: {restricted_sso_group}. User groups: {team_ids}. Received SSO response: {received_response}",
type=ProxyErrorTypes.auth_error,
param="restricted_sso_group",
code=status.HTTP_403_FORBIDDEN,
)
return True
@staticmethod
async def create_litellm_team_from_sso_group(
litellm_team_id: str,
@@ -1551,7 +1633,29 @@ async def debug_sso_callback(request: Request):
from fastapi.responses import HTMLResponse
from litellm.proxy.proxy_server import jwt_handler
from litellm.proxy._types import LiteLLM_JWTAuth
from litellm.proxy.auth.handle_jwt import JWTHandler
from litellm.proxy.proxy_server import (
general_settings,
jwt_handler,
prisma_client,
user_api_key_cache,
)
sso_jwt_handler: Optional[JWTHandler] = None
ui_access_mode = general_settings.get("ui_access_mode", None)
if ui_access_mode is not None and isinstance(ui_access_mode, dict):
sso_jwt_handler = JWTHandler()
sso_jwt_handler.update_environment(
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
litellm_jwtauth=LiteLLM_JWTAuth(
team_ids_jwt_field=general_settings.get("ui_access_mode", {}).get(
"sso_group_jwt_field", None
),
),
leeway=0,
)
microsoft_client_id = os.getenv("MICROSOFT_CLIENT_ID", None)
google_client_id = os.getenv("GOOGLE_CLIENT_ID", None)
@@ -1580,11 +1684,12 @@ async def debug_sso_callback(request: Request):
)
elif generic_client_id is not None:
result = await get_generic_sso_response(
result, _ = await get_generic_sso_response(
request=request,
jwt_handler=jwt_handler,
generic_client_id=generic_client_id,
redirect_url=redirect_url,
sso_jwt_handler=sso_jwt_handler,
)
# If result is None, return a basic error message
+7 -1
View File
@@ -905,7 +905,7 @@ health_check_results: Dict[str, Union[int, List[Dict[str, Any]]]] = {}
queue: List = []
litellm_proxy_budget_name = "litellm-proxy-budget"
litellm_proxy_admin_name = LITELLM_PROXY_ADMIN_NAME
ui_access_mode: Literal["admin", "all"] = "all"
ui_access_mode: Union[Literal["admin", "all"], Dict] = "all"
proxy_budget_rescheduler_min_time = PROXY_BUDGET_RESCHEDULER_MIN_TIME
proxy_budget_rescheduler_max_time = PROXY_BUDGET_RESCHEDULER_MAX_TIME
proxy_batch_write_at = PROXY_BATCH_WRITE_AT
@@ -1435,11 +1435,13 @@ class ProxyConfig:
- Do not write restricted params like 'api_key' to the database
- if api_key is passed, save that to the local environment or connected secret manage (maybe expose `litellm.save_secret()`)
"""
if prisma_client is not None and (
general_settings.get("store_model_in_db", False) is True
or store_model_in_db
):
# if using - db for config - models are in ModelTable
new_config.pop("model_list", None)
await prisma_client.insert_data(data=new_config, table_name="config")
else:
@@ -2625,6 +2627,10 @@ class ProxyConfig:
pass_through_endpoints=general_settings["pass_through_endpoints"]
)
## UI ACCESS MODE ##
if "ui_access_mode" in _general_settings:
general_settings["ui_access_mode"] = _general_settings["ui_access_mode"]
def _update_config_fields(
self,
current_config: dict,
+4 -1
View File
@@ -81,7 +81,10 @@ async def route_request(
team_id = get_team_id_from_data(data)
router_model_names = llm_router.model_names if llm_router is not None else []
if "api_key" in data or "api_base" in data:
return getattr(llm_router, f"{route_type}")(**data)
if llm_router is not None:
return getattr(llm_router, f"{route_type}")(**data)
else:
return getattr(litellm, f"{route_type}")(**data)
elif "user_config" in data:
router_config = data.pop("user_config")
@@ -7,7 +7,10 @@ import litellm
from litellm._logging import verbose_proxy_logger
from litellm.proxy._types import *
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
from litellm.types.proxy.management_endpoints.ui_sso import DefaultTeamSSOParams, SSOConfig
from litellm.types.proxy.management_endpoints.ui_sso import (
DefaultTeamSSOParams,
SSOConfig,
)
router = APIRouter()
@@ -18,26 +21,29 @@ class IPAddress(BaseModel):
class SettingsResponse(BaseModel):
"""Base response model for settings with values and schema information"""
values: Dict[str, Any]
"""The current configuration values"""
field_schema: Dict[str, Any]
"""Schema information including descriptions and property types for UI display"""
class SSOSettingsResponse(SettingsResponse):
"""Response model for SSO settings"""
pass
class InternalUserSettingsResponse(SettingsResponse):
"""Response model for internal user settings"""
pass
class DefaultTeamSettingsResponse(SettingsResponse):
"""Response model for default team settings"""
pass
@@ -166,7 +172,10 @@ async def _get_settings_with_schema(
# Add descriptions to the response
result = {
"values": settings_dict,
"field_schema": {"description": schema.get("description", ""), "properties": {}},
"field_schema": {
"description": schema.get("description", ""),
"properties": {},
},
}
# Add property descriptions
@@ -322,20 +331,21 @@ async def get_sso_settings():
Returns a structured object with values and descriptions for UI display.
"""
import os
from litellm.proxy.proxy_server import proxy_config
# Load existing config to get both environment variables and general settings
config = await proxy_config.get_config()
general_settings = config.get("general_settings", {}) or {}
environment_variables = config.get("environment_variables", {}) or {}
# Get user_email from general_settings
proxy_admin_email = general_settings.get("proxy_admin_email", None)
# Helper function to get env var value (first from config, then from environment)
def get_env_value(env_var_name: str):
return environment_variables.get(env_var_name) or os.getenv(env_var_name)
# Get current environment variables for SSO
sso_config = SSOConfig(
google_client_id=get_env_value("GOOGLE_CLIENT_ID"),
@@ -351,27 +361,31 @@ async def get_sso_settings():
proxy_base_url=get_env_value("PROXY_BASE_URL"),
user_email=proxy_admin_email, # Get from config instead of environment
)
# Get the schema for UI display
from pydantic import TypeAdapter
schema = TypeAdapter(SSOConfig).json_schema(by_alias=True)
# Convert to dict for response
sso_dict = sso_config.model_dump()
# Add descriptions to the response
result = {
"values": sso_dict,
"field_schema": {"description": schema.get("description", ""), "properties": {}},
"field_schema": {
"description": schema.get("description", ""),
"properties": {},
},
}
# Add property descriptions
for field_name, field_info in schema["properties"].items():
result["field_schema"]["properties"][field_name] = {
"description": field_info.get("description", ""),
"type": field_info.get("type", "string"),
}
return result
@@ -384,51 +398,56 @@ async def update_sso_settings(sso_config: SSOConfig):
"""
Update SSO configuration by saving to both environment variables and config file.
"""
from litellm.proxy.proxy_server import proxy_config
import os
from litellm.proxy.proxy_server import proxy_config
# Update environment variables
env_var_mapping = {
'google_client_id': 'GOOGLE_CLIENT_ID',
'google_client_secret': 'GOOGLE_CLIENT_SECRET',
'microsoft_client_id': 'MICROSOFT_CLIENT_ID',
'microsoft_client_secret': 'MICROSOFT_CLIENT_SECRET',
'microsoft_tenant': 'MICROSOFT_TENANT',
'generic_client_id': 'GENERIC_CLIENT_ID',
'generic_client_secret': 'GENERIC_CLIENT_SECRET',
'generic_authorization_endpoint': 'GENERIC_AUTHORIZATION_ENDPOINT',
'generic_token_endpoint': 'GENERIC_TOKEN_ENDPOINT',
'generic_userinfo_endpoint': 'GENERIC_USERINFO_ENDPOINT',
'proxy_base_url': 'PROXY_BASE_URL',
"google_client_id": "GOOGLE_CLIENT_ID",
"google_client_secret": "GOOGLE_CLIENT_SECRET",
"microsoft_client_id": "MICROSOFT_CLIENT_ID",
"microsoft_client_secret": "MICROSOFT_CLIENT_SECRET",
"microsoft_tenant": "MICROSOFT_TENANT",
"generic_client_id": "GENERIC_CLIENT_ID",
"generic_client_secret": "GENERIC_CLIENT_SECRET",
"generic_authorization_endpoint": "GENERIC_AUTHORIZATION_ENDPOINT",
"generic_token_endpoint": "GENERIC_TOKEN_ENDPOINT",
"generic_userinfo_endpoint": "GENERIC_USERINFO_ENDPOINT",
"proxy_base_url": "PROXY_BASE_URL",
}
# Load existing config
config = await proxy_config.get_config()
# Update config with new environment variables
if "environment_variables" not in config:
config["environment_variables"] = {}
# Update general_settings for user_email (admin email)
if "general_settings" not in config:
config["general_settings"] = {}
# Update environment variables in config and in memory
sso_data = sso_config.model_dump(exclude_none=True)
for field_name, value in sso_data.items():
if field_name == 'user_email' and value is not None:
if field_name == "user_email" and value is not None:
# Store user_email in general_settings instead of environment variables
config["general_settings"]["proxy_admin_email"] = value
elif field_name == "ui_access_mode" and value is not None:
config["general_settings"]["ui_access_mode"] = value
elif field_name in env_var_mapping and value is not None:
env_var_name = env_var_mapping[field_name]
# Update in config
config["environment_variables"][env_var_name] = value
# Update in runtime environment
os.environ[env_var_name] = value
# Save the updated config
await proxy_config.save_config(new_config=config)
return {
"message": "SSO settings updated successfully",
"status": "success",
+18 -12
View File
@@ -4108,20 +4108,26 @@ class Router:
original_exception=exception
)
_time_to_cooldown = kwargs.get("litellm_params", {}).get(
"cooldown_time", self.cooldown_time
)
# Determine cooldown time with priority: deployment config > response header > router default
deployment_cooldown = kwargs.get("litellm_params", {}).get("cooldown_time", None)
header_cooldown = None
if exception_headers is not None:
_time_to_cooldown = (
litellm.utils._get_retry_after_from_exception_header(
response_headers=exception_headers
)
header_cooldown = litellm.utils._get_retry_after_from_exception_header(
response_headers=exception_headers
)
if _time_to_cooldown is None or _time_to_cooldown < 0:
# if the response headers did not read it -> set to default cooldown time
_time_to_cooldown = self.cooldown_time
##############################################
# Logic to determine cooldown time
# 1. Check if a cooldown time is set in the deployment config
# 2. Check if a cooldown time is set in the response header
# 3. If no cooldown time is set, use the router default cooldown time
##############################################
if deployment_cooldown is not None and deployment_cooldown >= 0:
_time_to_cooldown = deployment_cooldown
elif header_cooldown is not None and header_cooldown >= 0:
_time_to_cooldown = header_cooldown
else:
_time_to_cooldown = self.cooldown_time
if isinstance(_model_info, dict):
deployment_id = _model_info.get("id", None)
@@ -1,4 +1,4 @@
from typing import List, Literal, Optional, TypedDict
from typing import List, Literal, Optional, TypedDict, Union
from pydantic import Field
@@ -31,6 +31,14 @@ class MicrosoftServicePrincipalTeam(TypedDict, total=False):
principalId: Optional[str]
class AccessControl_UI_AccessMode(LiteLLMPydanticObjectBase):
"""Model for Controlling UI Access Mode via SSO Groups"""
type: Literal["restricted_sso_group"]
restricted_sso_group: str
sso_group_jwt_field: str
class SSOConfig(LiteLLMPydanticObjectBase):
"""
Configuration for SSO environment variables and settings
@@ -45,7 +53,7 @@ class SSOConfig(LiteLLMPydanticObjectBase):
default=None,
description="Google OAuth Client Secret for SSO authentication",
)
# Microsoft SSO
microsoft_client_id: Optional[str] = Field(
default=None,
@@ -59,7 +67,7 @@ class SSOConfig(LiteLLMPydanticObjectBase):
default=None,
description="Microsoft Azure Tenant ID for SSO authentication",
)
# Generic/Okta SSO
generic_client_id: Optional[str] = Field(
default=None,
@@ -81,7 +89,7 @@ class SSOConfig(LiteLLMPydanticObjectBase):
default=None,
description="User info endpoint URL for generic OAuth provider",
)
# Common settings
proxy_base_url: Optional[str] = Field(
default=None,
@@ -92,6 +100,12 @@ class SSOConfig(LiteLLMPydanticObjectBase):
description="Email of the proxy admin user",
)
# Access Mode
ui_access_mode: Optional[Union[AccessControl_UI_AccessMode, str]] = Field(
default=None,
description="Access mode for the UI",
)
class DefaultTeamSSOParams(LiteLLMPydanticObjectBase):
"""
+80
View File
@@ -85,3 +85,83 @@ class VectorStoreSearchResponse(TypedDict, total=False):
] # Always "vector_store.search_results.page"
search_query: Optional[str]
data: Optional[List[VectorStoreSearchResult]]
class VectorStoreSearchOptionalRequestParams(TypedDict, total=False):
"""TypedDict for Optional parameters supported by the vector store search API."""
filters: Optional[Dict]
max_num_results: Optional[int]
ranking_options: Optional[Dict]
rewrite_query: Optional[bool]
class VectorStoreSearchRequest(VectorStoreSearchOptionalRequestParams, total=False):
"""Request body for searching a vector store"""
query: Union[str, List[str]]
# Vector Store Creation Types
class VectorStoreExpirationPolicy(TypedDict, total=False):
"""The expiration policy for a vector store"""
anchor: Literal["last_active_at"] # Anchor timestamp after which the expiration policy applies
days: int # Number of days after anchor time that the vector store will expire
class VectorStoreAutoChunkingStrategy(TypedDict, total=False):
"""Auto chunking strategy configuration"""
type: Literal["auto"] # Always "auto"
class VectorStoreStaticChunkingStrategyConfig(TypedDict, total=False):
"""Static chunking strategy configuration"""
max_chunk_size_tokens: int # Maximum number of tokens per chunk
chunk_overlap_tokens: int # Number of tokens to overlap between chunks
class VectorStoreStaticChunkingStrategy(TypedDict, total=False):
"""Static chunking strategy"""
type: Literal["static"] # Always "static"
static: VectorStoreStaticChunkingStrategyConfig
class VectorStoreChunkingStrategy(TypedDict, total=False):
"""Union type for chunking strategies"""
# This can be either auto or static
type: Literal["auto", "static"]
static: Optional[VectorStoreStaticChunkingStrategyConfig]
class VectorStoreFileCounts(TypedDict, total=False):
"""File counts for a vector store"""
in_progress: int
completed: int
failed: int
cancelled: int
total: int
class VectorStoreCreateOptionalRequestParams(TypedDict, total=False):
"""TypedDict for Optional parameters supported by the vector store create API."""
name: Optional[str] # Name of the vector store
file_ids: Optional[List[str]] # List of File IDs that the vector store should use
expires_after: Optional[VectorStoreExpirationPolicy] # Expiration policy for the vector store
chunking_strategy: Optional[VectorStoreChunkingStrategy] # Chunking strategy for the files
metadata: Optional[Dict[str, str]] # Set of key-value pairs for metadata
class VectorStoreCreateRequest(VectorStoreCreateOptionalRequestParams, total=False):
"""Request body for creating a vector store"""
pass # All fields are optional for vector store creation
class VectorStoreCreateResponse(TypedDict, total=False):
"""Response after creating a vector store"""
id: str # ID of the vector store
object: Literal["vector_store"] # Always "vector_store"
created_at: int # Unix timestamp of when the vector store was created
name: Optional[str] # Name of the vector store
bytes: int # Size of the vector store in bytes
file_counts: VectorStoreFileCounts # File counts for the vector store
status: Literal["expired", "in_progress", "completed"] # Status of the vector store
expires_after: Optional[VectorStoreExpirationPolicy] # Expiration policy
expires_at: Optional[int] # Unix timestamp of when the vector store expires
last_active_at: Optional[int] # Unix timestamp of when the vector store was last active
metadata: Optional[Dict[str, str]] # Metadata associated with the vector store
+25 -2
View File
@@ -78,7 +78,9 @@ from litellm.constants import (
)
from litellm.integrations.custom_guardrail import CustomGuardrail
from litellm.integrations.custom_logger import CustomLogger
from litellm.integrations.vector_stores.base_vector_store import BaseVectorStore
from litellm.integrations.vector_store_integrations.base_vector_store import (
BaseVectorStore,
)
from litellm.litellm_core_utils.core_helpers import (
map_finish_reason,
process_response_headers,
@@ -242,6 +244,7 @@ from litellm.llms.base_llm.image_variations.transformation import (
from litellm.llms.base_llm.realtime.transformation import BaseRealtimeConfig
from litellm.llms.base_llm.rerank.transformation import BaseRerankConfig
from litellm.llms.base_llm.responses.transformation import BaseResponsesAPIConfig
from litellm.llms.base_llm.vector_store.transformation import BaseVectorStoreConfig
from ._logging import _is_debugging_on, verbose_logger
from .caching.caching import (
@@ -6902,13 +6905,33 @@ class ProviderConfigManager:
def get_provider_vector_store_config(
provider: LlmProviders,
) -> Optional[CustomLogger]:
from litellm.integrations.vector_stores.bedrock_vector_store import (
from litellm.integrations.vector_store_integrations.bedrock_vector_store import (
BedrockVectorStore,
)
if LlmProviders.BEDROCK == provider:
return BedrockVectorStore.get_initialized_custom_logger()
return None
@staticmethod
def get_provider_vector_stores_config(
provider: LlmProviders,
) -> Optional[BaseVectorStoreConfig]:
"""
v2 vector store config, use this for new vector store integrations
"""
if litellm.LlmProviders.OPENAI == provider:
from litellm.llms.openai.vector_stores.transformation import (
OpenAIVectorStoreConfig,
)
return OpenAIVectorStoreConfig()
elif litellm.LlmProviders.AZURE == provider:
from litellm.llms.azure.vector_stores.transformation import (
AzureOpenAIVectorStoreConfig,
)
return AzureOpenAIVectorStoreConfig()
return None
@staticmethod
def get_provider_image_generation_config(
+4
View File
@@ -0,0 +1,4 @@
from .main import acreate, asearch, create, search
from .vector_store_registry import VectorStoreRegistry
__all__ = ["search", "asearch", "create", "acreate", "VectorStoreRegistry"]
+434
View File
@@ -0,0 +1,434 @@
"""
LiteLLM SDK Functions for Creating and Searching Vector Stores
"""
import asyncio
import contextvars
from functools import partial
from typing import Any, Coroutine, Dict, List, Optional, Union
import httpx
import litellm
from litellm.constants import request_timeout
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.llms.custom_httpx.llm_http_handler import BaseLLMHTTPHandler
from litellm.types.router import GenericLiteLLMParams
from litellm.types.vector_stores import (
VectorStoreCreateOptionalRequestParams,
VectorStoreCreateResponse,
VectorStoreFileCounts,
VectorStoreResultContent,
VectorStoreSearchOptionalRequestParams,
VectorStoreSearchResponse,
VectorStoreSearchResult,
)
from litellm.utils import ProviderConfigManager, client
from litellm.vector_stores.utils import VectorStoreRequestUtils
####### ENVIRONMENT VARIABLES ###################
# Initialize any necessary instances or variables here
base_llm_http_handler = BaseLLMHTTPHandler()
#################################################
def mock_vector_store_search_response(
mock_results: Optional[List[VectorStoreSearchResult]] = None,
):
"""Mock response for vector store search"""
if mock_results is None:
mock_results = [
VectorStoreSearchResult(
score=0.95,
content=[
VectorStoreResultContent(
text="This is a sample search result from the vector store.",
type="text"
)
]
)
]
return VectorStoreSearchResponse(
object="vector_store.search_results.page",
search_query="sample query",
data=mock_results,
)
def mock_vector_store_create_response(
mock_response: Optional[VectorStoreCreateResponse] = None,
):
"""Mock response for vector store create"""
if mock_response is None:
mock_response = VectorStoreCreateResponse(
id="vs_mock123",
object="vector_store",
created_at=1699061776,
name="Mock Vector Store",
bytes=0,
file_counts=VectorStoreFileCounts(
in_progress=0,
completed=0,
failed=0,
cancelled=0,
total=0,
),
status="completed",
expires_after=None,
expires_at=None,
last_active_at=None,
metadata=None,
)
return mock_response
@client
async def acreate(
name: Optional[str] = None,
file_ids: Optional[List[str]] = None,
expires_after: Optional[Dict] = None,
chunking_strategy: Optional[Dict] = None,
metadata: Optional[Dict[str, str]] = None,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Optional[Dict[str, Any]] = None,
extra_query: Optional[Dict[str, Any]] = None,
extra_body: Optional[Dict[str, Any]] = None,
timeout: Optional[Union[float, httpx.Timeout]] = None,
# LiteLLM specific params,
custom_llm_provider: Optional[str] = None,
**kwargs,
) -> VectorStoreCreateResponse:
"""
Async: Create a vector store.
"""
local_vars = locals()
try:
loop = asyncio.get_event_loop()
kwargs["acreate"] = True
# get custom llm provider so we can use this for mapping exceptions
if custom_llm_provider is None:
custom_llm_provider = "openai" # Default to OpenAI for vector stores
func = partial(
create,
name=name,
file_ids=file_ids,
expires_after=expires_after,
chunking_strategy=chunking_strategy,
metadata=metadata,
extra_headers=extra_headers,
extra_query=extra_query,
extra_body=extra_body,
timeout=timeout,
custom_llm_provider=custom_llm_provider,
**kwargs,
)
ctx = contextvars.copy_context()
func_with_context = partial(ctx.run, func)
init_response = await loop.run_in_executor(None, func_with_context)
if asyncio.iscoroutine(init_response):
response = await init_response
else:
response = init_response
return response
except Exception as e:
raise litellm.exception_type(
model=None,
custom_llm_provider=custom_llm_provider,
original_exception=e,
completion_kwargs=local_vars,
extra_kwargs=kwargs,
)
@client
def create(
name: Optional[str] = None,
file_ids: Optional[List[str]] = None,
expires_after: Optional[Dict] = None,
chunking_strategy: Optional[Dict] = None,
metadata: Optional[Dict[str, str]] = None,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Optional[Dict[str, Any]] = None,
extra_query: Optional[Dict[str, Any]] = None,
extra_body: Optional[Dict[str, Any]] = None,
timeout: Optional[Union[float, httpx.Timeout]] = None,
# LiteLLM specific params,
custom_llm_provider: Optional[str] = None,
**kwargs,
) -> Union[VectorStoreCreateResponse, Coroutine[Any, Any, VectorStoreCreateResponse]]:
"""
Create a vector store.
Args:
name: The name of the vector store.
file_ids: A list of File IDs that the vector store should use.
expires_after: The expiration policy for the vector store.
chunking_strategy: The chunking strategy used to chunk the file(s).
metadata: Set of 16 key-value pairs that can be attached to an object.
Returns:
VectorStoreCreateResponse containing the created vector store details.
"""
local_vars = locals()
try:
litellm_logging_obj: LiteLLMLoggingObj = kwargs.get("litellm_logging_obj") # type: ignore
litellm_call_id: Optional[str] = kwargs.get("litellm_call_id", None)
_is_async = kwargs.pop("acreate", False) is True
# get llm provider logic
litellm_params = GenericLiteLLMParams(**kwargs)
## MOCK RESPONSE LOGIC
if litellm_params.mock_response and isinstance(
litellm_params.mock_response, dict
):
return mock_vector_store_create_response(
mock_response=VectorStoreCreateResponse(**litellm_params.mock_response)
)
# Default to OpenAI for vector stores
if custom_llm_provider is None:
custom_llm_provider = "openai"
# get provider config - using vector store custom logger for now
vector_store_provider_config = ProviderConfigManager.get_provider_vector_stores_config(
provider=litellm.LlmProviders(custom_llm_provider),
)
if vector_store_provider_config is None:
raise ValueError(
f"Vector store create is not supported for {custom_llm_provider}"
)
local_vars.update(kwargs)
# Get VectorStoreCreateOptionalRequestParams with only valid parameters
vector_store_create_optional_params: VectorStoreCreateOptionalRequestParams = (
VectorStoreRequestUtils.get_requested_vector_store_create_optional_param(
local_vars
)
)
# Pre Call logging
litellm_logging_obj.update_environment_variables(
model=None,
optional_params={
"name": name,
**vector_store_create_optional_params,
},
litellm_params={
"litellm_call_id": litellm_call_id,
},
custom_llm_provider=custom_llm_provider,
)
response = base_llm_http_handler.vector_store_create_handler(
vector_store_create_optional_params=vector_store_create_optional_params,
vector_store_provider_config=vector_store_provider_config,
custom_llm_provider=custom_llm_provider,
litellm_params=litellm_params,
logging_obj=litellm_logging_obj,
extra_headers=extra_headers,
extra_body=extra_body,
timeout=timeout or request_timeout,
_is_async=_is_async,
client=kwargs.get("client"),
)
return response
except Exception as e:
raise litellm.exception_type(
model=None,
custom_llm_provider=custom_llm_provider,
original_exception=e,
completion_kwargs=local_vars,
extra_kwargs=kwargs,
)
@client
async def asearch(
vector_store_id: str,
query: Union[str, List[str]],
filters: Optional[Dict] = None,
max_num_results: Optional[int] = None,
ranking_options: Optional[Dict] = None,
rewrite_query: Optional[bool] = None,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Optional[Dict[str, Any]] = None,
extra_query: Optional[Dict[str, Any]] = None,
extra_body: Optional[Dict[str, Any]] = None,
timeout: Optional[Union[float, httpx.Timeout]] = None,
# LiteLLM specific params,
custom_llm_provider: Optional[str] = None,
**kwargs,
) -> VectorStoreSearchResponse:
"""
Async: Search a vector store for relevant chunks based on a query and file attributes filter.
"""
local_vars = locals()
try:
loop = asyncio.get_event_loop()
kwargs["asearch"] = True
# get custom llm provider so we can use this for mapping exceptions
if custom_llm_provider is None:
custom_llm_provider = "openai" # Default to OpenAI for vector stores
func = partial(
search,
vector_store_id=vector_store_id,
query=query,
filters=filters,
max_num_results=max_num_results,
ranking_options=ranking_options,
rewrite_query=rewrite_query,
extra_headers=extra_headers,
extra_query=extra_query,
extra_body=extra_body,
timeout=timeout,
custom_llm_provider=custom_llm_provider,
**kwargs,
)
ctx = contextvars.copy_context()
func_with_context = partial(ctx.run, func)
init_response = await loop.run_in_executor(None, func_with_context)
if asyncio.iscoroutine(init_response):
response = await init_response
else:
response = init_response
return response
except Exception as e:
raise litellm.exception_type(
model=None,
custom_llm_provider=custom_llm_provider,
original_exception=e,
completion_kwargs=local_vars,
extra_kwargs=kwargs,
)
@client
def search(
vector_store_id: str,
query: Union[str, List[str]],
filters: Optional[Dict] = None,
max_num_results: Optional[int] = None,
ranking_options: Optional[Dict] = None,
rewrite_query: Optional[bool] = None,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Optional[Dict[str, Any]] = None,
extra_query: Optional[Dict[str, Any]] = None,
extra_body: Optional[Dict[str, Any]] = None,
timeout: Optional[Union[float, httpx.Timeout]] = None,
# LiteLLM specific params,
custom_llm_provider: Optional[str] = None,
**kwargs,
) -> Union[VectorStoreSearchResponse, Coroutine[Any, Any, VectorStoreSearchResponse]]:
"""
Search a vector store for relevant chunks based on a query and file attributes filter.
Args:
vector_store_id: The ID of the vector store to search.
query: A query string or array for the search.
filters: Optional filter to apply based on file attributes.
max_num_results: Maximum number of results to return (1-50, default 10).
ranking_options: Optional ranking options for search.
rewrite_query: Whether to rewrite the natural language query for vector search.
Returns:
VectorStoreSearchResponse containing the search results.
"""
local_vars = locals()
try:
litellm_logging_obj: LiteLLMLoggingObj = kwargs.get("litellm_logging_obj") # type: ignore
litellm_call_id: Optional[str] = kwargs.get("litellm_call_id", None)
_is_async = kwargs.pop("asearch", False) is True
# get llm provider logic
litellm_params = GenericLiteLLMParams(**kwargs)
## MOCK RESPONSE LOGIC
if litellm_params.mock_response and isinstance(
litellm_params.mock_response, (str, list)
):
mock_results = None
if isinstance(litellm_params.mock_response, list):
mock_results = litellm_params.mock_response
return mock_vector_store_search_response(mock_results=mock_results)
# Default to OpenAI for vector stores
if custom_llm_provider is None:
custom_llm_provider = "openai"
# get provider config - using vector store custom logger for now
vector_store_provider_config = ProviderConfigManager.get_provider_vector_stores_config(
provider=litellm.LlmProviders(custom_llm_provider),
)
if vector_store_provider_config is None:
raise ValueError(
f"Vector store search is not supported for {custom_llm_provider}"
)
local_vars.update(kwargs)
# Get VectorStoreSearchOptionalRequestParams with only valid parameters
vector_store_search_optional_params: VectorStoreSearchOptionalRequestParams = (
VectorStoreRequestUtils.get_requested_vector_store_search_optional_param(
local_vars
)
)
# Pre Call logging
litellm_logging_obj.update_environment_variables(
model=None,
optional_params={
"vector_store_id": vector_store_id,
"query": query,
**vector_store_search_optional_params,
},
litellm_params={
"litellm_call_id": litellm_call_id,
"vector_store_id": vector_store_id,
},
custom_llm_provider=custom_llm_provider,
)
response = base_llm_http_handler.vector_store_search_handler(
vector_store_id=vector_store_id,
query=query,
vector_store_search_optional_params=vector_store_search_optional_params,
vector_store_provider_config=vector_store_provider_config,
custom_llm_provider=custom_llm_provider,
litellm_params=litellm_params,
logging_obj=litellm_logging_obj,
extra_headers=extra_headers,
extra_body=extra_body,
timeout=timeout or request_timeout,
_is_async=_is_async,
client=kwargs.get("client"),
)
return response
except Exception as e:
raise litellm.exception_type(
model=None,
custom_llm_provider=custom_llm_provider,
original_exception=e,
completion_kwargs=local_vars,
extra_kwargs=kwargs,
)
+51
View File
@@ -0,0 +1,51 @@
from typing import Any, Dict, cast, get_type_hints
from litellm.types.vector_stores import (
VectorStoreCreateOptionalRequestParams,
VectorStoreSearchOptionalRequestParams,
)
class VectorStoreRequestUtils:
"""Helper utils for constructing Vector Store search requests"""
@staticmethod
def get_requested_vector_store_search_optional_param(
params: Dict[str, Any],
) -> VectorStoreSearchOptionalRequestParams:
"""
Filter parameters to only include those defined in VectorStoreSearchOptionalRequestParams.
Args:
params: Dictionary of parameters to filter
Returns:
VectorStoreSearchOptionalRequestParams instance with only the valid parameters
"""
valid_keys = get_type_hints(VectorStoreSearchOptionalRequestParams).keys()
filtered_params = {
k: v for k, v in params.items() if k in valid_keys and v is not None
}
return cast(VectorStoreSearchOptionalRequestParams, filtered_params)
@staticmethod
def get_requested_vector_store_create_optional_param(
params: Dict[str, Any],
) -> VectorStoreCreateOptionalRequestParams:
"""
Filter parameters to only include those defined in VectorStoreCreateOptionalRequestParams.
Args:
params: Dictionary of parameters to filter
Returns:
VectorStoreCreateOptionalRequestParams instance with only the valid parameters
"""
valid_keys = get_type_hints(VectorStoreCreateOptionalRequestParams).keys()
filtered_params = {
k: v for k, v in params.items() if k in valid_keys and v is not None
}
return cast(VectorStoreCreateOptionalRequestParams, filtered_params)
+52 -26
View File
@@ -4046,7 +4046,8 @@
"litellm_provider": "mistral",
"mode": "chat",
"supports_assistant_prefill": true,
"supports_tool_choice": true
"supports_tool_choice": true,
"supports_response_schema": true
},
"mistral/mistral-small": {
"max_tokens": 8191,
@@ -4058,7 +4059,8 @@
"supports_function_calling": true,
"mode": "chat",
"supports_assistant_prefill": true,
"supports_tool_choice": true
"supports_tool_choice": true,
"supports_response_schema": true
},
"mistral/mistral-small-latest": {
"max_tokens": 8191,
@@ -4070,7 +4072,8 @@
"supports_function_calling": true,
"mode": "chat",
"supports_assistant_prefill": true,
"supports_tool_choice": true
"supports_tool_choice": true,
"supports_response_schema": true
},
"mistral/mistral-medium": {
"max_tokens": 8191,
@@ -4081,7 +4084,8 @@
"litellm_provider": "mistral",
"mode": "chat",
"supports_assistant_prefill": true,
"supports_tool_choice": true
"supports_tool_choice": true,
"supports_response_schema": true
},
"mistral/mistral-medium-latest": {
"max_tokens": 8191,
@@ -4093,7 +4097,8 @@
"mode": "chat",
"supports_function_calling": true,
"supports_assistant_prefill": true,
"supports_tool_choice": true
"supports_tool_choice": true,
"supports_response_schema": true
},
"mistral/mistral-medium-2505": {
"max_tokens": 8191,
@@ -4105,7 +4110,8 @@
"mode": "chat",
"supports_function_calling": true,
"supports_assistant_prefill": true,
"supports_tool_choice": true
"supports_tool_choice": true,
"supports_response_schema": true
},
"mistral/mistral-medium-2312": {
"max_tokens": 8191,
@@ -4116,7 +4122,8 @@
"litellm_provider": "mistral",
"mode": "chat",
"supports_assistant_prefill": true,
"supports_tool_choice": true
"supports_tool_choice": true,
"supports_response_schema": true
},
"mistral/mistral-large-latest": {
"max_tokens": 128000,
@@ -4128,7 +4135,8 @@
"mode": "chat",
"supports_function_calling": true,
"supports_assistant_prefill": true,
"supports_tool_choice": true
"supports_tool_choice": true,
"supports_response_schema": true
},
"mistral/mistral-large-2411": {
"max_tokens": 128000,
@@ -4140,7 +4148,8 @@
"mode": "chat",
"supports_function_calling": true,
"supports_assistant_prefill": true,
"supports_tool_choice": true
"supports_tool_choice": true,
"supports_response_schema": true
},
"mistral/mistral-large-2402": {
"max_tokens": 8191,
@@ -4152,7 +4161,8 @@
"mode": "chat",
"supports_function_calling": true,
"supports_assistant_prefill": true,
"supports_tool_choice": true
"supports_tool_choice": true,
"supports_response_schema": true
},
"mistral/mistral-large-2407": {
"max_tokens": 128000,
@@ -4164,7 +4174,8 @@
"mode": "chat",
"supports_function_calling": true,
"supports_assistant_prefill": true,
"supports_tool_choice": true
"supports_tool_choice": true,
"supports_response_schema": true
},
"mistral/pixtral-large-latest": {
"max_tokens": 128000,
@@ -4177,7 +4188,8 @@
"supports_function_calling": true,
"supports_assistant_prefill": true,
"supports_vision": true,
"supports_tool_choice": true
"supports_tool_choice": true,
"supports_response_schema": true
},
"mistral/pixtral-large-2411": {
"max_tokens": 128000,
@@ -4190,7 +4202,8 @@
"supports_function_calling": true,
"supports_assistant_prefill": true,
"supports_vision": true,
"supports_tool_choice": true
"supports_tool_choice": true,
"supports_response_schema": true
},
"mistral/pixtral-12b-2409": {
"max_tokens": 128000,
@@ -4203,7 +4216,8 @@
"supports_function_calling": true,
"supports_assistant_prefill": true,
"supports_vision": true,
"supports_tool_choice": true
"supports_tool_choice": true,
"supports_response_schema": true
},
"mistral/open-mistral-7b": {
"max_tokens": 8191,
@@ -4214,7 +4228,8 @@
"litellm_provider": "mistral",
"mode": "chat",
"supports_assistant_prefill": true,
"supports_tool_choice": true
"supports_tool_choice": true,
"supports_response_schema": true
},
"mistral/open-mixtral-8x7b": {
"max_tokens": 8191,
@@ -4226,7 +4241,8 @@
"mode": "chat",
"supports_function_calling": true,
"supports_assistant_prefill": true,
"supports_tool_choice": true
"supports_tool_choice": true,
"supports_response_schema": true
},
"mistral/open-mixtral-8x22b": {
"max_tokens": 8191,
@@ -4238,7 +4254,8 @@
"mode": "chat",
"supports_function_calling": true,
"supports_assistant_prefill": true,
"supports_tool_choice": true
"supports_tool_choice": true,
"supports_response_schema": true
},
"mistral/codestral-latest": {
"max_tokens": 8191,
@@ -4249,7 +4266,8 @@
"litellm_provider": "mistral",
"mode": "chat",
"supports_assistant_prefill": true,
"supports_tool_choice": true
"supports_tool_choice": true,
"supports_response_schema": true
},
"mistral/codestral-2405": {
"max_tokens": 8191,
@@ -4260,7 +4278,8 @@
"litellm_provider": "mistral",
"mode": "chat",
"supports_assistant_prefill": true,
"supports_tool_choice": true
"supports_tool_choice": true,
"supports_response_schema": true
},
"mistral/open-mistral-nemo": {
"max_tokens": 128000,
@@ -4272,7 +4291,8 @@
"mode": "chat",
"source": "https://mistral.ai/technology/",
"supports_assistant_prefill": true,
"supports_tool_choice": true
"supports_tool_choice": true,
"supports_response_schema": true
},
"mistral/open-mistral-nemo-2407": {
"max_tokens": 128000,
@@ -4284,7 +4304,8 @@
"mode": "chat",
"source": "https://mistral.ai/technology/",
"supports_assistant_prefill": true,
"supports_tool_choice": true
"supports_tool_choice": true,
"supports_response_schema": true
},
"mistral/open-codestral-mamba": {
"max_tokens": 256000,
@@ -4321,7 +4342,8 @@
"source": "https://mistral.ai/news/devstral",
"supports_function_calling": true,
"supports_assistant_prefill": true,
"supports_tool_choice": true
"supports_tool_choice": true,
"supports_response_schema": true
},
"mistral/magistral-medium-latest": {
"max_tokens": 40000,
@@ -4335,7 +4357,8 @@
"supports_function_calling": true,
"supports_assistant_prefill": true,
"supports_tool_choice": true,
"supports_reasoning": true
"supports_reasoning": true,
"supports_response_schema": true
},
"mistral/magistral-medium-2506": {
"max_tokens": 40000,
@@ -4349,7 +4372,8 @@
"supports_function_calling": true,
"supports_assistant_prefill": true,
"supports_tool_choice": true,
"supports_reasoning": true
"supports_reasoning": true,
"supports_response_schema": true
},
"mistral/magistral-small-latest": {
"max_tokens": 40000,
@@ -4363,7 +4387,8 @@
"supports_function_calling": true,
"supports_assistant_prefill": true,
"supports_tool_choice": true,
"supports_reasoning": true
"supports_reasoning": true,
"supports_response_schema": true
},
"mistral/magistral-small-2506": {
"max_tokens": 40000,
@@ -4377,7 +4402,8 @@
"supports_function_calling": true,
"supports_assistant_prefill": true,
"supports_tool_choice": true,
"supports_reasoning": true
"supports_reasoning": true,
"supports_response_schema": true
},
"mistral/mistral-embed": {
"max_tokens": 8192,
+8 -12
View File
@@ -1,24 +1,20 @@
import os
import json
gemini_model_cost_map = json.load(open("model_prices_and_context_window.json"))
mistral_model_cost_map = json.load(open("model_prices_and_context_window.json"))
for model, model_info in gemini_model_cost_map.items():
for model, model_info in mistral_model_cost_map.items():
if (
(
model_info.get("litellm_provider") == "gemini"
or model_info.get("litellm_provider") == "vertex_ai-language-models"
)
(model_info.get("litellm_provider") == "mistral")
and model_info.get("mode") == "chat"
and ("gemini-2.5" in model and "tts" not in model)
and model_info.get("supports_pdf_input") is None
and ("codestral-mamba" not in model)
):
"""
Update all gemini chat models to support pdf input
Update all mistral models to supports_response_schema
"""
model_info["supports_pdf_input"] = True
print(f"Updated {model} to support pdf input")
model_info["supports_response_schema"] = True
print(f"Updated {model} to support response schema")
json.dump(
gemini_model_cost_map, open("model_prices_and_context_window.json", "w"), indent=4
mistral_model_cost_map, open("model_prices_and_context_window.json", "w"), indent=4
)
+28 -1
View File
@@ -594,6 +594,7 @@ async def test_azure_embedding_max_retries_0(
def test_azure_safety_result():
"""Bubble up safety result from Azure OpenAI"""
from litellm import completion
litellm._turn_on_debug()
response = completion(
@@ -602,4 +603,30 @@ def test_azure_safety_result():
)
print(f"response: {response}")
assert response.choices[0].message.content is not None
assert response.choices[0].provider_specific_fields is not None
assert response.choices[0].provider_specific_fields is not None
def test_azure_openai_responses_bridge():
from litellm import completion
import litellm
litellm._turn_on_debug()
with patch.object(litellm, "responses") as mock_responses:
try:
response = completion(
model="azure/responses/test-azure-computer-use-preview",
messages=[{"role": "user", "content": "Hello world"}],
api_base=os.getenv("AZURE_COMPUTER_USE_API_BASE"),
api_version="2025-04-01-preview",
api_key=os.getenv("AZURE_COMPUTER_USE_API_KEY"),
)
except Exception as e:
print(e)
mock_responses.assert_called_once()
assert (
mock_responses.call_args.kwargs["model"]
== "test-azure-computer-use-preview"
)
assert mock_responses.call_args.kwargs["custom_llm_provider"] == "azure"
@@ -19,7 +19,7 @@ import pytest
import litellm
from litellm import completion
from litellm._logging import verbose_logger
from litellm.integrations.vector_stores.bedrock_vector_store import BedrockVectorStore
from litellm.integrations.vector_store_integrations.bedrock_vector_store import BedrockVectorStore
from litellm.llms.custom_httpx.http_handler import HTTPHandler, AsyncHTTPHandler
from litellm.integrations.custom_logger import CustomLogger
from litellm.types.utils import StandardLoggingPayload, StandardLoggingVectorStoreRequest
@@ -446,6 +446,40 @@ async def test_deployment_callback_on_failure(model_list):
)
def test_deployment_callback_respects_cooldown_time(model_list):
"""Ensure per-model cooldown_time is honored even when exception headers are present."""
import httpx
import time
from unittest.mock import patch
router = Router(model_list=model_list)
class FakeException(Exception):
def __init__(self):
self.status_code = 429
self.headers = httpx.Headers({"x-test": "1"})
kwargs = {
"exception": FakeException(),
"litellm_params": {
"metadata": {"model_group": "gpt-3.5-turbo"},
"model_info": {"id": 100},
"cooldown_time": 0,
},
}
with patch("litellm.router._set_cooldown_deployments") as mock_set:
router.deployment_callback_on_failure(
kwargs=kwargs,
completion_response=None,
start_time=time.time(),
end_time=time.time(),
)
mock_set.assert_called_once()
assert mock_set.call_args.kwargs["time_to_cooldown"] == 0
def test_log_retry(model_list):
"""Test if the '_log_retry' function is working correctly"""
import time
@@ -34,3 +34,33 @@ def test_anthropic_experimental_pass_through_messages_handler():
print(f"Error: {e}")
mock_completion.assert_called_once()
mock_completion.call_args.kwargs["api_key"] == "test-api-key"
def test_anthropic_experimental_pass_through_messages_handler_custom_llm_provider():
"""
Test that litellm.completion is called when a custom LLM provider is given
"""
from litellm.llms.anthropic.experimental_pass_through.messages.handler import (
anthropic_messages_handler,
)
with patch("litellm.completion", return_value="test-response") as mock_completion:
try:
anthropic_messages_handler(
max_tokens=100,
messages=[{"role": "user", "content": "Hello, how are you?"}],
model="my-custom-model",
custom_llm_provider="my-custom-llm",
api_key="test-api-key",
)
except Exception as e:
print(f"Error: {e}")
# Assert that litellm.completion was called when using a custom LLM provider
mock_completion.assert_called_once()
# Verify that the custom provider was passed through
call_kwargs = mock_completion.call_args.kwargs
assert call_kwargs["custom_llm_provider"] == "my-custom-llm"
assert call_kwargs["model"] == "my-custom-llm/my-custom-model"
assert call_kwargs["api_key"] == "test-api-key"
@@ -54,9 +54,9 @@ def test_validate_environment_azure_key_within_litellm():
def test_validate_environment_azure_openai_api_key_within_secret_str():
azure_openai_responses_apiconfig = AzureOpenAIResponsesAPIConfig()
with patch(
"litellm.llms.azure.responses.transformation.get_secret_str"
) as mock_get_secret_str:
with patch("litellm.api_key", None), \
patch("litellm.azure_key", None), \
patch("litellm.llms.azure.common_utils.get_secret_str") as mock_get_secret_str:
# Configure the mock to return "test-api-key" when called with "AZURE_OPENAI_API_KEY"
mock_get_secret_str.side_effect = (
lambda key: "test-api-key" if key == "AZURE_OPENAI_API_KEY" else None
@@ -74,13 +74,19 @@ def test_validate_environment_azure_openai_api_key_within_secret_str():
def test_validate_environment_azure_api_key_within_secret_str():
azure_openai_responses_apiconfig = AzureOpenAIResponsesAPIConfig()
with patch(
"litellm.llms.azure.responses.transformation.get_secret_str"
) as mock_get_secret_str:
# Configure the mock to return "test-api-key" when called with "AZURE_API_KEY"
mock_get_secret_str.side_effect = (
lambda key: "test-api-key" if key == "AZURE_API_KEY" else None
)
with patch("litellm.api_key", None), \
patch("litellm.azure_key", None), \
patch("litellm.llms.azure.common_utils.get_secret_str") as mock_get_secret_str:
# Configure the mock to return None for "AZURE_OPENAI_API_KEY" and "test-api-key" for "AZURE_API_KEY"
def mock_side_effect(key):
if key == "AZURE_OPENAI_API_KEY":
return None
elif key == "AZURE_API_KEY":
return "test-api-key"
else:
return None
mock_get_secret_str.side_effect = mock_side_effect
litellm_params = GenericLiteLLMParams()
result = azure_openai_responses_apiconfig.validate_environment(
@@ -95,10 +95,6 @@ class TestMistralReasoningSupport:
def test_get_mistral_reasoning_system_prompt(self):
"""Test that the reasoning system prompt is properly formatted."""
prompt = MistralConfig._get_mistral_reasoning_system_prompt()
assert "<think>" in prompt
assert "</think>" in prompt
assert "step-by-step" in prompt
assert isinstance(prompt, str)
assert len(prompt) > 50 # Ensure it's not empty
@@ -1,8 +1,9 @@
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from litellm.types.utils import EmbeddingResponse
from litellm.llms.ollama.completion.handler import ollama_embeddings, ollama_aembeddings
import pytest
from litellm.llms.ollama.completion.handler import ollama_aembeddings, ollama_embeddings
from litellm.types.utils import EmbeddingResponse
@pytest.fixture
@@ -26,8 +27,9 @@ def mock_encoding():
def test_ollama_embeddings(mock_response_data, mock_embedding_response, mock_encoding):
with patch("litellm.module_level_client.post") as mock_post, \
patch("litellm.OllamaConfig.get_config", return_value={"truncate": 512}):
with patch("litellm.module_level_client.post") as mock_post, patch(
"litellm.OllamaConfig.get_config", return_value={"truncate": 512}
):
mock_response = MagicMock()
mock_response.json.return_value = mock_response_data
@@ -50,11 +52,17 @@ def test_ollama_embeddings(mock_response_data, mock_embedding_response, mock_enc
@pytest.mark.asyncio
async def test_ollama_aembeddings(mock_response_data, mock_embedding_response, mock_encoding):
with patch("litellm.module_level_aclient.post", new_callable=AsyncMock) as mock_post, \
patch("litellm.OllamaConfig.get_config", return_value={"truncate": 512}):
mock_post.return_value.json.return_value = mock_response_data
async def test_ollama_aembeddings(
mock_response_data, mock_embedding_response, mock_encoding
):
mock_response = AsyncMock()
# Make json() a regular synchronous method, not async
mock_response.json = MagicMock(return_value=mock_response_data)
with patch(
"litellm.module_level_aclient.post", return_value=mock_response
) as mock_post, patch(
"litellm.OllamaConfig.get_config", return_value={"truncate": 512}
):
response = await ollama_aembeddings(
api_base="http://localhost:11434",
@@ -78,9 +86,10 @@ def test_prompt_eval_fallback_when_missing(mock_embedding_response, mock_encodin
# No "prompt_eval_count"
}
with patch("litellm.module_level_client.post") as mock_post, \
patch("litellm.OllamaConfig.get_config", return_value={}):
with patch("litellm.module_level_client.post") as mock_post, patch(
"litellm.OllamaConfig.get_config", return_value={}
):
mock_response = MagicMock()
mock_response.json.return_value = response_data
mock_post.return_value = mock_response
@@ -99,4 +108,4 @@ def test_prompt_eval_fallback_when_missing(mock_embedding_response, mock_encodin
assert response.usage.prompt_tokens == 5
assert response.usage.total_tokens == 5
assert response.usage.completion_tokens == 0
assert response.data[0]['embedding'] == [0.1, 0.2, 0.3]
assert response.data[0]["embedding"] == [0.1, 0.2, 0.3]
+47 -2
View File
@@ -1,5 +1,6 @@
import os
import sys
from unittest.mock import MagicMock, patch
from pydantic import BaseModel
@@ -23,7 +24,9 @@ class TestVolcEngineConfig:
)
assert mapped_params == {
"thinking": {"type": "disabled"},
"extra_body": {
"thinking": {"type": "disabled"},
}
}
e2e_mapped_params = get_optional_params(
@@ -33,6 +36,48 @@ class TestVolcEngineConfig:
drop_params=False,
)
assert "thinking" in e2e_mapped_params and e2e_mapped_params["thinking"] == {
assert "thinking" in e2e_mapped_params["extra_body"] and e2e_mapped_params[
"extra_body"
]["thinking"] == {
"type": "enabled",
}
def test_e2e_completion(self):
from openai import OpenAI
from litellm import completion
from litellm.types.utils import ModelResponse
client = OpenAI(api_key="test_api_key")
mock_raw_response = MagicMock()
mock_raw_response.headers = {
"x-request-id": "123",
"openai-organization": "org-123",
"x-ratelimit-limit-requests": "100",
"x-ratelimit-remaining-requests": "99",
}
mock_raw_response.parse.return_value = ModelResponse()
with patch.object(
client.chat.completions.with_raw_response, "create", mock_raw_response
) as mock_create:
completion(
model="volcengine/doubao-seed-1.6",
messages=[
{
"role": "system",
"content": "**Tell me your model detail information.**",
}
],
user="guest",
stream=True,
thinking={"type": "disabled"},
client=client,
)
mock_create.assert_called_once()
print(mock_create.call_args.kwargs)
assert mock_create.call_args.kwargs["extra_body"] == {
"thinking": {"type": "disabled"},
}
@@ -3,7 +3,7 @@ from unittest.mock import AsyncMock
import pytest
from fastapi import HTTPException
from litellm.proxy._types import NewUserRequest, ProxyException
from litellm.proxy._types import LitellmUserRoles, NewUserRequest, ProxyException
from litellm.proxy.management_endpoints.scim.scim_v2 import (
UserProvisionerHelpers,
_handle_team_membership_changes,
@@ -58,6 +58,94 @@ async def test_create_user_existing_user_conflict(mocker):
mocked_new_user.assert_not_called()
@pytest.mark.asyncio
async def test_create_user_defaults_to_viewer(mocker, monkeypatch):
"""If no role provided, new user should default to viewer"""
scim_user = SCIMUser(
schemas=["urn:ietf:params:scim:schemas:core:2.0:User"],
userName="new-user",
name=SCIMUserName(familyName="User", givenName="New"),
emails=[SCIMUserEmail(value="new@example.com")],
)
mock_prisma_client = mocker.MagicMock()
mock_prisma_client.db = mocker.MagicMock()
mock_prisma_client.db.litellm_usertable = mocker.MagicMock()
mock_prisma_client.db.litellm_usertable.find_unique = AsyncMock(return_value=None)
mock_prisma_client.db.litellm_usertable.find_first = AsyncMock(return_value=None)
monkeypatch.setattr(
"litellm.default_internal_user_params", None, raising=False
)
mocker.patch(
"litellm.proxy.management_endpoints.scim.scim_v2._get_prisma_client_or_raise_exception",
AsyncMock(return_value=mock_prisma_client),
)
new_user_mock = mocker.patch(
"litellm.proxy.management_endpoints.scim.scim_v2.new_user",
AsyncMock(return_value=NewUserRequest(user_id="new-user")),
)
mocker.patch(
"litellm.proxy.management_endpoints.scim.scim_v2.ScimTransformations.transform_litellm_user_to_scim_user",
AsyncMock(return_value=scim_user),
)
await create_user(user=scim_user)
called_args = new_user_mock.call_args.kwargs["data"]
assert called_args.user_role == LitellmUserRoles.INTERNAL_USER_VIEW_ONLY
@pytest.mark.asyncio
async def test_create_user_uses_default_internal_user_params_role(mocker, monkeypatch):
"""If role is set in default_internal_user_params, new user should use that role"""
scim_user = SCIMUser(
schemas=["urn:ietf:params:scim:schemas:core:2.0:User"],
userName="new-user",
name=SCIMUserName(familyName="User", givenName="New"),
emails=[SCIMUserEmail(value="new@example.com")],
)
mock_prisma_client = mocker.MagicMock()
mock_prisma_client.db = mocker.MagicMock()
mock_prisma_client.db.litellm_usertable = mocker.MagicMock()
mock_prisma_client.db.litellm_usertable.find_unique = AsyncMock(return_value=None)
mock_prisma_client.db.litellm_usertable.find_first = AsyncMock(return_value=None)
# Set default_internal_user_params with a specific role
default_params = {
"user_role": LitellmUserRoles.PROXY_ADMIN,
}
monkeypatch.setattr(
"litellm.default_internal_user_params", default_params, raising=False
)
mocker.patch(
"litellm.proxy.management_endpoints.scim.scim_v2._get_prisma_client_or_raise_exception",
AsyncMock(return_value=mock_prisma_client),
)
new_user_mock = mocker.patch(
"litellm.proxy.management_endpoints.scim.scim_v2.new_user",
AsyncMock(return_value=NewUserRequest(user_id="new-user")),
)
mocker.patch(
"litellm.proxy.management_endpoints.scim.scim_v2.ScimTransformations.transform_litellm_user_to_scim_user",
AsyncMock(return_value=scim_user),
)
await create_user(user=scim_user)
called_args = new_user_mock.call_args.kwargs["data"]
assert called_args.user_role == LitellmUserRoles.PROXY_ADMIN
@pytest.mark.asyncio
async def test_handle_existing_user_by_email_no_email(mocker):
"""Should return None when new_user_request has no email"""
@@ -696,11 +696,12 @@ async def test_get_generic_sso_response_with_additional_headers():
"fastapi_sso.sso.generic.create_provider", return_value=mock_sso_class
) as mock_create_provider:
# Act
result = await get_generic_sso_response(
result, received_response = await get_generic_sso_response(
request=mock_request,
jwt_handler=mock_jwt_handler,
generic_client_id=generic_client_id,
redirect_url=redirect_url,
sso_jwt_handler=None,
)
# Assert
@@ -756,11 +757,12 @@ async def test_get_generic_sso_response_with_empty_headers():
"fastapi_sso.sso.generic.create_provider", return_value=mock_sso_class
) as mock_create_provider:
# Act
result = await get_generic_sso_response(
result, received_response = await get_generic_sso_response(
request=mock_request,
jwt_handler=mock_jwt_handler,
generic_client_id=generic_client_id,
redirect_url=redirect_url,
sso_jwt_handler=None,
)
# Assert
@@ -102,3 +102,25 @@ async def test_route_request_no_model_required_with_router_settings():
# Reset the mock for the next route
llm_router.reset_mock()
@pytest.mark.asyncio
async def test_route_request_no_model_required_with_router_settings_and_no_router():
"""Test route types that don't require model parameter with router settings and no router"""
from unittest.mock import patch
import litellm
from litellm.proxy.route_llm_request import route_request
data = {
"model": "my-model-id",
"api_key": "my-api-key",
"messages": [{"role": "user", "content": "what llm are you"}],
}
with patch.object(
litellm, "acompletion", return_value="fake_response"
) as mock_completion:
response = await route_request(data, None, "gpt-3.5-turbo", "acompletion")
mock_completion.assert_called_once_with(**data)
@@ -0,0 +1,255 @@
import httpx
import json
import pytest
import sys
from typing import Any, Dict, List
from unittest.mock import MagicMock, Mock, patch
import os
import uuid
import time
import base64
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import litellm
from abc import ABC, abstractmethod
from litellm.integrations.custom_logger import CustomLogger
import json
from litellm.types.utils import StandardLoggingPayload
class BaseVectorStoreTest(ABC):
"""
Abstract base test class that enforces a common test across all test classes.
"""
@abstractmethod
def get_base_request_args(self) -> dict:
"""Must return the base request args"""
pass
@abstractmethod
def get_base_create_vector_store_args(self) -> dict:
"""Must return the base create vector store args"""
pass
@pytest.mark.parametrize("sync_mode", [True, False])
@pytest.mark.asyncio
async def test_basic_search_vector_store(self, sync_mode):
litellm._turn_on_debug()
litellm.set_verbose = True
base_request_args = self.get_base_request_args()
try:
if sync_mode:
response = litellm.vector_stores.search(
query="Basic ping",
**base_request_args
)
else:
response = await litellm.vector_stores.asearch(
query="Basic ping",
**base_request_args
)
except litellm.InternalServerError:
pytest.skip("Skipping test due to litellm.InternalServerError")
print("litellm response=", json.dumps(response, indent=4, default=str))
# Validate response structure
self._validate_vector_store_response(response)
@pytest.mark.parametrize("sync_mode", [True, False])
@pytest.mark.asyncio
async def test_basic_create_vector_store(self, sync_mode):
litellm._turn_on_debug()
litellm.set_verbose = True
base_request_args = self.get_base_create_vector_store_args()
# Extract custom_llm_provider from base args if present
create_args = base_request_args
try:
if sync_mode:
response = litellm.vector_stores.create(
name="Test Vector Store",
**create_args
)
else:
response = await litellm.vector_stores.acreate(
name="Test Vector Store",
**create_args
)
except litellm.InternalServerError:
pytest.skip("Skipping test due to litellm.InternalServerError")
except Exception as e:
# If this is an authentication or permission error, skip the test
if "authentication" in str(e).lower() or "permission" in str(e).lower() or "unauthorized" in str(e).lower():
pytest.skip(f"Skipping test due to authentication/permission error: {e}")
raise
print("litellm create response=", json.dumps(response, indent=4, default=str))
# Validate response structure
self._validate_vector_store_create_response(response)
def _validate_vector_store_response(self, response):
"""Validate the structure and content of a vector store search response"""
# Check that response is a dictionary
assert isinstance(response, dict), f"Response should be a dict, got {type(response)}"
# Check required top-level fields
required_fields = ['object', 'search_query', 'data']
for field in required_fields:
assert field in response, f"Missing required field '{field}' in response"
# Validate object field
assert response['object'] == 'vector_store.search_results.page', \
f"Expected object to be 'vector_store.search_results.page', got '{response['object']}'"
# Validate search_query field
assert isinstance(response['search_query'], list), \
f"search_query should be a list, got {type(response['search_query'])}"
assert len(response['search_query']) > 0, "search_query should not be empty"
assert all(isinstance(query, str) for query in response['search_query']), \
"All items in search_query should be strings"
# Validate data field
assert isinstance(response['data'], list), \
f"data should be a list, got {type(response['data'])}"
# Validate each result in data
for i, result in enumerate(response['data']):
self._validate_search_result(result, i)
print(f"✅ Response validation passed: Found {len(response['data'])} search results")
def _validate_vector_store_create_response(self, response):
"""Validate the structure and content of a vector store create response"""
# Check that response is a dictionary
assert isinstance(response, dict), f"Response should be a dict, got {type(response)}"
# Check required top-level fields for create response
required_fields = ['id', 'object', 'created_at']
for field in required_fields:
assert field in response, f"Missing required field '{field}' in create response"
# Validate object field
assert response['object'] == 'vector_store', \
f"Expected object to be 'vector_store', got '{response['object']}'"
# Validate id field
assert isinstance(response['id'], str), \
f"id should be a string, got {type(response['id'])}"
assert len(response['id']) > 0, "id should not be empty"
assert response['id'].startswith('vs_'), \
f"id should start with 'vs_', got '{response['id']}'"
# Validate created_at field
assert isinstance(response['created_at'], int), \
f"created_at should be an integer, got {type(response['created_at'])}"
assert response['created_at'] > 0, "created_at should be a positive timestamp"
# Validate optional fields if present
if 'name' in response:
assert isinstance(response['name'], str), \
f"name should be a string, got {type(response['name'])}"
if 'bytes' in response:
assert isinstance(response['bytes'], int), \
f"bytes should be an integer, got {type(response['bytes'])}"
assert response['bytes'] >= 0, "bytes should be non-negative"
if 'file_counts' in response:
self._validate_file_counts(response['file_counts'])
if 'status' in response:
valid_statuses = ['expired', 'in_progress', 'completed']
assert response['status'] in valid_statuses, \
f"status should be one of {valid_statuses}, got '{response['status']}'"
if 'expires_at' in response and response['expires_at'] is not None:
assert isinstance(response['expires_at'], int), \
f"expires_at should be an integer, got {type(response['expires_at'])}"
if 'last_active_at' in response and response['last_active_at'] is not None:
assert isinstance(response['last_active_at'], int), \
f"last_active_at should be an integer, got {type(response['last_active_at'])}"
if 'metadata' in response and response['metadata'] is not None:
assert isinstance(response['metadata'], dict), \
f"metadata should be a dict, got {type(response['metadata'])}"
print(f"✅ Create response validation passed: Vector store '{response['id']}' created successfully")
def _validate_file_counts(self, file_counts):
"""Validate file_counts structure"""
assert isinstance(file_counts, dict), \
f"file_counts should be a dict, got {type(file_counts)}"
required_count_fields = ['in_progress', 'completed', 'failed', 'cancelled', 'total']
for field in required_count_fields:
assert field in file_counts, f"Missing required field '{field}' in file_counts"
assert isinstance(file_counts[field], int), \
f"{field} should be an integer, got {type(file_counts[field])}"
assert file_counts[field] >= 0, f"{field} should be non-negative"
# Validate that total equals sum of other counts
calculated_total = (
file_counts['in_progress'] +
file_counts['completed'] +
file_counts['failed'] +
file_counts['cancelled']
)
assert file_counts['total'] == calculated_total, \
f"total should equal sum of other counts ({calculated_total}), got {file_counts['total']}"
def _validate_search_result(self, result, index):
"""Validate an individual search result"""
# Check that result is a dictionary
assert isinstance(result, dict), f"Result {index} should be a dict, got {type(result)}"
# Check required fields in each result
required_result_fields = ['file_id', 'filename', 'score', 'attributes', 'content']
for field in required_result_fields:
assert field in result, f"Missing required field '{field}' in result {index}"
# Validate file_id
assert isinstance(result['file_id'], str), \
f"file_id should be a string, got {type(result['file_id'])} in result {index}"
assert len(result['file_id']) > 0, f"file_id should not be empty in result {index}"
# Validate filename
assert isinstance(result['filename'], str), \
f"filename should be a string, got {type(result['filename'])} in result {index}"
assert len(result['filename']) > 0, f"filename should not be empty in result {index}"
# Validate score
assert isinstance(result['score'], (int, float)), \
f"score should be a number, got {type(result['score'])} in result {index}"
assert 0.0 <= result['score'] <= 1.0, \
f"score should be between 0.0 and 1.0, got {result['score']} in result {index}"
# Validate attributes
assert isinstance(result['attributes'], dict), \
f"attributes should be a dict, got {type(result['attributes'])} in result {index}"
# Validate content
assert isinstance(result['content'], list), \
f"content should be a list, got {type(result['content'])} in result {index}"
assert len(result['content']) > 0, f"content should not be empty in result {index}"
# Validate each content item
for j, content_item in enumerate(result['content']):
assert isinstance(content_item, dict), \
f"Content item {j} in result {index} should be a dict, got {type(content_item)}"
assert 'type' in content_item, \
f"Content item {j} in result {index} missing 'type' field"
assert 'text' in content_item, \
f"Content item {j} in result {index} missing 'text' field"
assert isinstance(content_item['text'], str), \
f"Content text should be a string in item {j} of result {index}"
assert len(content_item['text']) > 0, \
f"Content text should not be empty in item {j} of result {index}"
print(f"✅ Result {index} validation passed: {result['filename']} (score: {result['score']:.4f})")
+63
View File
@@ -0,0 +1,63 @@
# conftest.py
import importlib
import os
import sys
import pytest
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import litellm
@pytest.fixture(scope="function", autouse=True)
def setup_and_teardown():
"""
This fixture reloads litellm before every function. To speed up testing by removing callbacks being chained.
"""
curr_dir = os.getcwd() # Get the current working directory
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the project directory to the system path
import litellm
from litellm import Router
importlib.reload(litellm)
try:
if hasattr(litellm, "proxy") and hasattr(litellm.proxy, "proxy_server"):
import litellm.proxy.proxy_server
importlib.reload(litellm.proxy.proxy_server)
except Exception as e:
print(f"Error reloading litellm.proxy.proxy_server: {e}")
import asyncio
loop = asyncio.get_event_loop_policy().new_event_loop()
asyncio.set_event_loop(loop)
print(litellm)
# from litellm import Router, completion, aembedding, acompletion, embedding
yield
# Teardown code (executes after the yield point)
loop.close() # Close the loop created earlier
asyncio.set_event_loop(None) # Remove the reference to the loop
def pytest_collection_modifyitems(config, items):
# Separate tests in 'test_amazing_proxy_custom_logger.py' and other tests
custom_logger_tests = [
item for item in items if "custom_logger" in item.parent.name
]
other_tests = [item for item in items if "custom_logger" not in item.parent.name]
# Sort tests based on their names
custom_logger_tests.sort(key=lambda x: x.name)
other_tests.sort(key=lambda x: x.name)
# Reorder the items list
items[:] = custom_logger_tests + other_tests
@@ -0,0 +1,25 @@
from base_vector_store_test import BaseVectorStoreTest
import os
import pytest
class TestAzureOpenAIVectorStore(BaseVectorStoreTest):
def get_base_request_args(self) -> dict:
"""Must return the base request args"""
return {}
@pytest.mark.parametrize("sync_mode", [True, False])
@pytest.mark.asyncio
async def test_basic_search_vector_store(self, sync_mode):
pass
def get_base_create_vector_store_args(self) -> dict:
"""
This is a real vector store on Azure
"""
return {
"custom_llm_provider": "azure",
"api_base": os.getenv("AZURE_RESPONSES_OPENAI_ENDPOINT"),
"api_key": os.getenv("AZURE_RESPONSES_OPENAI_API_KEY"),
"api_version": "2025-04-01-preview",
}
@@ -0,0 +1,20 @@
from base_vector_store_test import BaseVectorStoreTest
class TestOpenAIVectorStore(BaseVectorStoreTest):
def get_base_request_args(self) -> dict:
"""
This is a real vector store on OpenAI
"""
return {
"vector_store_id": "vs_685b14b1a1b88191bc27e04f1917fddd",
"custom_llm_provider": "openai",
}
def get_base_create_vector_store_args(self) -> dict:
"""
This is a real vector store on OpenAI
"""
return {
"custom_llm_provider": "openai",
}
@@ -132,7 +132,7 @@ const SSOModals: React.FC<SSOModalsProps> = ({
}
}
// Set form values with existing data
// Set form values with existing data (excluding UI access control fields)
const formValues = {
sso_provider: selectedProvider,
proxy_base_url: ssoData.values.proxy_base_url,
@@ -0,0 +1,148 @@
import React, { useEffect, useState } from "react";
import { Form, Button as Button2, Select, message } from "antd";
import { Text, TextInput } from "@tremor/react";
import { getSSOSettings, updateSSOSettings } from "./networking";
interface UIAccessControlFormProps {
accessToken: string | null;
onSuccess: () => void;
}
// Separate UI Access Control Form Component
const UIAccessControlForm: React.FC<UIAccessControlFormProps> = ({ accessToken, onSuccess }) => {
const [form] = Form.useForm();
const [loading, setLoading] = useState(false);
// Load existing UI access control settings
useEffect(() => {
const loadUIAccessSettings = async () => {
if (accessToken) {
try {
const ssoData = await getSSOSettings(accessToken);
if (ssoData && ssoData.values) {
// Handle nested ui_access_mode structure
const uiAccessMode = ssoData.values.ui_access_mode;
let formValues = {};
if (uiAccessMode && typeof uiAccessMode === 'object') {
formValues = {
ui_access_mode_type: uiAccessMode.type,
restricted_sso_group: uiAccessMode.restricted_sso_group,
sso_group_jwt_field: uiAccessMode.sso_group_jwt_field,
};
} else if (typeof uiAccessMode === 'string') {
// Handle legacy flat structure
formValues = {
ui_access_mode_type: uiAccessMode,
restricted_sso_group: ssoData.values.restricted_sso_group,
sso_group_jwt_field: ssoData.values.team_ids_jwt_field || ssoData.values.sso_group_jwt_field,
};
}
form.setFieldsValue(formValues);
}
} catch (error) {
console.error("Failed to load UI access settings:", error);
}
}
};
loadUIAccessSettings();
}, [accessToken, form]);
const handleUIAccessSubmit = async (formValues: Record<string, any>) => {
if (!accessToken) {
message.error("No access token available");
return;
}
setLoading(true);
try {
// Transform form data to match API expected structure
const apiPayload = {
ui_access_mode: {
type: formValues.ui_access_mode_type,
restricted_sso_group: formValues.restricted_sso_group,
sso_group_jwt_field: formValues.sso_group_jwt_field,
}
};
await updateSSOSettings(accessToken, apiPayload);
onSuccess();
} catch (error) {
console.error("Failed to save UI access settings:", error);
message.error("Failed to save UI access settings");
} finally {
setLoading(false);
}
};
return (
<div style={{ padding: '16px' }}>
<div style={{ marginBottom: '16px' }}>
<Text style={{ fontSize: '14px', color: '#6b7280' }}>
Configure who can access the UI interface and how group information is extracted from JWT tokens.
</Text>
</div>
<Form
form={form}
onFinish={handleUIAccessSubmit}
layout="vertical"
>
<Form.Item
label="UI Access Mode"
name="ui_access_mode_type"
tooltip="Controls who can access the UI interface"
>
<Select placeholder="Select access mode">
<Select.Option value="all_authenticated_users">All Authenticated Users</Select.Option>
<Select.Option value="restricted_sso_group">Restricted SSO Group</Select.Option>
</Select>
</Form.Item>
<Form.Item
noStyle
shouldUpdate={(prevValues, currentValues) => prevValues.ui_access_mode_type !== currentValues.ui_access_mode_type}
>
{({ getFieldValue }) => {
const uiAccessModeType = getFieldValue('ui_access_mode_type');
return uiAccessModeType === 'restricted_sso_group' ? (
<Form.Item
label="Restricted SSO Group"
name="restricted_sso_group"
rules={[{ required: true, message: "Please enter the restricted SSO group" }]}
>
<TextInput placeholder="ui-access-group" />
</Form.Item>
) : null;
}}
</Form.Item>
<Form.Item
label="SSO Group JWT Field"
name="sso_group_jwt_field"
tooltip="JWT field name that contains team/group information. Use dot notation to access nested fields."
>
<TextInput placeholder="groups" />
</Form.Item>
<div style={{ textAlign: "right", marginTop: "16px" }}>
<Button2
type="primary"
htmlType="submit"
loading={loading}
style={{
backgroundColor: '#6366f1',
borderColor: '#6366f1'
}}
>
Update UI Access Control
</Button2>
</div>
</Form>
</div>
);
};
export default UIAccessControlForm;
@@ -44,6 +44,7 @@ import { InvitationLink } from "./onboarding_link";
import SSOModals from "./SSOModals";
import { ssoProviderConfigs } from './SSOModals';
import SCIMConfig from "./SCIM";
import UIAccessControlForm from "./UIAccessControlForm";
interface AdminPanelProps {
searchParams: any;
@@ -97,6 +98,7 @@ const AdminPanel: React.FC<AdminPanelProps> = ({
const [isAllowedIPModalVisible, setIsAllowedIPModalVisible] = useState(false);
const [isAddIPModalVisible, setIsAddIPModalVisible] = useState(false);
const [isDeleteIPModalVisible, setIsDeleteIPModalVisible] = useState(false);
const [isUIAccessControlModalVisible, setIsUIAccessControlModalVisible] = useState(false);
const [allowedIPs, setAllowedIPs] = useState<string[]>([]);
const [ipToDelete, setIPToDelete] = useState<string | null>(null);
const [ssoConfigured, setSsoConfigured] = useState<boolean>(false);
@@ -532,6 +534,14 @@ const AdminPanel: React.FC<AdminPanelProps> = ({
}
};
const handleUIAccessControlOk = () => {
setIsUIAccessControlModalVisible(false);
};
const handleUIAccessControlCancel = () => {
setIsUIAccessControlModalVisible(false);
};
console.log(`admins: ${admins?.length}`);
return (
<div className="w-full m-2 mt-2 p-8">
@@ -563,6 +573,14 @@ const AdminPanel: React.FC<AdminPanelProps> = ({
Allowed IPs
</Button>
</div>
<div>
<Button
style={{ width: '150px' }}
onClick={() => premiumUser === true ? setIsUIAccessControlModalVisible(true) : message.error("Only premium users can configure UI access control")}
>
UI Access Control
</Button>
</div>
</div>
</Card>
@@ -654,6 +672,24 @@ const AdminPanel: React.FC<AdminPanelProps> = ({
>
<p>Are you sure you want to delete the IP address: {ipToDelete}?</p>
</Modal>
{/* UI Access Control Modal */}
<Modal
title="UI Access Control Settings"
visible={isUIAccessControlModalVisible}
width={600}
footer={null}
onOk={handleUIAccessControlOk}
onCancel={handleUIAccessControlCancel}
>
<UIAccessControlForm
accessToken={accessToken}
onSuccess={() => {
handleUIAccessControlOk();
message.success("UI Access Control settings updated successfully");
}}
/>
</Modal>
</div>
<Callout title="Login without SSO" color="teal">
If you need to login without sso, you can access{" "}
@@ -260,6 +260,10 @@ const TeamInfoView: React.FC<TeamInfoProps> = ({
updateData.team_member_budget = Number(values.team_member_budget);
}
if (values.team_member_key_duration !== undefined) {
updateData.team_member_key_duration = values.team_member_key_duration;
}
// Handle object_permission updates
if (values.vector_stores !== undefined || values.mcp_servers !== undefined) {
updateData.object_permission = {
@@ -453,6 +457,15 @@ const TeamInfoView: React.FC<TeamInfoProps> = ({
<NumericalInput step={0.01} precision={2} style={{ width: "100%" }} />
</Form.Item>
<Form.Item label="Team Member Key Duration" name="team_member_key_duration" tooltip="Set a limit to the duration of a team member's key.">
<Select placeholder="n/a">
<Select.Option value="1d">1 day</Select.Option>
<Select.Option value="1w">1 week</Select.Option>
<Select.Option value="1mo">1 month</Select.Option>
</Select>
</Form.Item>
<Form.Item label="Reset Budget" name="budget_duration">
<Select placeholder="n/a">
<Select.Option value="24h">daily</Select.Option>
@@ -561,9 +574,19 @@ const TeamInfoView: React.FC<TeamInfoProps> = ({
<div>RPM: {info.rpm_limit || 'Unlimited'}</div>
</div>
<div>
<Text className="font-medium">Budget</Text>
<div>Max: {info.max_budget !== null ? `$${info.max_budget}` : 'No Limit'}</div>
<div>Reset: {info.budget_duration || 'Never'}</div>
<Text className="font-medium">Team Budget</Text>
<div>Max Budget: {info.max_budget !== null ? `$${info.max_budget}` : 'No Limit'}</div>
<div>Budget Reset: {info.budget_duration || 'Never'}</div>
</div>
<div>
<Text className="font-medium">
Team Member Settings{' '}
<Tooltip title="These are limits on individual team members">
<InfoCircleOutlined style={{ marginLeft: '4px' }} />
</Tooltip>
</Text>
<div>Max Budget: {info.team_member_budget_table?.max_budget || 'No Limit'}</div>
<div>Key Duration: {info.metadata?.team_member_key_duration || 'No Limit'}</div>
</div>
<div>
<Text className="font-medium">Organization ID</Text>
@@ -1060,6 +1060,17 @@ const Teams: React.FC<TeamProps> = ({
>
<NumericalInput step={0.01} precision={2} width={200} />
</Form.Item>
<Form.Item
label="Team Member Key Duration"
name="team_member_key_duration"
tooltip="Set a limit to the duration of a team member's key."
>
<Select2 defaultValue={null} placeholder="n/a">
<Select2.Option value="1d">1 day</Select2.Option>
<Select2.Option value="1w">1 week</Select2.Option>
<Select2.Option value="1mo">1 month</Select2.Option>
</Select2>
</Form.Item>
<Form.Item label="Metadata" name="metadata" help="Additional team metadata. Enter metadata as JSON object.">
<Input.TextArea rows={4} />
</Form.Item>