mirror of
https://github.com/tiennm99/litellm.git
synced 2026-06-26 03:08:43 +00:00
Merge branch 'BerriAI:main' into fix-today-selector-date-mutation-bug
This commit is contained in:
@@ -0,0 +1,89 @@
|
||||
# CLAUDE.md
|
||||
|
||||
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
|
||||
|
||||
## Development Commands
|
||||
|
||||
### Installation
|
||||
- `make install-dev` - Install core development dependencies
|
||||
- `make install-proxy-dev` - Install proxy development dependencies with full feature set
|
||||
- `make install-test-deps` - Install all test dependencies
|
||||
|
||||
### Testing
|
||||
- `make test` - Run all tests
|
||||
- `make test-unit` - Run unit tests (tests/test_litellm) with 4 parallel workers
|
||||
- `make test-integration` - Run integration tests (excludes unit tests)
|
||||
- `pytest tests/` - Direct pytest execution
|
||||
|
||||
### Code Quality
|
||||
- `make lint` - Run all linting (Ruff, MyPy, Black, circular imports, import safety)
|
||||
- `make format` - Apply Black code formatting
|
||||
- `make lint-ruff` - Run Ruff linting only
|
||||
- `make lint-mypy` - Run MyPy type checking only
|
||||
|
||||
### Single Test Files
|
||||
- `poetry run pytest tests/path/to/test_file.py -v` - Run specific test file
|
||||
- `poetry run pytest tests/path/to/test_file.py::test_function -v` - Run specific test
|
||||
|
||||
## Architecture Overview
|
||||
|
||||
LiteLLM is a unified interface for 100+ LLM providers with two main components:
|
||||
|
||||
### Core Library (`litellm/`)
|
||||
- **Main entry point**: `litellm/main.py` - Contains core completion() function
|
||||
- **Provider implementations**: `litellm/llms/` - Each provider has its own subdirectory
|
||||
- **Router system**: `litellm/router.py` + `litellm/router_utils/` - Load balancing and fallback logic
|
||||
- **Type definitions**: `litellm/types/` - Pydantic models and type hints
|
||||
- **Integrations**: `litellm/integrations/` - Third-party observability, caching, logging
|
||||
- **Caching**: `litellm/caching/` - Multiple cache backends (Redis, in-memory, S3, etc.)
|
||||
|
||||
### Proxy Server (`litellm/proxy/`)
|
||||
- **Main server**: `proxy_server.py` - FastAPI application
|
||||
- **Authentication**: `auth/` - API key management, JWT, OAuth2
|
||||
- **Database**: `db/` - Prisma ORM with PostgreSQL/SQLite support
|
||||
- **Management endpoints**: `management_endpoints/` - Admin APIs for keys, teams, models
|
||||
- **Pass-through endpoints**: `pass_through_endpoints/` - Provider-specific API forwarding
|
||||
- **Guardrails**: `guardrails/` - Safety and content filtering hooks
|
||||
- **UI Dashboard**: Served from `_experimental/out/` (Next.js build)
|
||||
|
||||
## Key Patterns
|
||||
|
||||
### Provider Implementation
|
||||
- Providers inherit from base classes in `litellm/llms/base.py`
|
||||
- Each provider has transformation functions for input/output formatting
|
||||
- Support both sync and async operations
|
||||
- Handle streaming responses and function calling
|
||||
|
||||
### Error Handling
|
||||
- Provider-specific exceptions mapped to OpenAI-compatible errors
|
||||
- Fallback logic handled by Router system
|
||||
- Comprehensive logging through `litellm/_logging.py`
|
||||
|
||||
### Configuration
|
||||
- YAML config files for proxy server (see `proxy/example_config_yaml/`)
|
||||
- Environment variables for API keys and settings
|
||||
- Database schema managed via Prisma (`proxy/schema.prisma`)
|
||||
|
||||
## Development Notes
|
||||
|
||||
### Code Style
|
||||
- Uses Black formatter, Ruff linter, MyPy type checker
|
||||
- Pydantic v2 for data validation
|
||||
- Async/await patterns throughout
|
||||
- Type hints required for all public APIs
|
||||
|
||||
### Testing Strategy
|
||||
- Unit tests in `tests/test_litellm/`
|
||||
- Integration tests for each provider in `tests/llm_translation/`
|
||||
- Proxy tests in `tests/proxy_unit_tests/`
|
||||
- Load tests in `tests/load_tests/`
|
||||
|
||||
### Database Migrations
|
||||
- Prisma handles schema migrations
|
||||
- Migration files auto-generated with `prisma migrate dev`
|
||||
- Always test migrations against both PostgreSQL and SQLite
|
||||
|
||||
### Enterprise Features
|
||||
- Enterprise-specific code in `enterprise/` directory
|
||||
- Optional features enabled via environment variables
|
||||
- Separate licensing and authentication for enterprise features
|
||||
@@ -17,6 +17,8 @@ LiteLLM integrates with vector stores, allowing your models to access your organ
|
||||
|
||||
## Supported Vector Stores
|
||||
- [Bedrock Knowledge Bases](https://aws.amazon.com/bedrock/knowledge-bases/)
|
||||
- [OpenAI Vector Stores](https://platform.openai.com/docs/api-reference/vector-stores/search)
|
||||
- [Azure Vector Stores](https://learn.microsoft.com/en-us/azure/ai-services/openai/how-to/file-search?tabs=python#vector-stores
|
||||
|
||||
## Quick Start
|
||||
|
||||
|
||||
@@ -233,3 +233,63 @@ for event in response:
|
||||
</TabItem>
|
||||
</Tabs>
|
||||
|
||||
|
||||
## Calling via `/chat/completions`
|
||||
|
||||
You can also call the Azure Responses API via the `/chat/completions` endpoint.
|
||||
|
||||
|
||||
<Tabs>
|
||||
<TabItem value="litellm-sdk" label="LiteLLM SDK">
|
||||
|
||||
```python showLineNumbers
|
||||
from litellm import completion
|
||||
import os
|
||||
|
||||
os.environ["AZURE_API_BASE"] = "https://my-endpoint-sweden-berri992.openai.azure.com/"
|
||||
os.environ["AZURE_API_VERSION"] = "2023-03-15-preview"
|
||||
os.environ["AZURE_API_KEY"] = "my-api-key"
|
||||
|
||||
response = completion(
|
||||
model="azure/responses/my-custom-o1-pro",
|
||||
messages=[{"role": "user", "content": "Hello world"}],
|
||||
)
|
||||
|
||||
print(response)
|
||||
```
|
||||
</TabItem>
|
||||
<TabItem value="proxy" label="OpenAI SDK with LiteLLM Proxy">
|
||||
|
||||
1. Setup config.yaml
|
||||
|
||||
```yaml showLineNumbers
|
||||
model_list:
|
||||
- model_name: my-custom-o1-pro
|
||||
litellm_params:
|
||||
model: azure/responses/my-custom-o1-pro
|
||||
api_key: os.environ/AZURE_API_KEY
|
||||
api_base: https://my-endpoint-sweden-berri992.openai.azure.com/
|
||||
api_version: 2023-03-15-preview
|
||||
```
|
||||
|
||||
2. Start LiteLLM proxy
|
||||
```bash
|
||||
litellm --config /path/to/config.yaml
|
||||
|
||||
# RUNNING on http://0.0.0.0:4000
|
||||
```
|
||||
|
||||
3. Test it!
|
||||
|
||||
```bash
|
||||
curl http://localhost:4000/v1/chat/completions \
|
||||
-X POST \
|
||||
-H "Content-Type: application/json" \
|
||||
-H "Authorization: Bearer $LITELLM_API_KEY" \
|
||||
-d '{
|
||||
"model": "my-custom-o1-pro",
|
||||
"messages": [{"role": "user", "content": "Hello world"}]
|
||||
}'
|
||||
```
|
||||
</TabItem>
|
||||
</Tabs>
|
||||
@@ -13,11 +13,12 @@ Call your custom torch-serve / internal LLM APIs via LiteLLM
|
||||
:::
|
||||
|
||||
Supported Routes:
|
||||
- `/v1/chat/completions` -> `litellm.completion`
|
||||
- `/v1/completions` -> `litellm.text_completion`
|
||||
- `/v1/embeddings` -> `litellm.embedding`
|
||||
- `/v1/images/generations` -> `litellm.image_generation`
|
||||
- `/v1/chat/completions` -> `litellm.acompletion`
|
||||
- `/v1/completions` -> `litellm.atext_completion`
|
||||
- `/v1/embeddings` -> `litellm.aembedding`
|
||||
- `/v1/images/generations` -> `litellm.aimage_generation`
|
||||
|
||||
- `/v1/messages` -> `litellm.acompletion`
|
||||
|
||||
## Quick Start
|
||||
|
||||
@@ -262,6 +263,102 @@ Expected Response
|
||||
}
|
||||
```
|
||||
|
||||
## Anthropic `/v1/messages`
|
||||
|
||||
- Write the integration for .acompletion
|
||||
- litellm will transform it to /v1/messages
|
||||
|
||||
1. Setup your `custom_handler.py` file
|
||||
|
||||
```python
|
||||
import litellm
|
||||
from litellm import CustomLLM, completion, get_llm_provider
|
||||
|
||||
|
||||
class MyCustomLLM(CustomLLM):
|
||||
async def acompletion(self, *args, **kwargs) -> litellm.ModelResponse:
|
||||
return litellm.completion(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=[{"role": "user", "content": "Hello world"}],
|
||||
mock_response="Hi!",
|
||||
) # type: ignore
|
||||
|
||||
|
||||
my_custom_llm = MyCustomLLM()
|
||||
```
|
||||
|
||||
2. Add to `config.yaml`
|
||||
|
||||
In the config below, we pass
|
||||
|
||||
python_filename: `custom_handler.py`
|
||||
custom_handler_instance_name: `my_custom_llm`. This is defined in Step 1
|
||||
|
||||
custom_handler: `custom_handler.my_custom_llm`
|
||||
|
||||
```yaml
|
||||
model_list:
|
||||
- model_name: "test-model"
|
||||
litellm_params:
|
||||
model: "openai/text-embedding-ada-002"
|
||||
- model_name: "my-custom-model"
|
||||
litellm_params:
|
||||
model: "my-custom-llm/my-model"
|
||||
|
||||
litellm_settings:
|
||||
custom_provider_map:
|
||||
- {"provider": "my-custom-llm", "custom_handler": custom_handler.my_custom_llm}
|
||||
```
|
||||
|
||||
```bash
|
||||
litellm --config /path/to/config.yaml
|
||||
```
|
||||
|
||||
3. Test it!
|
||||
|
||||
```bash
|
||||
curl -L -X POST 'http://0.0.0.0:4000/v1/messages' \
|
||||
-H 'anthropic-version: 2023-06-01' \
|
||||
-H 'content-type: application/json' \
|
||||
-H 'Authorization: Bearer sk-1234' \
|
||||
-d '{
|
||||
"model": "my-custom-model",
|
||||
"max_tokens": 1024,
|
||||
"messages": [{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "What are the key findings in this document 12?"
|
||||
}]
|
||||
}]
|
||||
}'
|
||||
```
|
||||
|
||||
Expected Response
|
||||
|
||||
```json
|
||||
{
|
||||
"id": "chatcmpl-Bm4qEp4h4vCe7Zi4Gud1MAxTWgibO",
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"model": "gpt-3.5-turbo-0125",
|
||||
"stop_sequence": null,
|
||||
"usage": {
|
||||
"input_tokens": 18,
|
||||
"output_tokens": 44
|
||||
},
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "Without the specific document being provided, it is not possible to determine the key findings within it. If you can provide the content or a summary of document 12, I would be happy to help identify the key findings."
|
||||
}
|
||||
],
|
||||
"stop_reason": "end_turn"
|
||||
}
|
||||
```
|
||||
|
||||
|
||||
## Additional Parameters
|
||||
|
||||
Additional parameters are passed inside `optional_params` key in the `completion` or `image_generation` function.
|
||||
|
||||
@@ -292,13 +292,18 @@ Let's also set the default models to `no-default-models`. This means a user can
|
||||
</TabItem>
|
||||
<TabItem value="yaml" label="YAML">
|
||||
|
||||
:::info
|
||||
Team must be created before setting it as the default team.
|
||||
:::
|
||||
|
||||
```yaml
|
||||
default_internal_user_params: # Default Params used when a new user signs in Via SSO
|
||||
user_role: "internal_user" # one of "internal_user", "internal_user_viewer",
|
||||
models: ["no-default-models"] # Optional[List[str]], optional): models to be used by the user
|
||||
teams: # Optional[List[NewUserRequestTeam]], optional): teams to be used by the user
|
||||
- team_id: "team_id_1" # Required[str]: team_id to be used by the user
|
||||
user_role: "user" # Optional[str], optional): Default role in the team. Values: "user" or "admin". Defaults to "user"
|
||||
litellm_settings:
|
||||
default_internal_user_params: # Default Params used when a new user signs in Via SSO
|
||||
user_role: "internal_user" # one of "internal_user", "internal_user_viewer",
|
||||
models: ["no-default-models"] # Optional[List[str]], optional): models to be used by the user
|
||||
teams: # Optional[List[NewUserRequestTeam]], optional): teams to be used by the user
|
||||
- team_id: "team_id_1" # Required[str]: team_id to be used by the user
|
||||
user_role: "user" # Optional[str], optional): Default role in the team. Values: "user" or "admin". Defaults to "user"
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
|
||||
@@ -0,0 +1,30 @@
|
||||
from typing import Optional
|
||||
|
||||
from litellm.proxy._types import GenerateKeyRequest, LiteLLM_TeamTable
|
||||
|
||||
|
||||
def add_team_member_key_duration(
|
||||
team_table: Optional[LiteLLM_TeamTable],
|
||||
data: GenerateKeyRequest,
|
||||
) -> GenerateKeyRequest:
|
||||
if team_table is None:
|
||||
return data
|
||||
|
||||
if data.user_id is None: # only apply for team member keys, not service accounts
|
||||
return data
|
||||
|
||||
if (
|
||||
team_table.metadata is not None
|
||||
and team_table.metadata.get("team_member_key_duration") is not None
|
||||
):
|
||||
data.duration = team_table.metadata["team_member_key_duration"]
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def apply_enterprise_key_management_params(
|
||||
data: GenerateKeyRequest,
|
||||
team_table: Optional[LiteLLM_TeamTable],
|
||||
) -> GenerateKeyRequest:
|
||||
data = add_team_member_key_duration(team_table, data)
|
||||
return data
|
||||
@@ -1141,6 +1141,7 @@ from .router import Router
|
||||
from .assistants.main import *
|
||||
from .batches.main import *
|
||||
from .images.main import *
|
||||
from .vector_stores import *
|
||||
from .batch_completion.main import * # type: ignore
|
||||
from .rerank_api.main import *
|
||||
from .llms.anthropic.experimental_pass_through.messages.handler import *
|
||||
|
||||
@@ -0,0 +1,409 @@
|
||||
# +-------------------------------------------------------------+
|
||||
#
|
||||
# Add Bedrock Knowledge Base Context to your LLM calls
|
||||
#
|
||||
# +-------------------------------------------------------------+
|
||||
# Thank you users! We ❤️ you! - Krrish & Ishaan
|
||||
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_logger, verbose_proxy_logger
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.integrations.vector_store_integrations.base_vector_store import (
|
||||
BaseVectorStore,
|
||||
)
|
||||
from litellm.llms.bedrock.base_aws_llm import BaseAWSLLM
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
get_async_httpx_client,
|
||||
httpxSpecialProvider,
|
||||
)
|
||||
from litellm.types.integrations.rag.bedrock_knowledgebase import (
|
||||
BedrockKBContent,
|
||||
BedrockKBGuardrailConfiguration,
|
||||
BedrockKBRequest,
|
||||
BedrockKBResponse,
|
||||
BedrockKBRetrievalConfiguration,
|
||||
BedrockKBRetrievalQuery,
|
||||
BedrockKBRetrievalResult,
|
||||
)
|
||||
from litellm.types.llms.openai import AllMessageValues, ChatCompletionUserMessage
|
||||
from litellm.types.utils import StandardLoggingVectorStoreRequest
|
||||
from litellm.types.vector_stores import (
|
||||
VectorStoreResultContent,
|
||||
VectorStoreSearchResponse,
|
||||
VectorStoreSearchResult,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
else:
|
||||
LiteLLMLoggingObj = Any
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import StandardCallbackDynamicParams
|
||||
else:
|
||||
StandardCallbackDynamicParams = Any
|
||||
|
||||
|
||||
class BedrockVectorStore(BaseVectorStore, BaseAWSLLM):
|
||||
CONTENT_PREFIX_STRING = "Context: \n\n"
|
||||
CUSTOM_LLM_PROVIDER = "bedrock"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
**kwargs,
|
||||
):
|
||||
self.async_handler = get_async_httpx_client(
|
||||
llm_provider=httpxSpecialProvider.LoggingCallback
|
||||
)
|
||||
|
||||
# store kwargs as optional_params
|
||||
self.optional_params = kwargs
|
||||
|
||||
super().__init__(**kwargs)
|
||||
BaseAWSLLM.__init__(self)
|
||||
|
||||
async def async_get_chat_completion_prompt(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
non_default_params: dict,
|
||||
prompt_id: Optional[str],
|
||||
prompt_variables: Optional[dict],
|
||||
dynamic_callback_params: StandardCallbackDynamicParams,
|
||||
litellm_logging_obj: LiteLLMLoggingObj,
|
||||
tools: Optional[List[Dict]] = None,
|
||||
prompt_label: Optional[str] = None,
|
||||
) -> Tuple[str, List[AllMessageValues], dict]:
|
||||
"""
|
||||
Retrieves the context from the Bedrock Knowledge Base and appends it to the messages.
|
||||
"""
|
||||
if litellm.vector_store_registry is None:
|
||||
return model, messages, non_default_params
|
||||
|
||||
vector_store_ids = litellm.vector_store_registry.pop_vector_store_ids_to_run(
|
||||
non_default_params=non_default_params, tools=tools
|
||||
)
|
||||
vector_store_request_metadata: List[StandardLoggingVectorStoreRequest] = []
|
||||
if vector_store_ids:
|
||||
for vector_store_id in vector_store_ids:
|
||||
start_time = datetime.now()
|
||||
query = self._get_kb_query_from_messages(messages)
|
||||
bedrock_kb_response = await self.make_bedrock_kb_retrieve_request(
|
||||
knowledge_base_id=vector_store_id,
|
||||
query=query,
|
||||
non_default_params=non_default_params,
|
||||
)
|
||||
verbose_logger.debug(
|
||||
f"Bedrock Knowledge Base Response: {bedrock_kb_response}"
|
||||
)
|
||||
|
||||
(
|
||||
context_message,
|
||||
context_string,
|
||||
) = self.get_chat_completion_message_from_bedrock_kb_response(
|
||||
bedrock_kb_response
|
||||
)
|
||||
if context_message is not None:
|
||||
messages.append(context_message)
|
||||
|
||||
#################################################################################################
|
||||
########## LOGGING for Standard Logging Payload, Langfuse, s3, LiteLLM DB etc. ##################
|
||||
#################################################################################################
|
||||
vector_store_search_response: VectorStoreSearchResponse = (
|
||||
self.transform_bedrock_kb_response_to_vector_store_search_response(
|
||||
bedrock_kb_response=bedrock_kb_response, query=query
|
||||
)
|
||||
)
|
||||
vector_store_request_metadata.append(
|
||||
StandardLoggingVectorStoreRequest(
|
||||
vector_store_id=vector_store_id,
|
||||
query=query,
|
||||
vector_store_search_response=vector_store_search_response,
|
||||
custom_llm_provider=self.CUSTOM_LLM_PROVIDER,
|
||||
start_time=start_time.timestamp(),
|
||||
end_time=datetime.now().timestamp(),
|
||||
)
|
||||
)
|
||||
|
||||
litellm_logging_obj.model_call_details[
|
||||
"vector_store_request_metadata"
|
||||
] = vector_store_request_metadata
|
||||
|
||||
return model, messages, non_default_params
|
||||
|
||||
def transform_bedrock_kb_response_to_vector_store_search_response(
|
||||
self,
|
||||
bedrock_kb_response: BedrockKBResponse,
|
||||
query: str,
|
||||
) -> VectorStoreSearchResponse:
|
||||
"""
|
||||
Transform a BedrockKBResponse to a VectorStoreSearchResponse
|
||||
"""
|
||||
retrieval_results: Optional[
|
||||
List[BedrockKBRetrievalResult]
|
||||
] = bedrock_kb_response.get("retrievalResults", None)
|
||||
vector_store_search_response: VectorStoreSearchResponse = (
|
||||
VectorStoreSearchResponse(search_query=query, data=[])
|
||||
)
|
||||
if retrieval_results is None:
|
||||
return vector_store_search_response
|
||||
|
||||
vector_search_response_data: List[VectorStoreSearchResult] = []
|
||||
for retrieval_result in retrieval_results:
|
||||
content: Optional[BedrockKBContent] = retrieval_result.get("content", None)
|
||||
if content is None:
|
||||
continue
|
||||
content_text: Optional[str] = content.get("text", None)
|
||||
if content_text is None:
|
||||
continue
|
||||
vector_store_search_result: VectorStoreSearchResult = (
|
||||
VectorStoreSearchResult(
|
||||
score=retrieval_result.get("score", None),
|
||||
content=[VectorStoreResultContent(text=content_text, type="text")],
|
||||
)
|
||||
)
|
||||
vector_search_response_data.append(vector_store_search_result)
|
||||
vector_store_search_response["data"] = vector_search_response_data
|
||||
return vector_store_search_response
|
||||
|
||||
def _get_kb_query_from_messages(self, messages: List[AllMessageValues]) -> str:
|
||||
"""
|
||||
Uses the text `content` field of the last message in the list of messages
|
||||
"""
|
||||
if len(messages) == 0:
|
||||
return ""
|
||||
last_message = messages[-1]
|
||||
last_message_content = last_message.get("content", None)
|
||||
if last_message_content is None:
|
||||
return ""
|
||||
if isinstance(last_message_content, str):
|
||||
return last_message_content
|
||||
elif isinstance(last_message_content, list):
|
||||
return "\n".join([item.get("text", "") for item in last_message_content])
|
||||
return ""
|
||||
|
||||
def _prepare_request(
|
||||
self,
|
||||
credentials: Any,
|
||||
data: BedrockKBRequest,
|
||||
optional_params: dict,
|
||||
aws_region_name: str,
|
||||
api_base: str,
|
||||
extra_headers: Optional[dict] = None,
|
||||
) -> Any:
|
||||
"""
|
||||
Prepare a signed AWS request.
|
||||
|
||||
Args:
|
||||
credentials: AWS credentials
|
||||
data: Request data
|
||||
optional_params: Additional parameters
|
||||
aws_region_name: AWS region name
|
||||
api_base: Base API URL
|
||||
extra_headers: Additional headers
|
||||
|
||||
Returns:
|
||||
AWSRequest: A signed AWS request
|
||||
"""
|
||||
try:
|
||||
from botocore.auth import SigV4Auth
|
||||
from botocore.awsrequest import AWSRequest
|
||||
except ImportError:
|
||||
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
|
||||
|
||||
sigv4 = SigV4Auth(credentials, "bedrock", aws_region_name)
|
||||
|
||||
encoded_data = json.dumps(data).encode("utf-8")
|
||||
headers = {"Content-Type": "application/json"}
|
||||
if extra_headers is not None:
|
||||
headers = {"Content-Type": "application/json", **extra_headers}
|
||||
|
||||
request = AWSRequest(
|
||||
method="POST", url=api_base, data=encoded_data, headers=headers
|
||||
)
|
||||
sigv4.add_auth(request)
|
||||
if extra_headers is not None and "Authorization" in extra_headers:
|
||||
# prevent sigv4 from overwriting the auth header
|
||||
request.headers["Authorization"] = extra_headers["Authorization"]
|
||||
|
||||
return request.prepare()
|
||||
|
||||
async def make_bedrock_kb_retrieve_request(
|
||||
self,
|
||||
knowledge_base_id: str,
|
||||
query: str,
|
||||
guardrail_id: Optional[str] = None,
|
||||
guardrail_version: Optional[str] = None,
|
||||
next_token: Optional[str] = None,
|
||||
retrieval_configuration: Optional[BedrockKBRetrievalConfiguration] = None,
|
||||
non_default_params: Optional[dict] = None,
|
||||
) -> BedrockKBResponse:
|
||||
"""
|
||||
Make a Bedrock Knowledge Base retrieve request.
|
||||
|
||||
Args:
|
||||
knowledge_base_id (str): The unique identifier of the knowledge base to query
|
||||
query (str): The query text to search for
|
||||
guardrail_id (Optional[str]): The guardrail ID to apply
|
||||
guardrail_version (Optional[str]): The version of the guardrail to apply
|
||||
next_token (Optional[str]): Token for pagination
|
||||
retrieval_configuration (Optional[BedrockKBRetrievalConfiguration]): Configuration for the retrieval process
|
||||
|
||||
Returns:
|
||||
BedrockKBRetrievalResponse: A typed response object containing the retrieval results
|
||||
"""
|
||||
from fastapi import HTTPException
|
||||
|
||||
non_default_params = non_default_params or {}
|
||||
credentials_dict: Dict[str, Any] = {}
|
||||
if litellm.vector_store_registry is not None:
|
||||
credentials_dict = (
|
||||
litellm.vector_store_registry.get_credentials_for_vector_store(
|
||||
knowledge_base_id
|
||||
)
|
||||
)
|
||||
|
||||
credentials = self.get_credentials(
|
||||
aws_access_key_id=credentials_dict.get(
|
||||
"aws_access_key_id", non_default_params.get("aws_access_key_id", None)
|
||||
),
|
||||
aws_secret_access_key=credentials_dict.get(
|
||||
"aws_secret_access_key",
|
||||
non_default_params.get("aws_secret_access_key", None),
|
||||
),
|
||||
aws_session_token=credentials_dict.get(
|
||||
"aws_session_token", non_default_params.get("aws_session_token", None)
|
||||
),
|
||||
aws_region_name=credentials_dict.get(
|
||||
"aws_region_name", non_default_params.get("aws_region_name", None)
|
||||
),
|
||||
aws_session_name=credentials_dict.get(
|
||||
"aws_session_name", non_default_params.get("aws_session_name", None)
|
||||
),
|
||||
aws_profile_name=credentials_dict.get(
|
||||
"aws_profile_name", non_default_params.get("aws_profile_name", None)
|
||||
),
|
||||
aws_role_name=credentials_dict.get(
|
||||
"aws_role_name", non_default_params.get("aws_role_name", None)
|
||||
),
|
||||
aws_web_identity_token=credentials_dict.get(
|
||||
"aws_web_identity_token",
|
||||
non_default_params.get("aws_web_identity_token", None),
|
||||
),
|
||||
aws_sts_endpoint=credentials_dict.get(
|
||||
"aws_sts_endpoint", non_default_params.get("aws_sts_endpoint", None)
|
||||
),
|
||||
)
|
||||
aws_region_name = self.get_aws_region_name_for_non_llm_api_calls(
|
||||
aws_region_name=credentials_dict.get(
|
||||
"aws_region_name", non_default_params.get("aws_region_name", None)
|
||||
),
|
||||
)
|
||||
|
||||
# Prepare request data
|
||||
request_data: BedrockKBRequest = BedrockKBRequest(
|
||||
retrievalQuery=BedrockKBRetrievalQuery(text=query),
|
||||
)
|
||||
if next_token:
|
||||
request_data["nextToken"] = next_token
|
||||
if retrieval_configuration:
|
||||
request_data["retrievalConfiguration"] = retrieval_configuration
|
||||
if guardrail_id and guardrail_version:
|
||||
request_data["guardrailConfiguration"] = BedrockKBGuardrailConfiguration(
|
||||
guardrailId=guardrail_id, guardrailVersion=guardrail_version
|
||||
)
|
||||
verbose_logger.debug(
|
||||
f"Request Data: {json.dumps(request_data, indent=4, default=str)}"
|
||||
)
|
||||
|
||||
# Prepare the request
|
||||
api_base = f"https://bedrock-agent-runtime.{aws_region_name}.amazonaws.com/knowledgebases/{knowledge_base_id}/retrieve"
|
||||
|
||||
prepared_request = self._prepare_request(
|
||||
credentials=credentials,
|
||||
data=request_data,
|
||||
optional_params=self.optional_params,
|
||||
aws_region_name=aws_region_name,
|
||||
api_base=api_base,
|
||||
)
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
"Bedrock Knowledge Base request body: %s, url %s, headers: %s",
|
||||
request_data,
|
||||
prepared_request.url,
|
||||
prepared_request.headers,
|
||||
)
|
||||
|
||||
response = await self.async_handler.post(
|
||||
url=prepared_request.url,
|
||||
data=prepared_request.body, # type: ignore
|
||||
headers=prepared_request.headers, # type: ignore
|
||||
)
|
||||
|
||||
verbose_proxy_logger.debug("Bedrock Knowledge Base response: %s", response.text)
|
||||
|
||||
if response.status_code == 200:
|
||||
response_data = response.json()
|
||||
return BedrockKBResponse(**response_data)
|
||||
else:
|
||||
verbose_proxy_logger.error(
|
||||
"Bedrock Knowledge Base: error in response. Status code: %s, response: %s",
|
||||
response.status_code,
|
||||
response.text,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=response.status_code,
|
||||
detail={
|
||||
"error": "Error calling Bedrock Knowledge Base",
|
||||
"response": response.text,
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_initialized_custom_logger() -> Optional[CustomLogger]:
|
||||
from litellm.litellm_core_utils.litellm_logging import (
|
||||
_init_custom_logger_compatible_class,
|
||||
)
|
||||
|
||||
return _init_custom_logger_compatible_class(
|
||||
logging_integration="bedrock_vector_store",
|
||||
internal_usage_cache=None,
|
||||
llm_router=None,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_chat_completion_message_from_bedrock_kb_response(
|
||||
response: BedrockKBResponse,
|
||||
) -> Tuple[Optional[ChatCompletionUserMessage], str]:
|
||||
"""
|
||||
Retrieves the context from the Bedrock Knowledge Base response and returns a ChatCompletionUserMessage object.
|
||||
"""
|
||||
retrieval_results: Optional[List[BedrockKBRetrievalResult]] = response.get(
|
||||
"retrievalResults", None
|
||||
)
|
||||
if retrieval_results is None:
|
||||
return None, ""
|
||||
|
||||
# string to combine the context from the knowledge base
|
||||
context_string: str = BedrockVectorStore.CONTENT_PREFIX_STRING
|
||||
for retrieval_result in retrieval_results:
|
||||
retrieval_result_content: Optional[BedrockKBContent] = (
|
||||
retrieval_result.get("content", None) or {}
|
||||
)
|
||||
if retrieval_result_content is None:
|
||||
continue
|
||||
retrieval_result_text: Optional[str] = retrieval_result_content.get(
|
||||
"text", None
|
||||
)
|
||||
if retrieval_result_text is None:
|
||||
continue
|
||||
context_string += retrieval_result_text
|
||||
message = ChatCompletionUserMessage(
|
||||
role="user",
|
||||
content=context_string,
|
||||
)
|
||||
return message, context_string
|
||||
@@ -12,7 +12,9 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
|
||||
import litellm
|
||||
from litellm._logging import verbose_logger, verbose_proxy_logger
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.integrations.vector_stores.base_vector_store import BaseVectorStore
|
||||
from litellm.integrations.vector_store_integrations.base_vector_store import (
|
||||
BaseVectorStore,
|
||||
)
|
||||
from litellm.llms.bedrock.base_aws_llm import BaseAWSLLM
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
get_async_httpx_client,
|
||||
|
||||
@@ -32,7 +32,9 @@ from litellm.integrations.opentelemetry import OpenTelemetry
|
||||
from litellm.integrations.opik.opik import OpikLogger
|
||||
from litellm.integrations.prometheus import PrometheusLogger
|
||||
from litellm.integrations.s3_v2 import S3Logger
|
||||
from litellm.integrations.vector_stores.bedrock_vector_store import BedrockVectorStore
|
||||
from litellm.integrations.vector_store_integrations.bedrock_vector_store import (
|
||||
BedrockVectorStore,
|
||||
)
|
||||
from litellm.proxy.hooks.dynamic_rate_limiter import _PROXY_DynamicRateLimitHandler
|
||||
|
||||
|
||||
|
||||
@@ -54,7 +54,9 @@ from litellm.integrations.custom_guardrail import CustomGuardrail
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.integrations.deepeval.deepeval import DeepEvalLogger
|
||||
from litellm.integrations.mlflow import MlflowLogger
|
||||
from litellm.integrations.vector_stores.bedrock_vector_store import BedrockVectorStore
|
||||
from litellm.integrations.vector_store_integrations.bedrock_vector_store import (
|
||||
BedrockVectorStore,
|
||||
)
|
||||
from litellm.litellm_core_utils.get_litellm_params import get_litellm_params
|
||||
from litellm.litellm_core_utils.llm_cost_calc.tool_call_cost_tracking import (
|
||||
StandardBuiltInToolCostTracking,
|
||||
|
||||
@@ -16,7 +16,6 @@ from typing import (
|
||||
Optional,
|
||||
Union,
|
||||
cast,
|
||||
Set,
|
||||
)
|
||||
|
||||
from litellm.types.llms.openai import (
|
||||
@@ -508,6 +507,7 @@ def unpack_defs(schema: dict, defs: dict) -> None:
|
||||
"""
|
||||
|
||||
import copy
|
||||
from collections import deque
|
||||
|
||||
# Combine the defs handed down by the caller with defs/definitions found on
|
||||
# the current node. Local keys shadow parent keys to match JSON-schema
|
||||
@@ -518,14 +518,17 @@ def unpack_defs(schema: dict, defs: dict) -> None:
|
||||
**schema.get("definitions", {}),
|
||||
}
|
||||
|
||||
def _walk_and_resolve(node: Any, active_defs: dict, seen: Set[int]): # type: ignore[name-defined]
|
||||
"""Depth-first resolver that replaces ``{"$ref": "#/defs/Foo"}`` with
|
||||
the *actual* ``Foo`` schema.
|
||||
"""
|
||||
|
||||
# Avoid infinite recursion on self-referential schemas
|
||||
# Use iterative approach with queue to avoid recursion
|
||||
# Each item in queue is (node, parent_container, key/index, active_defs, seen_ids)
|
||||
queue: deque[tuple[Any, Union[dict, list, None], Union[str, int, None], dict, set]] = deque([(schema, None, None, root_defs, set())])
|
||||
|
||||
while queue:
|
||||
node, parent, key, active_defs, seen = queue.popleft()
|
||||
|
||||
# Avoid infinite loops on self-referential schemas
|
||||
if id(node) in seen:
|
||||
return node
|
||||
continue
|
||||
seen = seen.copy() # Create new set for this branch
|
||||
seen.add(id(node))
|
||||
|
||||
# ----------------------------- dict -----------------------------
|
||||
@@ -536,7 +539,7 @@ def unpack_defs(schema: dict, defs: dict) -> None:
|
||||
target_schema = active_defs.get(ref_name)
|
||||
# Unknown reference – leave untouched
|
||||
if target_schema is None:
|
||||
return node
|
||||
continue
|
||||
|
||||
# Merge defs from the target to capture nested definitions
|
||||
child_defs = {
|
||||
@@ -545,12 +548,24 @@ def unpack_defs(schema: dict, defs: dict) -> None:
|
||||
**target_schema.get("definitions", {}),
|
||||
}
|
||||
|
||||
# Recursively resolve the target *copy* to avoid mutating the
|
||||
# shared definition map.
|
||||
resolved = _walk_and_resolve(copy.deepcopy(target_schema), child_defs, seen)
|
||||
return resolved
|
||||
# Replace the reference with resolved copy
|
||||
resolved = copy.deepcopy(target_schema)
|
||||
if parent is not None and key is not None:
|
||||
if isinstance(parent, dict) and isinstance(key, str):
|
||||
parent[key] = resolved
|
||||
elif isinstance(parent, list) and isinstance(key, int):
|
||||
parent[key] = resolved
|
||||
else:
|
||||
# This is the root schema itself
|
||||
schema.clear()
|
||||
schema.update(resolved)
|
||||
resolved = schema
|
||||
|
||||
# Add resolved node to queue for further processing
|
||||
queue.append((resolved, parent, key, child_defs, seen))
|
||||
continue
|
||||
|
||||
# --- Case 2: regular dict – recurse into its values ---
|
||||
# --- Case 2: regular dict – process its values ---
|
||||
# Update defs with any nested $defs/definitions present *here*.
|
||||
current_defs = {
|
||||
**active_defs,
|
||||
@@ -558,32 +573,15 @@ def unpack_defs(schema: dict, defs: dict) -> None:
|
||||
**node.get("definitions", {}),
|
||||
}
|
||||
|
||||
for key, val in list(node.items()):
|
||||
node[key] = _walk_and_resolve(val, current_defs, seen)
|
||||
return node
|
||||
# Add all dict values to queue
|
||||
for k, v in node.items():
|
||||
queue.append((v, node, k, current_defs, seen))
|
||||
|
||||
# ---------------------------- list ------------------------------
|
||||
if isinstance(node, list):
|
||||
elif isinstance(node, list):
|
||||
# Add all list items to queue
|
||||
for idx, item in enumerate(node):
|
||||
node[idx] = _walk_and_resolve(item, active_defs, seen)
|
||||
return node
|
||||
|
||||
# -------------------------- primitive ---------------------------
|
||||
return node
|
||||
|
||||
# Kick off traversal
|
||||
resolved_root = _walk_and_resolve(schema, root_defs, set())
|
||||
# If the resolver returned a *different* dict (e.g., the root itself was a
|
||||
# $ref), mirror the changes back into the original object so that callers
|
||||
# holding a reference to ``schema`` see the updated structure.
|
||||
if resolved_root is not schema:
|
||||
schema.clear()
|
||||
if isinstance(resolved_root, dict):
|
||||
schema.update(resolved_root)
|
||||
else:
|
||||
# In the very unlikely case the root was resolved to a non-dict
|
||||
# (e.g., a primitive), replace in-place via a sentinel key.
|
||||
schema["__resolved_value__"] = resolved_root # type: ignore
|
||||
queue.append((item, node, idx, active_defs, seen))
|
||||
|
||||
|
||||
def _get_image_mime_type_from_url(url: str) -> Optional[str]:
|
||||
|
||||
@@ -119,6 +119,8 @@ def anthropic_messages_handler(
|
||||
"""
|
||||
Makes Anthropic `/v1/messages` API calls In the Anthropic API Spec
|
||||
"""
|
||||
from litellm.types.utils import LlmProviders
|
||||
|
||||
local_vars = locals()
|
||||
is_async = kwargs.pop("is_async", False)
|
||||
# Use provided client or create a new one
|
||||
@@ -141,12 +143,17 @@ def anthropic_messages_handler(
|
||||
api_key=litellm_params.api_key,
|
||||
)
|
||||
|
||||
anthropic_messages_provider_config: Optional[
|
||||
BaseAnthropicMessagesConfig
|
||||
] = ProviderConfigManager.get_provider_anthropic_messages_config(
|
||||
model=model,
|
||||
provider=litellm.LlmProviders(custom_llm_provider),
|
||||
)
|
||||
anthropic_messages_provider_config: Optional[BaseAnthropicMessagesConfig] = None
|
||||
|
||||
if custom_llm_provider is not None and custom_llm_provider in [
|
||||
provider.value for provider in LlmProviders
|
||||
]:
|
||||
anthropic_messages_provider_config = (
|
||||
ProviderConfigManager.get_provider_anthropic_messages_config(
|
||||
model=model,
|
||||
provider=litellm.LlmProviders(custom_llm_provider),
|
||||
)
|
||||
)
|
||||
if anthropic_messages_provider_config is None:
|
||||
# Handle non-Anthropic models using the adapter
|
||||
return (
|
||||
|
||||
@@ -1,12 +1,11 @@
|
||||
import json
|
||||
import os
|
||||
from typing import Any, Callable, Dict, Optional, Union
|
||||
from typing import Any, Callable, Dict, Literal, Optional, Union, cast
|
||||
|
||||
import httpx
|
||||
from openai import AsyncAzureOpenAI, AzureOpenAI
|
||||
|
||||
import litellm
|
||||
from litellm.types.router import GenericLiteLLMParams
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.caching.caching import DualCache
|
||||
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
||||
@@ -15,6 +14,8 @@ from litellm.secret_managers.get_azure_ad_token_provider import (
|
||||
get_azure_ad_token_provider,
|
||||
)
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
from litellm.types.router import GenericLiteLLMParams
|
||||
from litellm.utils import _add_path_to_api_base
|
||||
|
||||
azure_ad_cache = DualCache()
|
||||
|
||||
@@ -613,3 +614,78 @@ class BaseAzureLLM(BaseOpenAILLM):
|
||||
else:
|
||||
client = AzureOpenAI(**azure_client_params) # type: ignore
|
||||
return client
|
||||
|
||||
@staticmethod
|
||||
def _base_validate_azure_environment(
|
||||
headers: dict, litellm_params: Optional[GenericLiteLLMParams]
|
||||
) -> dict:
|
||||
litellm_params = litellm_params or GenericLiteLLMParams()
|
||||
api_key = (
|
||||
litellm_params.api_key
|
||||
or litellm.api_key
|
||||
or litellm.azure_key
|
||||
or get_secret_str("AZURE_OPENAI_API_KEY")
|
||||
or get_secret_str("AZURE_API_KEY")
|
||||
)
|
||||
|
||||
if api_key:
|
||||
headers["api-key"] = api_key
|
||||
return headers
|
||||
|
||||
### Fallback to Azure AD token-based authentication if no API key is available
|
||||
### Retrieves Azure AD token and adds it to the Authorization header
|
||||
azure_ad_token = get_azure_ad_token(litellm_params)
|
||||
if azure_ad_token:
|
||||
headers["Authorization"] = f"Bearer {azure_ad_token}"
|
||||
|
||||
return headers
|
||||
|
||||
@staticmethod
|
||||
def _get_base_azure_url(
|
||||
api_base: Optional[str],
|
||||
litellm_params: Optional[Union[GenericLiteLLMParams, Dict[str, Any]]],
|
||||
route: Literal["/openai/responses", "/openai/vector_stores"]
|
||||
) -> str:
|
||||
api_base = api_base or litellm.api_base or get_secret_str("AZURE_API_BASE")
|
||||
if api_base is None:
|
||||
raise ValueError(
|
||||
f"api_base is required for Azure AI Studio. Please set the api_base parameter. Passed `api_base={api_base}`"
|
||||
)
|
||||
original_url = httpx.URL(api_base)
|
||||
|
||||
# Extract api_version or use default
|
||||
litellm_params = litellm_params or {}
|
||||
api_version = cast(Optional[str], litellm_params.get("api_version"))
|
||||
|
||||
# Create a new dictionary with existing params
|
||||
query_params = dict(original_url.params)
|
||||
|
||||
# Add api_version if needed
|
||||
if "api-version" not in query_params and api_version:
|
||||
query_params["api-version"] = api_version
|
||||
|
||||
# Add the path to the base URL
|
||||
if route not in api_base:
|
||||
new_url = _add_path_to_api_base(
|
||||
api_base=api_base, ending_path=route
|
||||
)
|
||||
else:
|
||||
new_url = api_base
|
||||
|
||||
if BaseAzureLLM._is_azure_v1_api_version(api_version):
|
||||
# ensure the request go to /openai/v1 and not just /openai
|
||||
if "/openai/v1" not in new_url:
|
||||
parsed_url = httpx.URL(new_url)
|
||||
new_url = str(parsed_url.copy_with(path=parsed_url.path.replace("/openai", "/openai/v1")))
|
||||
|
||||
|
||||
# Use the new query_params dictionary
|
||||
final_url = httpx.URL(new_url).copy_with(params=query_params)
|
||||
|
||||
return str(final_url)
|
||||
|
||||
@staticmethod
|
||||
def _is_azure_v1_api_version(api_version: Optional[str]) -> bool:
|
||||
if api_version is None:
|
||||
return False
|
||||
return api_version == "preview" or api_version == "latest"
|
||||
|
||||
@@ -1,16 +1,11 @@
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, cast
|
||||
from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Tuple
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.llms.azure.common_utils import get_azure_ad_token
|
||||
from litellm.llms.azure.common_utils import BaseAzureLLM
|
||||
from litellm.llms.openai.responses.transformation import OpenAIResponsesAPIConfig
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
from litellm.types.llms.openai import *
|
||||
from litellm.types.responses.main import *
|
||||
from litellm.types.router import GenericLiteLLMParams
|
||||
from litellm.utils import _add_path_to_api_base
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
||||
@@ -24,27 +19,11 @@ class AzureOpenAIResponsesAPIConfig(OpenAIResponsesAPIConfig):
|
||||
def validate_environment(
|
||||
self, headers: dict, model: str, litellm_params: Optional[GenericLiteLLMParams]
|
||||
) -> dict:
|
||||
litellm_params = litellm_params or GenericLiteLLMParams()
|
||||
api_key = (
|
||||
litellm_params.api_key
|
||||
or litellm.api_key
|
||||
or litellm.azure_key
|
||||
or get_secret_str("AZURE_OPENAI_API_KEY")
|
||||
or get_secret_str("AZURE_API_KEY")
|
||||
return BaseAzureLLM._base_validate_azure_environment(
|
||||
headers=headers,
|
||||
litellm_params=litellm_params
|
||||
)
|
||||
|
||||
if api_key:
|
||||
headers["api-key"] = api_key
|
||||
return headers
|
||||
|
||||
### Fallback to Azure AD token-based authentication if no API key is available
|
||||
### Retrieves Azure AD token and adds it to the Authorization header
|
||||
azure_ad_token = get_azure_ad_token(litellm_params)
|
||||
if azure_ad_token:
|
||||
headers["Authorization"] = f"Bearer {azure_ad_token}"
|
||||
|
||||
return headers
|
||||
|
||||
def get_complete_url(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
@@ -66,47 +45,12 @@ class AzureOpenAIResponsesAPIConfig(OpenAIResponsesAPIConfig):
|
||||
- A complete URL string, e.g.,
|
||||
"https://litellm8397336933.openai.azure.com/openai/responses?api-version=2024-05-01-preview"
|
||||
"""
|
||||
api_base = api_base or litellm.api_base or get_secret_str("AZURE_API_BASE")
|
||||
if api_base is None:
|
||||
raise ValueError(
|
||||
f"api_base is required for Azure AI Studio. Please set the api_base parameter. Passed `api_base={api_base}`"
|
||||
)
|
||||
original_url = httpx.URL(api_base)
|
||||
|
||||
# Extract api_version or use default
|
||||
api_version = cast(Optional[str], litellm_params.get("api_version"))
|
||||
|
||||
# Create a new dictionary with existing params
|
||||
query_params = dict(original_url.params)
|
||||
|
||||
# Add api_version if needed
|
||||
if "api-version" not in query_params and api_version:
|
||||
query_params["api-version"] = api_version
|
||||
|
||||
# Add the path to the base URL
|
||||
if "/openai/responses" not in api_base:
|
||||
new_url = _add_path_to_api_base(
|
||||
api_base=api_base, ending_path="/openai/responses"
|
||||
)
|
||||
else:
|
||||
new_url = api_base
|
||||
|
||||
if self._is_azure_v1_api_version(api_version):
|
||||
# ensure the request go to /openai/v1 and not just /openai
|
||||
if "/openai/v1" not in new_url:
|
||||
parsed_url = httpx.URL(new_url)
|
||||
new_url = str(parsed_url.copy_with(path=parsed_url.path.replace("/openai", "/openai/v1")))
|
||||
|
||||
|
||||
# Use the new query_params dictionary
|
||||
final_url = httpx.URL(new_url).copy_with(params=query_params)
|
||||
|
||||
return str(final_url)
|
||||
return BaseAzureLLM._get_base_azure_url(
|
||||
api_base=api_base,
|
||||
litellm_params=litellm_params,
|
||||
route="/openai/responses"
|
||||
)
|
||||
|
||||
def _is_azure_v1_api_version(self, api_version: Optional[str]) -> bool:
|
||||
if api_version is None:
|
||||
return False
|
||||
return api_version == "preview" or api_version == "latest"
|
||||
|
||||
#########################################################
|
||||
########## DELETE RESPONSE API TRANSFORMATION ##############
|
||||
|
||||
@@ -0,0 +1,27 @@
|
||||
from typing import Optional
|
||||
|
||||
from litellm.llms.azure.common_utils import BaseAzureLLM
|
||||
from litellm.llms.openai.vector_stores.transformation import OpenAIVectorStoreConfig
|
||||
from litellm.types.router import GenericLiteLLMParams
|
||||
|
||||
|
||||
class AzureOpenAIVectorStoreConfig(OpenAIVectorStoreConfig):
|
||||
def get_complete_url(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
litellm_params: dict,
|
||||
) -> str:
|
||||
return BaseAzureLLM._get_base_azure_url(
|
||||
api_base=api_base,
|
||||
litellm_params=litellm_params,
|
||||
route="/openai/vector_stores"
|
||||
)
|
||||
|
||||
|
||||
def validate_environment(
|
||||
self, headers: dict, litellm_params: Optional[GenericLiteLLMParams]
|
||||
) -> dict:
|
||||
return BaseAzureLLM._base_validate_azure_environment(
|
||||
headers=headers,
|
||||
litellm_params=litellm_params
|
||||
)
|
||||
@@ -0,0 +1,86 @@
|
||||
from abc import abstractmethod
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm.types.router import GenericLiteLLMParams
|
||||
from litellm.types.vector_stores import (
|
||||
VectorStoreCreateOptionalRequestParams,
|
||||
VectorStoreCreateResponse,
|
||||
VectorStoreSearchOptionalRequestParams,
|
||||
VectorStoreSearchResponse,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
||||
|
||||
from ..chat.transformation import BaseLLMException as _BaseLLMException
|
||||
|
||||
LiteLLMLoggingObj = _LiteLLMLoggingObj
|
||||
BaseLLMException = _BaseLLMException
|
||||
else:
|
||||
LiteLLMLoggingObj = Any
|
||||
BaseLLMException = Any
|
||||
|
||||
class BaseVectorStoreConfig:
|
||||
@abstractmethod
|
||||
def transform_search_vector_store_request(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
query: Union[str, List[str]],
|
||||
vector_store_search_optional_params: VectorStoreSearchOptionalRequestParams,
|
||||
api_base: str,
|
||||
) -> Tuple[str, Dict]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def transform_search_vector_store_response(self, response: httpx.Response) -> VectorStoreSearchResponse:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def transform_create_vector_store_request(
|
||||
self,
|
||||
vector_store_create_optional_params: VectorStoreCreateOptionalRequestParams,
|
||||
api_base: str,
|
||||
) -> Tuple[str, Dict]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def transform_create_vector_store_response(self, response: httpx.Response) -> VectorStoreCreateResponse:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def validate_environment(
|
||||
self, headers: dict, litellm_params: Optional[GenericLiteLLMParams]
|
||||
) -> dict:
|
||||
return {}
|
||||
|
||||
@abstractmethod
|
||||
def get_complete_url(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
litellm_params: dict,
|
||||
) -> str:
|
||||
"""
|
||||
OPTIONAL
|
||||
|
||||
Get the complete url for the request
|
||||
|
||||
Some providers need `model` in `api_base`
|
||||
"""
|
||||
if api_base is None:
|
||||
raise ValueError("api_base is required")
|
||||
return api_base
|
||||
|
||||
|
||||
def get_error_class(
|
||||
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
|
||||
) -> BaseLLMException:
|
||||
from ..chat.transformation import BaseLLMException
|
||||
|
||||
raise BaseLLMException(
|
||||
status_code=status_code,
|
||||
message=error_message,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
@@ -35,6 +35,7 @@ from litellm.llms.base_llm.image_edit.transformation import BaseImageEditConfig
|
||||
from litellm.llms.base_llm.realtime.transformation import BaseRealtimeConfig
|
||||
from litellm.llms.base_llm.rerank.transformation import BaseRerankConfig
|
||||
from litellm.llms.base_llm.responses.transformation import BaseResponsesAPIConfig
|
||||
from litellm.llms.base_llm.vector_store.transformation import BaseVectorStoreConfig
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
AsyncHTTPHandler,
|
||||
HTTPHandler,
|
||||
@@ -60,6 +61,12 @@ from litellm.types.rerank import OptionalRerankParams, RerankResponse
|
||||
from litellm.types.responses.main import DeleteResponseResult
|
||||
from litellm.types.router import GenericLiteLLMParams
|
||||
from litellm.types.utils import EmbeddingResponse, FileTypes, TranscriptionResponse
|
||||
from litellm.types.vector_stores import (
|
||||
VectorStoreCreateOptionalRequestParams,
|
||||
VectorStoreCreateResponse,
|
||||
VectorStoreSearchOptionalRequestParams,
|
||||
VectorStoreSearchResponse,
|
||||
)
|
||||
from litellm.utils import (
|
||||
CustomStreamWrapper,
|
||||
ImageResponse,
|
||||
@@ -2342,7 +2349,7 @@ class BaseLLMHTTPHandler:
|
||||
self,
|
||||
e: Exception,
|
||||
provider_config: Union[
|
||||
BaseConfig, BaseRerankConfig, BaseResponsesAPIConfig, BaseImageEditConfig
|
||||
BaseConfig, BaseRerankConfig, BaseResponsesAPIConfig, BaseImageEditConfig, BaseVectorStoreConfig
|
||||
],
|
||||
):
|
||||
status_code = getattr(e, "status_code", 500)
|
||||
@@ -2613,3 +2620,271 @@ class BaseLLMHTTPHandler:
|
||||
raw_response=response,
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
|
||||
###### VECTOR STORE HANDLER ######
|
||||
async def async_vector_store_search_handler(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
query: Union[str, List[str]],
|
||||
vector_store_search_optional_params: VectorStoreSearchOptionalRequestParams,
|
||||
vector_store_provider_config: BaseVectorStoreConfig,
|
||||
custom_llm_provider: str,
|
||||
litellm_params: GenericLiteLLMParams,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
extra_headers: Optional[Dict[str, Any]] = None,
|
||||
extra_body: Optional[Dict[str, Any]] = None,
|
||||
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
|
||||
_is_async: bool = False,
|
||||
) -> VectorStoreSearchResponse:
|
||||
if client is None or not isinstance(client, AsyncHTTPHandler):
|
||||
async_httpx_client = get_async_httpx_client(
|
||||
llm_provider=litellm.LlmProviders(custom_llm_provider),
|
||||
params={"ssl_verify": litellm_params.get("ssl_verify", None)},
|
||||
)
|
||||
else:
|
||||
async_httpx_client = client
|
||||
|
||||
headers = vector_store_provider_config.validate_environment(
|
||||
headers=extra_headers or {},
|
||||
litellm_params=litellm_params
|
||||
)
|
||||
|
||||
if extra_headers:
|
||||
headers.update(extra_headers)
|
||||
|
||||
api_base = vector_store_provider_config.get_complete_url(
|
||||
api_base=litellm_params.api_base,
|
||||
litellm_params=dict(litellm_params),
|
||||
)
|
||||
|
||||
url, request_body = vector_store_provider_config.transform_search_vector_store_request(
|
||||
vector_store_id=vector_store_id,
|
||||
query=query,
|
||||
vector_store_search_optional_params=vector_store_search_optional_params,
|
||||
api_base=api_base,
|
||||
)
|
||||
|
||||
logging_obj.pre_call(
|
||||
input="",
|
||||
api_key="",
|
||||
additional_args={
|
||||
"complete_input_dict": request_body,
|
||||
"api_base": api_base,
|
||||
"headers": headers,
|
||||
},
|
||||
)
|
||||
|
||||
try:
|
||||
response = await async_httpx_client.post(url=url, headers=headers, json=request_body, timeout=timeout)
|
||||
except Exception as e:
|
||||
raise self._handle_error(e=e, provider_config=vector_store_provider_config)
|
||||
|
||||
return vector_store_provider_config.transform_search_vector_store_response(
|
||||
response=response,
|
||||
)
|
||||
|
||||
def vector_store_search_handler(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
query: Union[str, List[str]],
|
||||
vector_store_search_optional_params: VectorStoreSearchOptionalRequestParams,
|
||||
vector_store_provider_config: BaseVectorStoreConfig,
|
||||
custom_llm_provider: str,
|
||||
litellm_params: GenericLiteLLMParams,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
extra_headers: Optional[Dict[str, Any]] = None,
|
||||
extra_body: Optional[Dict[str, Any]] = None,
|
||||
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
|
||||
_is_async: bool = False,
|
||||
) -> Union[VectorStoreSearchResponse, Coroutine[Any, Any, VectorStoreSearchResponse]]:
|
||||
if _is_async:
|
||||
return self.async_vector_store_search_handler(
|
||||
vector_store_id=vector_store_id,
|
||||
query=query,
|
||||
vector_store_search_optional_params=vector_store_search_optional_params,
|
||||
vector_store_provider_config=vector_store_provider_config,
|
||||
litellm_params=litellm_params,
|
||||
logging_obj=logging_obj,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
extra_headers=extra_headers,
|
||||
extra_body=extra_body,
|
||||
timeout=timeout,
|
||||
client=client,
|
||||
)
|
||||
|
||||
if client is None or not isinstance(client, HTTPHandler):
|
||||
sync_httpx_client = _get_httpx_client(
|
||||
params={"ssl_verify": litellm_params.get("ssl_verify", None)}
|
||||
)
|
||||
else:
|
||||
sync_httpx_client = client
|
||||
|
||||
headers = vector_store_provider_config.validate_environment(
|
||||
headers=extra_headers or {},
|
||||
litellm_params=litellm_params
|
||||
)
|
||||
|
||||
if extra_headers:
|
||||
headers.update(extra_headers)
|
||||
|
||||
api_base = vector_store_provider_config.get_complete_url(
|
||||
api_base=litellm_params.api_base,
|
||||
litellm_params=dict(litellm_params),
|
||||
)
|
||||
|
||||
url, request_body = vector_store_provider_config.transform_search_vector_store_request(
|
||||
vector_store_id=vector_store_id,
|
||||
query=query,
|
||||
vector_store_search_optional_params=vector_store_search_optional_params,
|
||||
api_base=api_base,
|
||||
)
|
||||
|
||||
logging_obj.pre_call(
|
||||
input="",
|
||||
api_key="",
|
||||
additional_args={
|
||||
"complete_input_dict": request_body,
|
||||
"api_base": api_base,
|
||||
"headers": headers,
|
||||
},
|
||||
)
|
||||
|
||||
try:
|
||||
response = sync_httpx_client.post(url=url, headers=headers, json=request_body)
|
||||
except Exception as e:
|
||||
raise self._handle_error(e=e, provider_config=vector_store_provider_config)
|
||||
|
||||
return vector_store_provider_config.transform_search_vector_store_response(
|
||||
response=response,
|
||||
)
|
||||
|
||||
async def async_vector_store_create_handler(
|
||||
self,
|
||||
vector_store_create_optional_params: VectorStoreCreateOptionalRequestParams,
|
||||
vector_store_provider_config: BaseVectorStoreConfig,
|
||||
custom_llm_provider: str,
|
||||
litellm_params: GenericLiteLLMParams,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
extra_headers: Optional[Dict[str, Any]] = None,
|
||||
extra_body: Optional[Dict[str, Any]] = None,
|
||||
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
|
||||
_is_async: bool = False,
|
||||
) -> VectorStoreCreateResponse:
|
||||
if client is None or not isinstance(client, AsyncHTTPHandler):
|
||||
async_httpx_client = get_async_httpx_client(
|
||||
llm_provider=litellm.LlmProviders(custom_llm_provider),
|
||||
params={"ssl_verify": litellm_params.get("ssl_verify", None)},
|
||||
)
|
||||
else:
|
||||
async_httpx_client = client
|
||||
|
||||
headers = vector_store_provider_config.validate_environment(
|
||||
headers=extra_headers or {},
|
||||
litellm_params=litellm_params
|
||||
)
|
||||
|
||||
if extra_headers:
|
||||
headers.update(extra_headers)
|
||||
|
||||
api_base = vector_store_provider_config.get_complete_url(
|
||||
api_base=litellm_params.api_base,
|
||||
litellm_params=dict(litellm_params),
|
||||
)
|
||||
|
||||
url, request_body = vector_store_provider_config.transform_create_vector_store_request(
|
||||
vector_store_create_optional_params=vector_store_create_optional_params,
|
||||
api_base=api_base,
|
||||
)
|
||||
|
||||
logging_obj.pre_call(
|
||||
input="",
|
||||
api_key="",
|
||||
additional_args={
|
||||
"complete_input_dict": request_body,
|
||||
"api_base": api_base,
|
||||
"headers": headers,
|
||||
},
|
||||
)
|
||||
|
||||
try:
|
||||
response = await async_httpx_client.post(url=url, headers=headers, json=request_body, timeout=timeout)
|
||||
except Exception as e:
|
||||
raise self._handle_error(e=e, provider_config=vector_store_provider_config)
|
||||
|
||||
return vector_store_provider_config.transform_create_vector_store_response(
|
||||
response=response,
|
||||
)
|
||||
|
||||
def vector_store_create_handler(
|
||||
self,
|
||||
vector_store_create_optional_params: VectorStoreCreateOptionalRequestParams,
|
||||
vector_store_provider_config: BaseVectorStoreConfig,
|
||||
custom_llm_provider: str,
|
||||
litellm_params: GenericLiteLLMParams,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
extra_headers: Optional[Dict[str, Any]] = None,
|
||||
extra_body: Optional[Dict[str, Any]] = None,
|
||||
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
|
||||
_is_async: bool = False,
|
||||
) -> Union[VectorStoreCreateResponse, Coroutine[Any, Any, VectorStoreCreateResponse]]:
|
||||
if _is_async:
|
||||
return self.async_vector_store_create_handler(
|
||||
vector_store_create_optional_params=vector_store_create_optional_params,
|
||||
vector_store_provider_config=vector_store_provider_config,
|
||||
litellm_params=litellm_params,
|
||||
logging_obj=logging_obj,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
extra_headers=extra_headers,
|
||||
extra_body=extra_body,
|
||||
timeout=timeout,
|
||||
client=client,
|
||||
)
|
||||
|
||||
if client is None or not isinstance(client, HTTPHandler):
|
||||
sync_httpx_client = _get_httpx_client(
|
||||
params={"ssl_verify": litellm_params.get("ssl_verify", None)}
|
||||
)
|
||||
else:
|
||||
sync_httpx_client = client
|
||||
|
||||
headers = vector_store_provider_config.validate_environment(
|
||||
headers=extra_headers or {},
|
||||
litellm_params=litellm_params
|
||||
)
|
||||
|
||||
if extra_headers:
|
||||
headers.update(extra_headers)
|
||||
|
||||
api_base = vector_store_provider_config.get_complete_url(
|
||||
api_base=litellm_params.api_base,
|
||||
litellm_params=dict(litellm_params),
|
||||
)
|
||||
|
||||
url, request_body = vector_store_provider_config.transform_create_vector_store_request(
|
||||
vector_store_create_optional_params=vector_store_create_optional_params,
|
||||
api_base=api_base,
|
||||
)
|
||||
|
||||
logging_obj.pre_call(
|
||||
input="",
|
||||
api_key="",
|
||||
additional_args={
|
||||
"complete_input_dict": request_body,
|
||||
"api_base": api_base,
|
||||
"headers": headers,
|
||||
},
|
||||
)
|
||||
|
||||
try:
|
||||
response = sync_httpx_client.post(url=url, headers=headers, json=request_body)
|
||||
except Exception as e:
|
||||
raise self._handle_error(e=e, provider_config=vector_store_provider_config)
|
||||
|
||||
return vector_store_provider_config.transform_create_vector_store_response(
|
||||
response=response,
|
||||
)
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ Why separate file? Make it easy to see how transformation works
|
||||
Docs - https://docs.mistral.ai/api/
|
||||
"""
|
||||
|
||||
from typing import Any, Coroutine, List, Literal, Optional, Tuple, Union, overload, cast
|
||||
from typing import Any, Coroutine, List, Literal, Optional, Tuple, Union, cast, overload
|
||||
|
||||
from litellm.litellm_core_utils.prompt_templates.common_utils import (
|
||||
handle_messages_with_content_list_to_str_conversion,
|
||||
@@ -107,15 +107,28 @@ class MistralConfig(OpenAIGPTConfig):
|
||||
def _get_mistral_reasoning_system_prompt() -> str:
|
||||
"""
|
||||
Returns the system prompt for Mistral reasoning models.
|
||||
Based on Mistral's documentation: https://docs.mistral.ai/capabilities/reasoning/
|
||||
Based on Mistral's documentation: https://huggingface.co/mistralai/Magistral-Small-2506
|
||||
|
||||
Mistral recommends the following system prompt for reasoning:
|
||||
"""
|
||||
return """When solving problems, think step-by-step in <think> tags before providing your final answer. Use the following format:
|
||||
return """
|
||||
<s>[SYSTEM_PROMPT]system_prompt
|
||||
A user will ask you to solve a task. You should first draft your thinking process (inner monologue) until you have derived the final answer. Afterwards, write a self-contained summary of your thoughts (i.e. your summary should be succinct but contain all the critical steps you needed to reach the conclusion). You should use Markdown to format your response. Write both your thoughts and summary in the same language as the task posed by the user. NEVER use \boxed{} in your response.
|
||||
|
||||
<think>
|
||||
Your step-by-step reasoning process. Be thorough and work through the problem carefully.
|
||||
</think>
|
||||
Your thinking process must follow the template below:
|
||||
<think>
|
||||
Your thoughts or/and draft, like working through an exercise on scratch paper. Be as casual and as long as you want until you are confident to generate a correct answer.
|
||||
</think>
|
||||
|
||||
Then provide a clear, concise answer based on your reasoning."""
|
||||
Here, provide a concise summary that reflects your reasoning and presents a clear final answer to the user. Don't mention that this is a summary.
|
||||
|
||||
Problem:
|
||||
|
||||
[/SYSTEM_PROMPT][INST]user_message[/INST]<think>
|
||||
reasoning_traces
|
||||
</think>
|
||||
assistant_response</s>[INST]user_message[/INST]
|
||||
"""
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
|
||||
@@ -5,15 +5,15 @@ Ollama /chat/completion calls handled in llm_http_handler.py
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import litellm
|
||||
from litellm.types.utils import EmbeddingResponse
|
||||
|
||||
|
||||
def _prepare_ollama_embedding_payload(
|
||||
model: str,
|
||||
prompts: List[str],
|
||||
optional_params: Dict[str, Any]
|
||||
model: str, prompts: List[str], optional_params: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
|
||||
|
||||
data: Dict[str, Any] = {"model": model, "input": prompts}
|
||||
special_optional_params = ["truncate", "options", "keep_alive"]
|
||||
|
||||
@@ -26,13 +26,14 @@ def _prepare_ollama_embedding_payload(
|
||||
data["options"].update({k: v})
|
||||
return data
|
||||
|
||||
|
||||
def _process_ollama_embedding_response(
|
||||
response_json: dict,
|
||||
prompts: List[str],
|
||||
model: str,
|
||||
model_response: EmbeddingResponse,
|
||||
logging_obj: Any,
|
||||
encoding: Any
|
||||
encoding: Any,
|
||||
) -> EmbeddingResponse:
|
||||
output_data = []
|
||||
embeddings: List[List[float]] = response_json["embeddings"]
|
||||
@@ -46,11 +47,15 @@ def _process_ollama_embedding_response(
|
||||
if encoding is not None:
|
||||
input_tokens = len(encoding.encode("".join(prompts)))
|
||||
if logging_obj:
|
||||
logging_obj.debug("Ollama response missing prompt_eval_count; estimated with encoding.")
|
||||
logging_obj.debug(
|
||||
"Ollama response missing prompt_eval_count; estimated with encoding."
|
||||
)
|
||||
else:
|
||||
input_tokens = 0
|
||||
if logging_obj:
|
||||
logging_obj.warning("Missing prompt_eval_count and no encoding provided; defaulted to 0.")
|
||||
logging_obj.warning(
|
||||
"Missing prompt_eval_count and no encoding provided; defaulted to 0."
|
||||
)
|
||||
|
||||
model_response.object = "list"
|
||||
model_response.data = output_data
|
||||
@@ -64,6 +69,7 @@ def _process_ollama_embedding_response(
|
||||
)
|
||||
return model_response
|
||||
|
||||
|
||||
async def ollama_aembeddings(
|
||||
api_base: str,
|
||||
model: str,
|
||||
@@ -79,7 +85,7 @@ async def ollama_aembeddings(
|
||||
data = _prepare_ollama_embedding_payload(model, prompts, optional_params)
|
||||
|
||||
response = await litellm.module_level_aclient.post(url=api_base, json=data)
|
||||
response_json = await response.json()
|
||||
response_json = response.json()
|
||||
|
||||
return _process_ollama_embedding_response(
|
||||
response_json=response_json,
|
||||
@@ -87,9 +93,10 @@ async def ollama_aembeddings(
|
||||
model=model,
|
||||
model_response=model_response,
|
||||
logging_obj=logging_obj,
|
||||
encoding=encoding
|
||||
encoding=encoding,
|
||||
)
|
||||
|
||||
|
||||
def ollama_embeddings(
|
||||
api_base: str,
|
||||
model: str,
|
||||
@@ -113,5 +120,5 @@ def ollama_embeddings(
|
||||
model=model,
|
||||
model_response=model_response,
|
||||
logging_obj=logging_obj,
|
||||
encoding=encoding
|
||||
encoding=encoding,
|
||||
)
|
||||
|
||||
@@ -0,0 +1,140 @@
|
||||
from typing import Dict, List, Optional, Tuple, Union, cast
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm.llms.base_llm.vector_store.transformation import BaseVectorStoreConfig
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
from litellm.types.router import GenericLiteLLMParams
|
||||
from litellm.types.vector_stores import (
|
||||
VectorStoreCreateOptionalRequestParams,
|
||||
VectorStoreCreateRequest,
|
||||
VectorStoreCreateResponse,
|
||||
VectorStoreSearchOptionalRequestParams,
|
||||
VectorStoreSearchRequest,
|
||||
VectorStoreSearchResponse,
|
||||
)
|
||||
|
||||
|
||||
class OpenAIVectorStoreConfig(BaseVectorStoreConfig):
|
||||
ASSISTANTS_HEADER_KEY = "OpenAI-Beta"
|
||||
ASSISTANTS_HEADER_VALUE = "assistants=v2"
|
||||
|
||||
def validate_environment(
|
||||
self, headers: dict, litellm_params: Optional[GenericLiteLLMParams]
|
||||
) -> dict:
|
||||
litellm_params = litellm_params or GenericLiteLLMParams()
|
||||
api_key = (
|
||||
litellm_params.api_key
|
||||
or litellm.api_key
|
||||
or litellm.openai_key
|
||||
or get_secret_str("OPENAI_API_KEY")
|
||||
)
|
||||
headers.update(
|
||||
{
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
}
|
||||
)
|
||||
|
||||
#########################################################
|
||||
# Ensure OpenAI Assistants header is includes
|
||||
#########################################################
|
||||
if self.ASSISTANTS_HEADER_KEY not in headers:
|
||||
headers.update(
|
||||
{
|
||||
self.ASSISTANTS_HEADER_KEY: self.ASSISTANTS_HEADER_VALUE,
|
||||
}
|
||||
)
|
||||
|
||||
return headers
|
||||
|
||||
def get_complete_url(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
litellm_params: dict,
|
||||
) -> str:
|
||||
"""
|
||||
Get the Base endpoint for OpenAI Vector Stores API
|
||||
"""
|
||||
api_base = (
|
||||
api_base
|
||||
or litellm.api_base
|
||||
or get_secret_str("OPENAI_BASE_URL")
|
||||
or get_secret_str("OPENAI_API_BASE")
|
||||
or "https://api.openai.com/v1"
|
||||
)
|
||||
|
||||
# Remove trailing slashes
|
||||
api_base = api_base.rstrip("/")
|
||||
|
||||
return f"{api_base}/vector_stores"
|
||||
|
||||
|
||||
def transform_search_vector_store_request(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
query: Union[str, List[str]],
|
||||
vector_store_search_optional_params: VectorStoreSearchOptionalRequestParams,
|
||||
api_base: str,
|
||||
) -> Tuple[str, Dict]:
|
||||
url = f"{api_base}/{vector_store_id}/search"
|
||||
typed_request_body = VectorStoreSearchRequest(
|
||||
query=query,
|
||||
filters=vector_store_search_optional_params.get("filters", None),
|
||||
max_num_results=vector_store_search_optional_params.get("max_num_results", None),
|
||||
ranking_options=vector_store_search_optional_params.get("ranking_options", None),
|
||||
rewrite_query=vector_store_search_optional_params.get("rewrite_query", None),
|
||||
)
|
||||
|
||||
dict_request_body = cast(dict, typed_request_body)
|
||||
return url, dict_request_body
|
||||
|
||||
|
||||
|
||||
def transform_search_vector_store_response(self, response: httpx.Response) -> VectorStoreSearchResponse:
|
||||
try:
|
||||
response_json = response.json()
|
||||
return VectorStoreSearchResponse(
|
||||
**response_json
|
||||
)
|
||||
except Exception as e:
|
||||
raise self.get_error_class(
|
||||
error_message=str(e),
|
||||
status_code=response.status_code,
|
||||
headers=response.headers
|
||||
)
|
||||
|
||||
def transform_create_vector_store_request(
|
||||
self,
|
||||
vector_store_create_optional_params: VectorStoreCreateOptionalRequestParams,
|
||||
api_base: str,
|
||||
) -> Tuple[str, Dict]:
|
||||
url = api_base # Base URL for creating vector stores
|
||||
typed_request_body = VectorStoreCreateRequest(
|
||||
name=vector_store_create_optional_params.get("name", None),
|
||||
file_ids=vector_store_create_optional_params.get("file_ids", None),
|
||||
expires_after=vector_store_create_optional_params.get("expires_after", None),
|
||||
chunking_strategy=vector_store_create_optional_params.get("chunking_strategy", None),
|
||||
metadata=vector_store_create_optional_params.get("metadata", None),
|
||||
)
|
||||
|
||||
dict_request_body = cast(dict, typed_request_body)
|
||||
return url, dict_request_body
|
||||
|
||||
def transform_create_vector_store_response(self, response: httpx.Response) -> VectorStoreCreateResponse:
|
||||
try:
|
||||
response_json = response.json()
|
||||
return VectorStoreCreateResponse(
|
||||
**response_json
|
||||
)
|
||||
except Exception as e:
|
||||
raise self.get_error_class(
|
||||
error_message=str(e),
|
||||
status_code=response.status_code,
|
||||
headers=response.headers
|
||||
)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -63,3 +63,26 @@ class VolcEngineConfig(OpenAILikeChatConfig):
|
||||
"extra_headers",
|
||||
"thinking",
|
||||
] # works across all models
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
replace_max_completion_tokens_with_max_tokens: bool = True,
|
||||
) -> dict:
|
||||
optional_params = super().map_openai_params(
|
||||
non_default_params,
|
||||
optional_params,
|
||||
model,
|
||||
drop_params,
|
||||
replace_max_completion_tokens_with_max_tokens,
|
||||
)
|
||||
|
||||
if "thinking" in optional_params:
|
||||
optional_params.setdefault("extra_body", {})["thinking"] = (
|
||||
optional_params.pop("thinking")
|
||||
)
|
||||
|
||||
return optional_params
|
||||
|
||||
@@ -1297,6 +1297,12 @@ def completion( # type: ignore # noqa: PLR0915
|
||||
except Exception as e:
|
||||
verbose_logger.debug("Error getting model info: {}".format(e))
|
||||
model_info = {}
|
||||
if model.startswith(
|
||||
"responses/"
|
||||
): # handle azure models - `azure/responses/<deployment-name>`
|
||||
model = model.split("/")[1]
|
||||
mode = "responses"
|
||||
model_info["mode"] = mode
|
||||
|
||||
if model_info.get("mode") == "responses":
|
||||
from litellm.completion_extras import responses_api_bridge
|
||||
|
||||
@@ -4046,7 +4046,8 @@
|
||||
"litellm_provider": "mistral",
|
||||
"mode": "chat",
|
||||
"supports_assistant_prefill": true,
|
||||
"supports_tool_choice": true
|
||||
"supports_tool_choice": true,
|
||||
"supports_response_schema": true
|
||||
},
|
||||
"mistral/mistral-small": {
|
||||
"max_tokens": 8191,
|
||||
@@ -4058,7 +4059,8 @@
|
||||
"supports_function_calling": true,
|
||||
"mode": "chat",
|
||||
"supports_assistant_prefill": true,
|
||||
"supports_tool_choice": true
|
||||
"supports_tool_choice": true,
|
||||
"supports_response_schema": true
|
||||
},
|
||||
"mistral/mistral-small-latest": {
|
||||
"max_tokens": 8191,
|
||||
@@ -4070,7 +4072,8 @@
|
||||
"supports_function_calling": true,
|
||||
"mode": "chat",
|
||||
"supports_assistant_prefill": true,
|
||||
"supports_tool_choice": true
|
||||
"supports_tool_choice": true,
|
||||
"supports_response_schema": true
|
||||
},
|
||||
"mistral/mistral-medium": {
|
||||
"max_tokens": 8191,
|
||||
@@ -4081,7 +4084,8 @@
|
||||
"litellm_provider": "mistral",
|
||||
"mode": "chat",
|
||||
"supports_assistant_prefill": true,
|
||||
"supports_tool_choice": true
|
||||
"supports_tool_choice": true,
|
||||
"supports_response_schema": true
|
||||
},
|
||||
"mistral/mistral-medium-latest": {
|
||||
"max_tokens": 8191,
|
||||
@@ -4093,7 +4097,8 @@
|
||||
"mode": "chat",
|
||||
"supports_function_calling": true,
|
||||
"supports_assistant_prefill": true,
|
||||
"supports_tool_choice": true
|
||||
"supports_tool_choice": true,
|
||||
"supports_response_schema": true
|
||||
},
|
||||
"mistral/mistral-medium-2505": {
|
||||
"max_tokens": 8191,
|
||||
@@ -4105,7 +4110,8 @@
|
||||
"mode": "chat",
|
||||
"supports_function_calling": true,
|
||||
"supports_assistant_prefill": true,
|
||||
"supports_tool_choice": true
|
||||
"supports_tool_choice": true,
|
||||
"supports_response_schema": true
|
||||
},
|
||||
"mistral/mistral-medium-2312": {
|
||||
"max_tokens": 8191,
|
||||
@@ -4116,7 +4122,8 @@
|
||||
"litellm_provider": "mistral",
|
||||
"mode": "chat",
|
||||
"supports_assistant_prefill": true,
|
||||
"supports_tool_choice": true
|
||||
"supports_tool_choice": true,
|
||||
"supports_response_schema": true
|
||||
},
|
||||
"mistral/mistral-large-latest": {
|
||||
"max_tokens": 128000,
|
||||
@@ -4128,7 +4135,8 @@
|
||||
"mode": "chat",
|
||||
"supports_function_calling": true,
|
||||
"supports_assistant_prefill": true,
|
||||
"supports_tool_choice": true
|
||||
"supports_tool_choice": true,
|
||||
"supports_response_schema": true
|
||||
},
|
||||
"mistral/mistral-large-2411": {
|
||||
"max_tokens": 128000,
|
||||
@@ -4140,7 +4148,8 @@
|
||||
"mode": "chat",
|
||||
"supports_function_calling": true,
|
||||
"supports_assistant_prefill": true,
|
||||
"supports_tool_choice": true
|
||||
"supports_tool_choice": true,
|
||||
"supports_response_schema": true
|
||||
},
|
||||
"mistral/mistral-large-2402": {
|
||||
"max_tokens": 8191,
|
||||
@@ -4152,7 +4161,8 @@
|
||||
"mode": "chat",
|
||||
"supports_function_calling": true,
|
||||
"supports_assistant_prefill": true,
|
||||
"supports_tool_choice": true
|
||||
"supports_tool_choice": true,
|
||||
"supports_response_schema": true
|
||||
},
|
||||
"mistral/mistral-large-2407": {
|
||||
"max_tokens": 128000,
|
||||
@@ -4164,7 +4174,8 @@
|
||||
"mode": "chat",
|
||||
"supports_function_calling": true,
|
||||
"supports_assistant_prefill": true,
|
||||
"supports_tool_choice": true
|
||||
"supports_tool_choice": true,
|
||||
"supports_response_schema": true
|
||||
},
|
||||
"mistral/pixtral-large-latest": {
|
||||
"max_tokens": 128000,
|
||||
@@ -4177,7 +4188,8 @@
|
||||
"supports_function_calling": true,
|
||||
"supports_assistant_prefill": true,
|
||||
"supports_vision": true,
|
||||
"supports_tool_choice": true
|
||||
"supports_tool_choice": true,
|
||||
"supports_response_schema": true
|
||||
},
|
||||
"mistral/pixtral-large-2411": {
|
||||
"max_tokens": 128000,
|
||||
@@ -4190,7 +4202,8 @@
|
||||
"supports_function_calling": true,
|
||||
"supports_assistant_prefill": true,
|
||||
"supports_vision": true,
|
||||
"supports_tool_choice": true
|
||||
"supports_tool_choice": true,
|
||||
"supports_response_schema": true
|
||||
},
|
||||
"mistral/pixtral-12b-2409": {
|
||||
"max_tokens": 128000,
|
||||
@@ -4203,7 +4216,8 @@
|
||||
"supports_function_calling": true,
|
||||
"supports_assistant_prefill": true,
|
||||
"supports_vision": true,
|
||||
"supports_tool_choice": true
|
||||
"supports_tool_choice": true,
|
||||
"supports_response_schema": true
|
||||
},
|
||||
"mistral/open-mistral-7b": {
|
||||
"max_tokens": 8191,
|
||||
@@ -4214,7 +4228,8 @@
|
||||
"litellm_provider": "mistral",
|
||||
"mode": "chat",
|
||||
"supports_assistant_prefill": true,
|
||||
"supports_tool_choice": true
|
||||
"supports_tool_choice": true,
|
||||
"supports_response_schema": true
|
||||
},
|
||||
"mistral/open-mixtral-8x7b": {
|
||||
"max_tokens": 8191,
|
||||
@@ -4226,7 +4241,8 @@
|
||||
"mode": "chat",
|
||||
"supports_function_calling": true,
|
||||
"supports_assistant_prefill": true,
|
||||
"supports_tool_choice": true
|
||||
"supports_tool_choice": true,
|
||||
"supports_response_schema": true
|
||||
},
|
||||
"mistral/open-mixtral-8x22b": {
|
||||
"max_tokens": 8191,
|
||||
@@ -4238,7 +4254,8 @@
|
||||
"mode": "chat",
|
||||
"supports_function_calling": true,
|
||||
"supports_assistant_prefill": true,
|
||||
"supports_tool_choice": true
|
||||
"supports_tool_choice": true,
|
||||
"supports_response_schema": true
|
||||
},
|
||||
"mistral/codestral-latest": {
|
||||
"max_tokens": 8191,
|
||||
@@ -4249,7 +4266,8 @@
|
||||
"litellm_provider": "mistral",
|
||||
"mode": "chat",
|
||||
"supports_assistant_prefill": true,
|
||||
"supports_tool_choice": true
|
||||
"supports_tool_choice": true,
|
||||
"supports_response_schema": true
|
||||
},
|
||||
"mistral/codestral-2405": {
|
||||
"max_tokens": 8191,
|
||||
@@ -4260,7 +4278,8 @@
|
||||
"litellm_provider": "mistral",
|
||||
"mode": "chat",
|
||||
"supports_assistant_prefill": true,
|
||||
"supports_tool_choice": true
|
||||
"supports_tool_choice": true,
|
||||
"supports_response_schema": true
|
||||
},
|
||||
"mistral/open-mistral-nemo": {
|
||||
"max_tokens": 128000,
|
||||
@@ -4272,7 +4291,8 @@
|
||||
"mode": "chat",
|
||||
"source": "https://mistral.ai/technology/",
|
||||
"supports_assistant_prefill": true,
|
||||
"supports_tool_choice": true
|
||||
"supports_tool_choice": true,
|
||||
"supports_response_schema": true
|
||||
},
|
||||
"mistral/open-mistral-nemo-2407": {
|
||||
"max_tokens": 128000,
|
||||
@@ -4284,7 +4304,8 @@
|
||||
"mode": "chat",
|
||||
"source": "https://mistral.ai/technology/",
|
||||
"supports_assistant_prefill": true,
|
||||
"supports_tool_choice": true
|
||||
"supports_tool_choice": true,
|
||||
"supports_response_schema": true
|
||||
},
|
||||
"mistral/open-codestral-mamba": {
|
||||
"max_tokens": 256000,
|
||||
@@ -4321,7 +4342,8 @@
|
||||
"source": "https://mistral.ai/news/devstral",
|
||||
"supports_function_calling": true,
|
||||
"supports_assistant_prefill": true,
|
||||
"supports_tool_choice": true
|
||||
"supports_tool_choice": true,
|
||||
"supports_response_schema": true
|
||||
},
|
||||
"mistral/magistral-medium-latest": {
|
||||
"max_tokens": 40000,
|
||||
@@ -4335,7 +4357,8 @@
|
||||
"supports_function_calling": true,
|
||||
"supports_assistant_prefill": true,
|
||||
"supports_tool_choice": true,
|
||||
"supports_reasoning": true
|
||||
"supports_reasoning": true,
|
||||
"supports_response_schema": true
|
||||
},
|
||||
"mistral/magistral-medium-2506": {
|
||||
"max_tokens": 40000,
|
||||
@@ -4349,7 +4372,8 @@
|
||||
"supports_function_calling": true,
|
||||
"supports_assistant_prefill": true,
|
||||
"supports_tool_choice": true,
|
||||
"supports_reasoning": true
|
||||
"supports_reasoning": true,
|
||||
"supports_response_schema": true
|
||||
},
|
||||
"mistral/magistral-small-latest": {
|
||||
"max_tokens": 40000,
|
||||
@@ -4363,7 +4387,8 @@
|
||||
"supports_function_calling": true,
|
||||
"supports_assistant_prefill": true,
|
||||
"supports_tool_choice": true,
|
||||
"supports_reasoning": true
|
||||
"supports_reasoning": true,
|
||||
"supports_response_schema": true
|
||||
},
|
||||
"mistral/magistral-small-2506": {
|
||||
"max_tokens": 40000,
|
||||
@@ -4377,7 +4402,8 @@
|
||||
"supports_function_calling": true,
|
||||
"supports_assistant_prefill": true,
|
||||
"supports_tool_choice": true,
|
||||
"supports_reasoning": true
|
||||
"supports_reasoning": true,
|
||||
"supports_response_schema": true
|
||||
},
|
||||
"mistral/mistral-embed": {
|
||||
"max_tokens": 8192,
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
model_list:
|
||||
- model_name: gemini-2.5-pro
|
||||
litellm_params:
|
||||
model: gemini/gemini-2.5-pro
|
||||
model: gemini/gemini-2.5-pro
|
||||
|
||||
@@ -1115,6 +1115,7 @@ class NewTeamRequest(TeamBase):
|
||||
team_member_budget: Optional[float] = (
|
||||
None # allow user to set a budget for all team members
|
||||
)
|
||||
team_member_key_duration: Optional[str] = None # e.g. "1d", "1w", "1m"
|
||||
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
@@ -1157,6 +1158,7 @@ class UpdateTeamRequest(LiteLLMPydanticObjectBase):
|
||||
guardrails: Optional[List[str]] = None
|
||||
object_permission: Optional[LiteLLM_ObjectPermissionBase] = None
|
||||
team_member_budget: Optional[float] = None
|
||||
team_member_key_duration: Optional[str] = None
|
||||
|
||||
|
||||
class ResetTeamBudgetRequest(LiteLLMPydanticObjectBase):
|
||||
@@ -2792,6 +2794,7 @@ LiteLLM_ManagementEndpoint_MetadataFields = [
|
||||
LiteLLM_ManagementEndpoint_MetadataFields_Premium = [
|
||||
"guardrails",
|
||||
"tags",
|
||||
"team_member_key_duration",
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -164,7 +164,9 @@ class JWTHandler:
|
||||
self.litellm_jwtauth.team_ids_jwt_field is not None
|
||||
and token.get(self.litellm_jwtauth.team_ids_jwt_field) is not None
|
||||
):
|
||||
|
||||
return token[self.litellm_jwtauth.team_ids_jwt_field]
|
||||
|
||||
return []
|
||||
|
||||
def get_end_user_id(
|
||||
|
||||
@@ -0,0 +1,357 @@
|
||||
[92m18:10:09 - LiteLLM Router:INFO[0m: router.py:660 - Routing strategy: simple-shuffle
|
||||
[92m18:10:11 - LiteLLM Proxy:INFO[0m: utils.py:1317 - All necessary views exist!
|
||||
[92m18:10:11 - LiteLLM Router:WARNING[0m: router.py:4862 - Error upserting deployment: vertex_project, and vertex_location must be set in litellm_params for pass-through endpoints., ignoring and continuing with other deployments.
|
||||
[92m18:10:11 - LiteLLM Router:WARNING[0m: router.py:4862 - Error upserting deployment: vertex_project, and vertex_location must be set in litellm_params for pass-through endpoints., ignoring and continuing with other deployments.
|
||||
[92m18:10:23 - LiteLLM Proxy:INFO[0m: ui_sso.py:129 - Redirecting to SSO login for http://localhost:4000/sso/callback
|
||||
[92m18:10:27 - LiteLLM Proxy:INFO[0m: ui_sso.py:495 - Starting SSO callback
|
||||
[92m18:10:27 - LiteLLM Proxy:INFO[0m: ui_sso.py:550 - Redirecting to http://localhost:4000/sso/callback
|
||||
[92m18:10:28 - LiteLLM Proxy:INFO[0m: ui_sso.py:581 - SSO callback result: id='krrishd' email='krrishdholakia@gmail.com' first_name=None last_name=None display_name='a3f1c107-04dc-4c93-ae60-7f32eb4b05ce' picture=None provider=None team_ids=[]
|
||||
[92m18:10:28 - LiteLLM Proxy:INFO[0m: ui_sso.py:671 - user_defined_values for creating ui key: {'models': [], 'user_id': 'krrishd', 'user_email': 'krrishdholakia@gmail.com', 'max_budget': None, 'user_role': 'proxy_admin', 'budget_duration': None}
|
||||
[92m18:10:28 - LiteLLM Proxy:INFO[0m: utils.py:1856 - Data Inserted into Keys Table
|
||||
[92m18:10:28 - LiteLLM Proxy:INFO[0m: ui_sso.py:761 - user_id: krrishd; jwt_token: eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJ1c2VyX2lkIjoia3JyaXNoZCIsImtleSI6InNrLTVvOXVVc0ZaaTVBRFBiWERoanhCZlEiLCJ1c2VyX2VtYWlsIjoia3JyaXNoZGhvbGFraWFAZ21haWwuY29tIiwidXNlcl9yb2xlIjoicHJveHlfYWRtaW4iLCJsb2dpbl9tZXRob2QiOiJzc28iLCJwcmVtaXVtX3VzZXIiOnRydWUsImF1dGhfaGVhZGVyX25hbWUiOiJBdXRob3JpemF0aW9uIiwiZGlzYWJsZWRfbm9uX2FkbWluX3BlcnNvbmFsX2tleV9jcmVhdGlvbiI6ZmFsc2UsInNlcnZlcl9yb290X3BhdGgiOiIvIn0.OiZdFjZ2wiMhFbMCwu2cZYXh7oV5BB8Vta-Ysk5JBQU
|
||||
[92m18:10:28 - LiteLLM Proxy:INFO[0m: ui_sso.py:764 - Redirecting to http://localhost:4000/ui/?login=success
|
||||
[92m18:10:30 - LiteLLM Proxy:ERROR[0m: key_management_endpoints.py:2275 - Error in list_keys: Server disconnected without sending a response.
|
||||
Traceback (most recent call last):
|
||||
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpx/_transports/default.py", line 101, in map_httpcore_exceptions
|
||||
yield
|
||||
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpx/_transports/default.py", line 394, in handle_async_request
|
||||
resp = await self._pool.handle_async_request(req)
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpcore/_async/connection_pool.py", line 216, in handle_async_request
|
||||
raise exc from None
|
||||
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpcore/_async/connection_pool.py", line 196, in handle_async_request
|
||||
response = await connection.handle_async_request(
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpcore/_async/connection.py", line 101, in handle_async_request
|
||||
return await self._connection.handle_async_request(request)
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpcore/_async/http11.py", line 143, in handle_async_request
|
||||
raise exc
|
||||
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpcore/_async/http11.py", line 113, in handle_async_request
|
||||
) = await self._receive_response_headers(**kwargs)
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpcore/_async/http11.py", line 186, in _receive_response_headers
|
||||
event = await self._receive_event(timeout=timeout)
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpcore/_async/http11.py", line 238, in _receive_event
|
||||
raise RemoteProtocolError(msg)
|
||||
httpcore.RemoteProtocolError: Server disconnected without sending a response.
|
||||
|
||||
The above exception was the direct cause of the following exception:
|
||||
|
||||
Traceback (most recent call last):
|
||||
File "/Users/krrishdholakia/Documents/litellm/litellm/proxy/management_endpoints/key_management_endpoints.py", line 2255, in list_keys
|
||||
response = await _list_key_helper(
|
||||
^^^^^^^^^^^^^^^^^^^^^^^
|
||||
File "/Users/krrishdholakia/Documents/litellm/litellm/proxy/management_endpoints/key_management_endpoints.py", line 2434, in _list_key_helper
|
||||
total_count = await prisma_client.db.litellm_verificationtoken.count(
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/prisma/actions.py", line 10157, in count
|
||||
resp = await self._client._execute(
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/prisma/_base_client.py", line 543, in _execute
|
||||
return await self._engine.query(builder.build(), tx_id=self._tx_id)
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/prisma/engine/_query.py", line 402, in query
|
||||
return await self.request(
|
||||
^^^^^^^^^^^^^^^^^^^
|
||||
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/prisma/engine/_http.py", line 217, in request
|
||||
response = await self.session.request(method, url, **kwargs)
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/prisma/_async_http.py", line 26, in request
|
||||
return Response(await self.session.request(method, url, **kwargs))
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpx/_client.py", line 1540, in request
|
||||
return await self.send(request, auth=auth, follow_redirects=follow_redirects)
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpx/_client.py", line 1629, in send
|
||||
response = await self._send_handling_auth(
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpx/_client.py", line 1657, in _send_handling_auth
|
||||
response = await self._send_handling_redirects(
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpx/_client.py", line 1694, in _send_handling_redirects
|
||||
response = await self._send_single_request(request)
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpx/_client.py", line 1730, in _send_single_request
|
||||
response = await transport.handle_async_request(request)
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpx/_transports/default.py", line 393, in handle_async_request
|
||||
with map_httpcore_exceptions():
|
||||
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/contextlib.py", line 155, in __exit__
|
||||
self.gen.throw(typ, value, traceback)
|
||||
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpx/_transports/default.py", line 118, in map_httpcore_exceptions
|
||||
raise mapped_exc(message) from exc
|
||||
httpx.RemoteProtocolError: Server disconnected without sending a response.
|
||||
[92m18:10:30 - LiteLLM Proxy:ERROR[0m: proxy_server.py:2730 - litellm.proxy_server.py::add_deployment() - Error getting new models from DB - All connection attempts failed
|
||||
Traceback (most recent call last):
|
||||
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpx/_transports/default.py", line 101, in map_httpcore_exceptions
|
||||
yield
|
||||
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpx/_transports/default.py", line 394, in handle_async_request
|
||||
resp = await self._pool.handle_async_request(req)
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpcore/_async/connection_pool.py", line 216, in handle_async_request
|
||||
raise exc from None
|
||||
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpcore/_async/connection_pool.py", line 196, in handle_async_request
|
||||
response = await connection.handle_async_request(
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpcore/_async/connection.py", line 99, in handle_async_request
|
||||
raise exc
|
||||
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpcore/_async/connection.py", line 76, in handle_async_request
|
||||
stream = await self._connect(request)
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpcore/_async/connection.py", line 122, in _connect
|
||||
stream = await self._network_backend.connect_tcp(**kwargs)
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpcore/_backends/auto.py", line 30, in connect_tcp
|
||||
return await self._backend.connect_tcp(
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpcore/_backends/anyio.py", line 112, in connect_tcp
|
||||
with map_exceptions(exc_map):
|
||||
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/contextlib.py", line 155, in __exit__
|
||||
self.gen.throw(typ, value, traceback)
|
||||
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpcore/_exceptions.py", line 14, in map_exceptions
|
||||
raise to_exc(exc) from exc
|
||||
httpcore.ConnectError: All connection attempts failed
|
||||
|
||||
The above exception was the direct cause of the following exception:
|
||||
|
||||
Traceback (most recent call last):
|
||||
File "/Users/krrishdholakia/Documents/litellm/litellm/proxy/proxy_server.py", line 2728, in _get_models_from_db
|
||||
new_models = await prisma_client.db.litellm_proxymodeltable.find_many()
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/prisma/actions.py", line 2540, in find_many
|
||||
resp = await self._client._execute(
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/prisma/_base_client.py", line 543, in _execute
|
||||
return await self._engine.query(builder.build(), tx_id=self._tx_id)
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/prisma/engine/_query.py", line 402, in query
|
||||
return await self.request(
|
||||
^^^^^^^^^^^^^^^^^^^
|
||||
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/prisma/engine/_http.py", line 217, in request
|
||||
response = await self.session.request(method, url, **kwargs)
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/prisma/_async_http.py", line 26, in request
|
||||
return Response(await self.session.request(method, url, **kwargs))
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpx/_client.py", line 1540, in request
|
||||
return await self.send(request, auth=auth, follow_redirects=follow_redirects)
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpx/_client.py", line 1629, in send
|
||||
response = await self._send_handling_auth(
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpx/_client.py", line 1657, in _send_handling_auth
|
||||
response = await self._send_handling_redirects(
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpx/_client.py", line 1694, in _send_handling_redirects
|
||||
response = await self._send_single_request(request)
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpx/_client.py", line 1730, in _send_single_request
|
||||
response = await transport.handle_async_request(request)
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpx/_transports/default.py", line 393, in handle_async_request
|
||||
with map_httpcore_exceptions():
|
||||
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/contextlib.py", line 155, in __exit__
|
||||
self.gen.throw(typ, value, traceback)
|
||||
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpx/_transports/default.py", line 118, in map_httpcore_exceptions
|
||||
raise mapped_exc(message) from exc
|
||||
httpx.ConnectError: All connection attempts failed
|
||||
[92m18:10:30 - LiteLLM Proxy:ERROR[0m: utils.py:1404 - LiteLLM Prisma Client Exception get_generic_data: All connection attempts failed
|
||||
[92m18:10:30 - LiteLLM Proxy:ERROR[0m: utils.py:1404 - LiteLLM Prisma Client Exception get_generic_data: All connection attempts failed
|
||||
[92m18:10:30 - LiteLLM Proxy:ERROR[0m: proxy_server.py:2778 - litellm.proxy.proxy_server.py::ProxyConfig:add_deployment - All connection attempts failed
|
||||
Traceback (most recent call last):
|
||||
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpx/_transports/default.py", line 101, in map_httpcore_exceptions
|
||||
yield
|
||||
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpx/_transports/default.py", line 394, in handle_async_request
|
||||
resp = await self._pool.handle_async_request(req)
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpcore/_async/connection_pool.py", line 216, in handle_async_request
|
||||
raise exc from None
|
||||
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpcore/_async/connection_pool.py", line 196, in handle_async_request
|
||||
response = await connection.handle_async_request(
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpcore/_async/connection.py", line 99, in handle_async_request
|
||||
raise exc
|
||||
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpcore/_async/connection.py", line 76, in handle_async_request
|
||||
stream = await self._connect(request)
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpcore/_async/connection.py", line 122, in _connect
|
||||
stream = await self._network_backend.connect_tcp(**kwargs)
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpcore/_backends/auto.py", line 30, in connect_tcp
|
||||
return await self._backend.connect_tcp(
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpcore/_backends/anyio.py", line 112, in connect_tcp
|
||||
with map_exceptions(exc_map):
|
||||
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/contextlib.py", line 155, in __exit__
|
||||
self.gen.throw(typ, value, traceback)
|
||||
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpcore/_exceptions.py", line 14, in map_exceptions
|
||||
raise to_exc(exc) from exc
|
||||
httpcore.ConnectError: All connection attempts failed
|
||||
|
||||
The above exception was the direct cause of the following exception:
|
||||
|
||||
Traceback (most recent call last):
|
||||
File "/Users/krrishdholakia/Documents/litellm/litellm/proxy/proxy_server.py", line 2760, in add_deployment
|
||||
await self._update_llm_router(
|
||||
File "/Users/krrishdholakia/Documents/litellm/litellm/proxy/proxy_server.py", line 2418, in _update_llm_router
|
||||
config_data = await proxy_config.get_config()
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
File "/Users/krrishdholakia/Documents/litellm/litellm/proxy/proxy_server.py", line 1584, in get_config
|
||||
config = await self._update_config_from_db(
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
File "/Users/krrishdholakia/Documents/litellm/litellm/proxy/proxy_server.py", line 2706, in _update_config_from_db
|
||||
responses = await asyncio.gather(*_tasks)
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
File "/Users/krrishdholakia/Documents/litellm/litellm/proxy/db/log_db_metrics.py", line 99, in wrapper
|
||||
raise e
|
||||
File "/Users/krrishdholakia/Documents/litellm/litellm/proxy/db/log_db_metrics.py", line 42, in wrapper
|
||||
result = await func(*args, **kwargs)
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/backoff/_async.py", line 151, in retry
|
||||
ret = await target(*args, **kwargs)
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
File "/Users/krrishdholakia/Documents/litellm/litellm/proxy/utils.py", line 1418, in get_generic_data
|
||||
raise e
|
||||
File "/Users/krrishdholakia/Documents/litellm/litellm/proxy/utils.py", line 1392, in get_generic_data
|
||||
response = await self.db.litellm_config.find_first( # type: ignore
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/prisma/actions.py", line 11822, in find_first
|
||||
resp = await self._client._execute(
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/prisma/_base_client.py", line 543, in _execute
|
||||
return await self._engine.query(builder.build(), tx_id=self._tx_id)
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/prisma/engine/_query.py", line 402, in query
|
||||
return await self.request(
|
||||
^^^^^^^^^^^^^^^^^^^
|
||||
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/prisma/engine/_http.py", line 217, in request
|
||||
response = await self.session.request(method, url, **kwargs)
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/prisma/_async_http.py", line 26, in request
|
||||
return Response(await self.session.request(method, url, **kwargs))
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpx/_client.py", line 1540, in request
|
||||
return await self.send(request, auth=auth, follow_redirects=follow_redirects)
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpx/_client.py", line 1629, in send
|
||||
response = await self._send_handling_auth(
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpx/_client.py", line 1657, in _send_handling_auth
|
||||
response = await self._send_handling_redirects(
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpx/_client.py", line 1694, in _send_handling_redirects
|
||||
response = await self._send_single_request(request)
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpx/_client.py", line 1730, in _send_single_request
|
||||
response = await transport.handle_async_request(request)
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpx/_transports/default.py", line 393, in handle_async_request
|
||||
with map_httpcore_exceptions():
|
||||
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/contextlib.py", line 155, in __exit__
|
||||
self.gen.throw(typ, value, traceback)
|
||||
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpx/_transports/default.py", line 118, in map_httpcore_exceptions
|
||||
raise mapped_exc(message) from exc
|
||||
httpx.ConnectError: All connection attempts failed
|
||||
[92m18:10:30 - LiteLLM Proxy:ERROR[0m: utils.py:1404 - LiteLLM Prisma Client Exception get_generic_data: All connection attempts failed
|
||||
[92m18:10:30 - LiteLLM Proxy:ERROR[0m: utils.py:1404 - LiteLLM Prisma Client Exception get_generic_data: All connection attempts failed
|
||||
[92m18:10:30 - LiteLLM Proxy:INFO[0m: proxy_server.py:490 - Shutting down LiteLLM Proxy Server
|
||||
[92m18:11:47 - LiteLLM Router:INFO[0m: router.py:660 - Routing strategy: simple-shuffle
|
||||
[92m18:11:49 - LiteLLM Proxy:INFO[0m: utils.py:1317 - All necessary views exist!
|
||||
[92m18:11:50 - LiteLLM Router:WARNING[0m: router.py:4862 - Error upserting deployment: vertex_project, and vertex_location must be set in litellm_params for pass-through endpoints., ignoring and continuing with other deployments.
|
||||
[92m18:11:50 - LiteLLM Router:WARNING[0m: router.py:4862 - Error upserting deployment: vertex_project, and vertex_location must be set in litellm_params for pass-through endpoints., ignoring and continuing with other deployments.
|
||||
[92m18:12:00 - LiteLLM Proxy:ERROR[0m: proxy_server.py:2925 - litellm.proxy_server.py::get_credentials() - Error getting credentials from DB - Server disconnected without sending a response.
|
||||
Traceback (most recent call last):
|
||||
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpx/_transports/default.py", line 101, in map_httpcore_exceptions
|
||||
yield
|
||||
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpx/_transports/default.py", line 394, in handle_async_request
|
||||
resp = await self._pool.handle_async_request(req)
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpcore/_async/connection_pool.py", line 216, in handle_async_request
|
||||
raise exc from None
|
||||
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpcore/_async/connection_pool.py", line 196, in handle_async_request
|
||||
response = await connection.handle_async_request(
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpcore/_async/connection.py", line 101, in handle_async_request
|
||||
return await self._connection.handle_async_request(request)
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpcore/_async/http11.py", line 143, in handle_async_request
|
||||
raise exc
|
||||
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpcore/_async/http11.py", line 113, in handle_async_request
|
||||
) = await self._receive_response_headers(**kwargs)
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpcore/_async/http11.py", line 186, in _receive_response_headers
|
||||
event = await self._receive_event(timeout=timeout)
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpcore/_async/http11.py", line 238, in _receive_event
|
||||
raise RemoteProtocolError(msg)
|
||||
httpcore.RemoteProtocolError: Server disconnected without sending a response.
|
||||
|
||||
The above exception was the direct cause of the following exception:
|
||||
|
||||
Traceback (most recent call last):
|
||||
File "/Users/krrishdholakia/Documents/litellm/litellm/proxy/proxy_server.py", line 2916, in get_credentials
|
||||
credentials = await prisma_client.db.litellm_credentialstable.find_many()
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/prisma/actions.py", line 1502, in find_many
|
||||
resp = await self._client._execute(
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/prisma/_base_client.py", line 543, in _execute
|
||||
return await self._engine.query(builder.build(), tx_id=self._tx_id)
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/prisma/engine/_query.py", line 402, in query
|
||||
return await self.request(
|
||||
^^^^^^^^^^^^^^^^^^^
|
||||
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/prisma/engine/_http.py", line 217, in request
|
||||
response = await self.session.request(method, url, **kwargs)
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/prisma/_async_http.py", line 26, in request
|
||||
return Response(await self.session.request(method, url, **kwargs))
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpx/_client.py", line 1540, in request
|
||||
return await self.send(request, auth=auth, follow_redirects=follow_redirects)
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpx/_client.py", line 1629, in send
|
||||
response = await self._send_handling_auth(
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpx/_client.py", line 1657, in _send_handling_auth
|
||||
response = await self._send_handling_redirects(
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpx/_client.py", line 1694, in _send_handling_redirects
|
||||
response = await self._send_single_request(request)
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpx/_client.py", line 1730, in _send_single_request
|
||||
response = await transport.handle_async_request(request)
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpx/_transports/default.py", line 393, in handle_async_request
|
||||
with map_httpcore_exceptions():
|
||||
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/contextlib.py", line 155, in __exit__
|
||||
self.gen.throw(typ, value, traceback)
|
||||
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpx/_transports/default.py", line 118, in map_httpcore_exceptions
|
||||
raise mapped_exc(message) from exc
|
||||
httpx.RemoteProtocolError: Server disconnected without sending a response.
|
||||
[92m18:12:01 - LiteLLM Proxy:INFO[0m: proxy_server.py:490 - Shutting down LiteLLM Proxy Server
|
||||
[92m18:12:14 - LiteLLM Router:INFO[0m: router.py:660 - Routing strategy: simple-shuffle
|
||||
[92m18:12:16 - LiteLLM Proxy:INFO[0m: utils.py:1317 - All necessary views exist!
|
||||
[92m18:12:16 - LiteLLM Router:WARNING[0m: router.py:4862 - Error upserting deployment: vertex_project, and vertex_location must be set in litellm_params for pass-through endpoints., ignoring and continuing with other deployments.
|
||||
[92m18:12:16 - LiteLLM Router:WARNING[0m: router.py:4862 - Error upserting deployment: vertex_project, and vertex_location must be set in litellm_params for pass-through endpoints., ignoring and continuing with other deployments.
|
||||
[92m18:12:21 - LiteLLM Proxy:INFO[0m: ui_sso.py:129 - Redirecting to SSO login for http://localhost:4000/sso/callback
|
||||
[92m18:12:26 - LiteLLM Proxy:INFO[0m: ui_sso.py:495 - Starting SSO callback
|
||||
[92m18:12:26 - LiteLLM Proxy:INFO[0m: ui_sso.py:550 - Redirecting to http://localhost:4000/sso/callback
|
||||
[92m18:12:26 - LiteLLM Proxy:INFO[0m: ui_sso.py:581 - SSO callback result: id='krrishd' email='krrishdholakia@gmail.com' first_name=None last_name=None display_name='a3f1c107-04dc-4c93-ae60-7f32eb4b05ce' picture=None provider=None team_ids=[]
|
||||
[92m18:12:27 - LiteLLM Proxy:INFO[0m: ui_sso.py:672 - user_defined_values for creating ui key: {'models': [], 'user_id': 'krrishd', 'user_email': 'krrishdholakia@gmail.com', 'max_budget': None, 'user_role': 'proxy_admin', 'budget_duration': None}
|
||||
[92m18:12:27 - LiteLLM Proxy:INFO[0m: utils.py:1856 - Data Inserted into Keys Table
|
||||
[92m18:12:27 - LiteLLM Proxy:INFO[0m: ui_sso.py:762 - user_id: krrishd; jwt_token: eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJ1c2VyX2lkIjoia3JyaXNoZCIsImtleSI6InNrLUQzMEFpdW9lckU3YlMyakFXWVFLd1EiLCJ1c2VyX2VtYWlsIjoia3JyaXNoZGhvbGFraWFAZ21haWwuY29tIiwidXNlcl9yb2xlIjoicHJveHlfYWRtaW4iLCJsb2dpbl9tZXRob2QiOiJzc28iLCJwcmVtaXVtX3VzZXIiOnRydWUsImF1dGhfaGVhZGVyX25hbWUiOiJBdXRob3JpemF0aW9uIiwiZGlzYWJsZWRfbm9uX2FkbWluX3BlcnNvbmFsX2tleV9jcmVhdGlvbiI6ZmFsc2UsInNlcnZlcl9yb290X3BhdGgiOiIvIn0.EzYP86hw12J4WHLe6ZZz4YgVNGPnxM_PHqLjINH2_-U
|
||||
[92m18:12:27 - LiteLLM Proxy:INFO[0m: ui_sso.py:765 - Redirecting to http://localhost:4000/ui/?login=success
|
||||
[92m18:12:31 - LiteLLM Proxy:INFO[0m: proxy_server.py:490 - Shutting down LiteLLM Proxy Server
|
||||
[92m18:15:07 - LiteLLM Router:INFO[0m: router.py:660 - Routing strategy: simple-shuffle
|
||||
[92m18:15:09 - LiteLLM Proxy:INFO[0m: utils.py:1317 - All necessary views exist!
|
||||
[92m18:15:09 - LiteLLM Router:WARNING[0m: router.py:4862 - Error upserting deployment: vertex_project, and vertex_location must be set in litellm_params for pass-through endpoints., ignoring and continuing with other deployments.
|
||||
[92m18:15:09 - LiteLLM Router:WARNING[0m: router.py:4862 - Error upserting deployment: vertex_project, and vertex_location must be set in litellm_params for pass-through endpoints., ignoring and continuing with other deployments.
|
||||
[92m18:15:17 - LiteLLM Proxy:INFO[0m: utils.py:1916 - Data Inserted into Config Table
|
||||
[92m18:15:28 - LiteLLM Proxy:INFO[0m: ui_sso.py:129 - Redirecting to SSO login for http://localhost:4000/sso/callback
|
||||
[92m18:15:32 - LiteLLM Proxy:INFO[0m: ui_sso.py:495 - Starting SSO callback
|
||||
[92m18:15:32 - LiteLLM Proxy:INFO[0m: ui_sso.py:550 - Redirecting to http://localhost:4000/sso/callback
|
||||
[92m18:15:32 - LiteLLM Proxy:INFO[0m: ui_sso.py:581 - SSO callback result: id='krrishd' email='krrishdholakia@gmail.com' first_name=None last_name=None display_name='a3f1c107-04dc-4c93-ae60-7f32eb4b05ce' picture=None provider=None team_ids=[]
|
||||
[92m18:15:37 - LiteLLM Proxy:INFO[0m: proxy_server.py:490 - Shutting down LiteLLM Proxy Server
|
||||
@@ -512,6 +512,20 @@ async def generate_key_fn( # noqa: PLR0915
|
||||
},
|
||||
)
|
||||
|
||||
# APPLY ENTERPRISE KEY MANAGEMENT PARAMS
|
||||
try:
|
||||
from litellm_enterprise.proxy.management_endpoints.key_management_endpoints import (
|
||||
apply_enterprise_key_management_params,
|
||||
)
|
||||
|
||||
data = apply_enterprise_key_management_params(data, team_table)
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.info(
|
||||
"litellm.proxy.proxy_server.generate_key_fn(): Enterprise key management params not applied - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
|
||||
# TODO: @ishaan-jaff: Migrate all budget tracking to use LiteLLM_BudgetTable
|
||||
_budget_id = data.budget_id
|
||||
if prisma_client is not None and data.soft_budget is not None:
|
||||
@@ -536,7 +550,7 @@ async def generate_key_fn( # noqa: PLR0915
|
||||
# ADD METADATA FIELDS
|
||||
# Set Management Endpoint Metadata Fields
|
||||
for field in LiteLLM_ManagementEndpoint_MetadataFields_Premium:
|
||||
if getattr(data, field) is not None:
|
||||
if getattr(data, field, None) is not None:
|
||||
_set_object_metadata_field(
|
||||
object_data=data,
|
||||
field_name=field,
|
||||
@@ -589,9 +603,9 @@ async def generate_key_fn( # noqa: PLR0915
|
||||
request_type="key", **data_json, table_name="key"
|
||||
)
|
||||
|
||||
response[
|
||||
"soft_budget"
|
||||
] = data.soft_budget # include the user-input soft budget in the response
|
||||
response["soft_budget"] = (
|
||||
data.soft_budget
|
||||
) # include the user-input soft budget in the response
|
||||
|
||||
response = GenerateKeyResponse(**response)
|
||||
|
||||
@@ -667,9 +681,9 @@ async def _set_object_permission(
|
||||
data=data_json["object_permission"],
|
||||
)
|
||||
)
|
||||
data_json[
|
||||
"object_permission_id"
|
||||
] = created_object_permission.object_permission_id
|
||||
data_json["object_permission_id"] = (
|
||||
created_object_permission.object_permission_id
|
||||
)
|
||||
|
||||
# delete the object_permission from the data_json
|
||||
data_json.pop("object_permission")
|
||||
@@ -1652,10 +1666,10 @@ async def delete_verification_tokens(
|
||||
try:
|
||||
if prisma_client:
|
||||
tokens = [_hash_token_if_needed(token=key) for key in tokens]
|
||||
_keys_being_deleted: List[
|
||||
LiteLLM_VerificationToken
|
||||
] = await prisma_client.db.litellm_verificationtoken.find_many(
|
||||
where={"token": {"in": tokens}}
|
||||
_keys_being_deleted: List[LiteLLM_VerificationToken] = (
|
||||
await prisma_client.db.litellm_verificationtoken.find_many(
|
||||
where={"token": {"in": tokens}}
|
||||
)
|
||||
)
|
||||
|
||||
if len(_keys_being_deleted) == 0:
|
||||
@@ -1763,9 +1777,9 @@ async def _rotate_master_key(
|
||||
from litellm.proxy.proxy_server import proxy_config
|
||||
|
||||
try:
|
||||
models: Optional[
|
||||
List
|
||||
] = await prisma_client.db.litellm_proxymodeltable.find_many()
|
||||
models: Optional[List] = (
|
||||
await prisma_client.db.litellm_proxymodeltable.find_many()
|
||||
)
|
||||
except Exception:
|
||||
models = None
|
||||
# 2. process model table
|
||||
@@ -2057,11 +2071,11 @@ async def validate_key_list_check(
|
||||
param="user_id",
|
||||
code=status.HTTP_403_FORBIDDEN,
|
||||
)
|
||||
complete_user_info_db_obj: Optional[
|
||||
BaseModel
|
||||
] = await prisma_client.db.litellm_usertable.find_unique(
|
||||
where={"user_id": user_api_key_dict.user_id},
|
||||
include={"organization_memberships": True},
|
||||
complete_user_info_db_obj: Optional[BaseModel] = (
|
||||
await prisma_client.db.litellm_usertable.find_unique(
|
||||
where={"user_id": user_api_key_dict.user_id},
|
||||
include={"organization_memberships": True},
|
||||
)
|
||||
)
|
||||
|
||||
if complete_user_info_db_obj is None:
|
||||
@@ -2147,10 +2161,10 @@ async def get_admin_team_ids(
|
||||
if complete_user_info is None:
|
||||
return []
|
||||
# Get all teams that user is an admin of
|
||||
teams: Optional[
|
||||
List[BaseModel]
|
||||
] = await prisma_client.db.litellm_teamtable.find_many(
|
||||
where={"team_id": {"in": complete_user_info.teams}}
|
||||
teams: Optional[List[BaseModel]] = (
|
||||
await prisma_client.db.litellm_teamtable.find_many(
|
||||
where={"team_id": {"in": complete_user_info.teams}}
|
||||
)
|
||||
)
|
||||
if teams is None:
|
||||
return []
|
||||
@@ -2403,12 +2417,14 @@ async def _list_key_helper(
|
||||
where=where, # type: ignore
|
||||
skip=skip, # type: ignore
|
||||
take=size, # type: ignore
|
||||
order=order_by
|
||||
if order_by
|
||||
else [
|
||||
{"created_at": "desc"},
|
||||
{"token": "desc"}, # fallback sort
|
||||
],
|
||||
order=(
|
||||
order_by
|
||||
if order_by
|
||||
else [
|
||||
{"created_at": "desc"},
|
||||
{"token": "desc"}, # fallback sort
|
||||
]
|
||||
),
|
||||
include={"object_permission": True},
|
||||
)
|
||||
|
||||
|
||||
@@ -18,6 +18,7 @@ from fastapi import (
|
||||
Response,
|
||||
)
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.litellm_core_utils.safe_json_dumps import safe_dumps
|
||||
from litellm.proxy._types import (
|
||||
@@ -361,6 +362,18 @@ async def create_user(
|
||||
# Create user in database
|
||||
user_id = user.userName or str(uuid.uuid4())
|
||||
metadata = _build_scim_metadata(user_data["given_name"], user_data["family_name"])
|
||||
|
||||
default_role: Optional[
|
||||
Literal[
|
||||
LitellmUserRoles.PROXY_ADMIN,
|
||||
LitellmUserRoles.PROXY_ADMIN_VIEW_ONLY,
|
||||
LitellmUserRoles.INTERNAL_USER,
|
||||
LitellmUserRoles.INTERNAL_USER_VIEW_ONLY,
|
||||
]
|
||||
] = LitellmUserRoles.INTERNAL_USER_VIEW_ONLY
|
||||
if litellm.default_internal_user_params:
|
||||
default_role = litellm.default_internal_user_params.get("user_role")
|
||||
|
||||
new_user_request = NewUserRequest(
|
||||
user_id=user_id,
|
||||
user_email=user_data["user_email"],
|
||||
@@ -368,6 +381,7 @@ async def create_user(
|
||||
teams=user_data["teams"],
|
||||
metadata=metadata,
|
||||
auto_create_key=False,
|
||||
user_role=default_role,
|
||||
)
|
||||
|
||||
# Check if user with email already exists and update if found
|
||||
|
||||
@@ -1,9 +1,14 @@
|
||||
from typing import Dict, Union
|
||||
|
||||
from litellm.proxy._types import LitellmUserRoles
|
||||
|
||||
|
||||
def check_is_admin_only_access(ui_access_mode: str) -> bool:
|
||||
def check_is_admin_only_access(ui_access_mode: Union[str, Dict]) -> bool:
|
||||
"""Checks ui access mode is admin_only"""
|
||||
return ui_access_mode == "admin_only"
|
||||
if isinstance(ui_access_mode, str):
|
||||
return ui_access_mode == "admin_only"
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def has_admin_ui_access(user_role: str) -> bool:
|
||||
|
||||
@@ -262,6 +262,7 @@ async def new_team( # noqa: PLR0915
|
||||
- guardrails: Optional[List[str]] - Guardrails for the team. [Docs](https://docs.litellm.ai/docs/proxy/guardrails)
|
||||
- object_permission: Optional[LiteLLM_ObjectPermissionBase] - team-specific object permission. Example - {"vector_stores": ["vector_store_1", "vector_store_2"]}. IF null or {} then no object permission.
|
||||
- team_member_budget: Optional[float] - The maximum budget allocated to an individual team member.
|
||||
- team_member_key_duration: Optional[str] - The duration for a team member's key. e.g. "1d", "1w", "1mo"
|
||||
|
||||
Returns:
|
||||
- team_id: (str) Unique team id - used for tracking spend across multiple keys for same team id.
|
||||
@@ -688,6 +689,7 @@ async def update_team(
|
||||
- guardrails: Optional[List[str]] - Guardrails for the team. [Docs](https://docs.litellm.ai/docs/proxy/guardrails)
|
||||
- object_permission: Optional[LiteLLM_ObjectPermissionBase] - team-specific object permission. Example - {"vector_stores": ["vector_store_1", "vector_store_2"]}. IF null or {} then no object permission.
|
||||
- team_member_budget: Optional[float] - The maximum budget allocated to an individual team member.
|
||||
- team_member_key_duration: Optional[str] - The duration for a team member's key. e.g. "1d", "1w", "1mo"
|
||||
Example - update team TPM Limit
|
||||
|
||||
```
|
||||
|
||||
@@ -145,7 +145,11 @@ async def google_login(request: Request): # noqa: PLR0915
|
||||
return HTMLResponse(content=html_form, status_code=200)
|
||||
|
||||
|
||||
def generic_response_convertor(response, jwt_handler: JWTHandler):
|
||||
def generic_response_convertor(
|
||||
response,
|
||||
jwt_handler: JWTHandler,
|
||||
sso_jwt_handler: Optional[JWTHandler] = None,
|
||||
):
|
||||
generic_user_id_attribute_name = os.getenv(
|
||||
"GENERIC_USER_ID_ATTRIBUTE", "preferred_username"
|
||||
)
|
||||
@@ -171,6 +175,13 @@ def generic_response_convertor(response, jwt_handler: JWTHandler):
|
||||
f" generic_user_id_attribute_name: {generic_user_id_attribute_name}\n generic_user_email_attribute_name: {generic_user_email_attribute_name}"
|
||||
)
|
||||
|
||||
all_teams = []
|
||||
if sso_jwt_handler is not None:
|
||||
team_ids = sso_jwt_handler.get_team_ids_from_jwt(cast(dict, response))
|
||||
all_teams.extend(team_ids)
|
||||
|
||||
team_ids = jwt_handler.get_team_ids_from_jwt(cast(dict, response))
|
||||
all_teams.extend(team_ids)
|
||||
return CustomOpenID(
|
||||
id=response.get(generic_user_id_attribute_name),
|
||||
display_name=response.get(generic_user_display_name_attribute_name),
|
||||
@@ -178,20 +189,24 @@ def generic_response_convertor(response, jwt_handler: JWTHandler):
|
||||
first_name=response.get(generic_user_first_name_attribute_name),
|
||||
last_name=response.get(generic_user_last_name_attribute_name),
|
||||
provider=response.get(generic_provider_attribute_name),
|
||||
team_ids=jwt_handler.get_team_ids_from_jwt(cast(dict, response)),
|
||||
team_ids=all_teams,
|
||||
)
|
||||
|
||||
|
||||
async def get_generic_sso_response(
|
||||
request: Request,
|
||||
jwt_handler: JWTHandler,
|
||||
sso_jwt_handler: Optional[
|
||||
JWTHandler
|
||||
], # sso specific jwt handler - used for restricted sso group access control
|
||||
generic_client_id: str,
|
||||
redirect_url: str,
|
||||
) -> Union[OpenID, dict]:
|
||||
) -> Tuple[Union[OpenID, dict], Optional[dict]]: # return received response
|
||||
# make generic sso provider
|
||||
from fastapi_sso.sso.base import DiscoveryDocument
|
||||
from fastapi_sso.sso.generic import create_provider
|
||||
|
||||
received_response: Optional[dict] = None
|
||||
generic_client_secret = os.getenv("GENERIC_CLIENT_SECRET", None)
|
||||
generic_scope = os.getenv("GENERIC_SCOPE", "openid email profile").split(" ")
|
||||
generic_authorization_endpoint = os.getenv("GENERIC_AUTHORIZATION_ENDPOINT", None)
|
||||
@@ -242,9 +257,12 @@ async def get_generic_sso_response(
|
||||
)
|
||||
|
||||
def response_convertor(response, client):
|
||||
nonlocal received_response # return for user debugging
|
||||
received_response = response
|
||||
return generic_response_convertor(
|
||||
response=response,
|
||||
jwt_handler=jwt_handler,
|
||||
sso_jwt_handler=sso_jwt_handler,
|
||||
)
|
||||
|
||||
SSOProvider = create_provider(
|
||||
@@ -284,7 +302,7 @@ async def get_generic_sso_response(
|
||||
)
|
||||
raise e
|
||||
verbose_proxy_logger.debug("generic result: %s", result)
|
||||
return result or {}
|
||||
return result or {}, received_response
|
||||
|
||||
|
||||
async def create_team_member_add_task(team_id, user_info):
|
||||
@@ -480,6 +498,8 @@ async def check_and_update_if_proxy_admin_id(
|
||||
async def auth_callback(request: Request): # noqa: PLR0915
|
||||
"""Verify login"""
|
||||
verbose_proxy_logger.info("Starting SSO callback")
|
||||
from litellm.proxy._types import LiteLLM_JWTAuth
|
||||
from litellm.proxy.auth.handle_jwt import JWTHandler
|
||||
from litellm.proxy.management_endpoints.key_management_endpoints import (
|
||||
generate_key_helper_fn,
|
||||
)
|
||||
@@ -490,7 +510,6 @@ async def auth_callback(request: Request): # noqa: PLR0915
|
||||
premium_user,
|
||||
prisma_client,
|
||||
proxy_logging_obj,
|
||||
ui_access_mode,
|
||||
user_api_key_cache,
|
||||
user_custom_sso,
|
||||
)
|
||||
@@ -502,9 +521,25 @@ async def auth_callback(request: Request): # noqa: PLR0915
|
||||
status_code=500, detail=CommonProxyErrors.db_not_connected_error.value
|
||||
)
|
||||
|
||||
sso_jwt_handler: Optional[JWTHandler] = None
|
||||
ui_access_mode = general_settings.get("ui_access_mode", None)
|
||||
if ui_access_mode is not None and isinstance(ui_access_mode, dict):
|
||||
sso_jwt_handler = JWTHandler()
|
||||
sso_jwt_handler.update_environment(
|
||||
prisma_client=prisma_client,
|
||||
user_api_key_cache=user_api_key_cache,
|
||||
litellm_jwtauth=LiteLLM_JWTAuth(
|
||||
team_ids_jwt_field=general_settings.get("ui_access_mode", {}).get(
|
||||
"sso_group_jwt_field", None
|
||||
),
|
||||
),
|
||||
leeway=0,
|
||||
)
|
||||
|
||||
microsoft_client_id = os.getenv("MICROSOFT_CLIENT_ID", None)
|
||||
google_client_id = os.getenv("GOOGLE_CLIENT_ID", None)
|
||||
generic_client_id = os.getenv("GENERIC_CLIENT_ID", None)
|
||||
received_response: Optional[dict] = None
|
||||
# get url from request
|
||||
if master_key is None:
|
||||
raise ProxyException(
|
||||
@@ -532,11 +567,12 @@ async def auth_callback(request: Request): # noqa: PLR0915
|
||||
redirect_url=redirect_url,
|
||||
)
|
||||
elif generic_client_id is not None:
|
||||
result = await get_generic_sso_response(
|
||||
result, received_response = await get_generic_sso_response(
|
||||
request=request,
|
||||
jwt_handler=jwt_handler,
|
||||
generic_client_id=generic_client_id,
|
||||
redirect_url=redirect_url,
|
||||
sso_jwt_handler=sso_jwt_handler,
|
||||
)
|
||||
|
||||
if result is None:
|
||||
@@ -547,6 +583,7 @@ async def auth_callback(request: Request): # noqa: PLR0915
|
||||
|
||||
# User is Authe'd in - generate key for the UI to access Proxy
|
||||
verbose_proxy_logger.info(f"SSO callback result: {result}")
|
||||
|
||||
user_email: Optional[str] = getattr(result, "email", None)
|
||||
user_id: Optional[str] = getattr(result, "id", None) if result is not None else None
|
||||
|
||||
@@ -612,6 +649,13 @@ async def auth_callback(request: Request): # noqa: PLR0915
|
||||
budget_duration=internal_user_budget_duration,
|
||||
)
|
||||
|
||||
# (IF SET) Verify user is in restricted SSO group
|
||||
SSOAuthenticationHandler.verify_user_in_restricted_sso_group(
|
||||
general_settings=general_settings,
|
||||
result=result,
|
||||
received_response=received_response,
|
||||
)
|
||||
|
||||
user_info = await get_user_info_from_db(
|
||||
result=result,
|
||||
prisma_client=prisma_client,
|
||||
@@ -1055,6 +1099,44 @@ class SSOAuthenticationHandler:
|
||||
sso_teams = getattr(result, "team_ids", [])
|
||||
await add_missing_team_member(user_info=user_info, sso_teams=sso_teams)
|
||||
|
||||
@staticmethod
|
||||
def verify_user_in_restricted_sso_group(
|
||||
general_settings: Dict,
|
||||
result: Optional[Union[CustomOpenID, OpenID, dict]],
|
||||
received_response: Optional[dict],
|
||||
) -> Literal[True]:
|
||||
"""
|
||||
when ui_access_mode.type == "restricted_sso_group":
|
||||
|
||||
- result.team_ids should contain the restricted_sso_group
|
||||
- if not, raise a ProxyException
|
||||
- if so, return True
|
||||
- if result.team_ids is None, return False
|
||||
- if result.team_ids is an empty list, return False
|
||||
- if result.team_ids is a list, return True if the restricted_sso_group is in the list, otherwise return False
|
||||
"""
|
||||
|
||||
ui_access_mode = cast(
|
||||
Optional[Union[Dict, str]], general_settings.get("ui_access_mode")
|
||||
)
|
||||
|
||||
if ui_access_mode is None:
|
||||
return True
|
||||
if isinstance(ui_access_mode, str):
|
||||
return True
|
||||
team_ids = getattr(result, "team_ids", [])
|
||||
|
||||
if ui_access_mode.get("type") == "restricted_sso_group":
|
||||
restricted_sso_group = ui_access_mode.get("restricted_sso_group")
|
||||
if restricted_sso_group not in team_ids:
|
||||
raise ProxyException(
|
||||
message=f"User is not in the restricted SSO group: {restricted_sso_group}. User groups: {team_ids}. Received SSO response: {received_response}",
|
||||
type=ProxyErrorTypes.auth_error,
|
||||
param="restricted_sso_group",
|
||||
code=status.HTTP_403_FORBIDDEN,
|
||||
)
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
async def create_litellm_team_from_sso_group(
|
||||
litellm_team_id: str,
|
||||
@@ -1551,7 +1633,29 @@ async def debug_sso_callback(request: Request):
|
||||
|
||||
from fastapi.responses import HTMLResponse
|
||||
|
||||
from litellm.proxy.proxy_server import jwt_handler
|
||||
from litellm.proxy._types import LiteLLM_JWTAuth
|
||||
from litellm.proxy.auth.handle_jwt import JWTHandler
|
||||
from litellm.proxy.proxy_server import (
|
||||
general_settings,
|
||||
jwt_handler,
|
||||
prisma_client,
|
||||
user_api_key_cache,
|
||||
)
|
||||
|
||||
sso_jwt_handler: Optional[JWTHandler] = None
|
||||
ui_access_mode = general_settings.get("ui_access_mode", None)
|
||||
if ui_access_mode is not None and isinstance(ui_access_mode, dict):
|
||||
sso_jwt_handler = JWTHandler()
|
||||
sso_jwt_handler.update_environment(
|
||||
prisma_client=prisma_client,
|
||||
user_api_key_cache=user_api_key_cache,
|
||||
litellm_jwtauth=LiteLLM_JWTAuth(
|
||||
team_ids_jwt_field=general_settings.get("ui_access_mode", {}).get(
|
||||
"sso_group_jwt_field", None
|
||||
),
|
||||
),
|
||||
leeway=0,
|
||||
)
|
||||
|
||||
microsoft_client_id = os.getenv("MICROSOFT_CLIENT_ID", None)
|
||||
google_client_id = os.getenv("GOOGLE_CLIENT_ID", None)
|
||||
@@ -1580,11 +1684,12 @@ async def debug_sso_callback(request: Request):
|
||||
)
|
||||
|
||||
elif generic_client_id is not None:
|
||||
result = await get_generic_sso_response(
|
||||
result, _ = await get_generic_sso_response(
|
||||
request=request,
|
||||
jwt_handler=jwt_handler,
|
||||
generic_client_id=generic_client_id,
|
||||
redirect_url=redirect_url,
|
||||
sso_jwt_handler=sso_jwt_handler,
|
||||
)
|
||||
|
||||
# If result is None, return a basic error message
|
||||
|
||||
@@ -905,7 +905,7 @@ health_check_results: Dict[str, Union[int, List[Dict[str, Any]]]] = {}
|
||||
queue: List = []
|
||||
litellm_proxy_budget_name = "litellm-proxy-budget"
|
||||
litellm_proxy_admin_name = LITELLM_PROXY_ADMIN_NAME
|
||||
ui_access_mode: Literal["admin", "all"] = "all"
|
||||
ui_access_mode: Union[Literal["admin", "all"], Dict] = "all"
|
||||
proxy_budget_rescheduler_min_time = PROXY_BUDGET_RESCHEDULER_MIN_TIME
|
||||
proxy_budget_rescheduler_max_time = PROXY_BUDGET_RESCHEDULER_MAX_TIME
|
||||
proxy_batch_write_at = PROXY_BATCH_WRITE_AT
|
||||
@@ -1435,11 +1435,13 @@ class ProxyConfig:
|
||||
- Do not write restricted params like 'api_key' to the database
|
||||
- if api_key is passed, save that to the local environment or connected secret manage (maybe expose `litellm.save_secret()`)
|
||||
"""
|
||||
|
||||
if prisma_client is not None and (
|
||||
general_settings.get("store_model_in_db", False) is True
|
||||
or store_model_in_db
|
||||
):
|
||||
# if using - db for config - models are in ModelTable
|
||||
|
||||
new_config.pop("model_list", None)
|
||||
await prisma_client.insert_data(data=new_config, table_name="config")
|
||||
else:
|
||||
@@ -2625,6 +2627,10 @@ class ProxyConfig:
|
||||
pass_through_endpoints=general_settings["pass_through_endpoints"]
|
||||
)
|
||||
|
||||
## UI ACCESS MODE ##
|
||||
if "ui_access_mode" in _general_settings:
|
||||
general_settings["ui_access_mode"] = _general_settings["ui_access_mode"]
|
||||
|
||||
def _update_config_fields(
|
||||
self,
|
||||
current_config: dict,
|
||||
|
||||
@@ -81,7 +81,10 @@ async def route_request(
|
||||
team_id = get_team_id_from_data(data)
|
||||
router_model_names = llm_router.model_names if llm_router is not None else []
|
||||
if "api_key" in data or "api_base" in data:
|
||||
return getattr(llm_router, f"{route_type}")(**data)
|
||||
if llm_router is not None:
|
||||
return getattr(llm_router, f"{route_type}")(**data)
|
||||
else:
|
||||
return getattr(litellm, f"{route_type}")(**data)
|
||||
|
||||
elif "user_config" in data:
|
||||
router_config = data.pop("user_config")
|
||||
|
||||
@@ -7,7 +7,10 @@ import litellm
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.proxy._types import *
|
||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||
from litellm.types.proxy.management_endpoints.ui_sso import DefaultTeamSSOParams, SSOConfig
|
||||
from litellm.types.proxy.management_endpoints.ui_sso import (
|
||||
DefaultTeamSSOParams,
|
||||
SSOConfig,
|
||||
)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@@ -18,26 +21,29 @@ class IPAddress(BaseModel):
|
||||
|
||||
class SettingsResponse(BaseModel):
|
||||
"""Base response model for settings with values and schema information"""
|
||||
|
||||
|
||||
values: Dict[str, Any]
|
||||
"""The current configuration values"""
|
||||
|
||||
|
||||
field_schema: Dict[str, Any]
|
||||
"""Schema information including descriptions and property types for UI display"""
|
||||
|
||||
|
||||
class SSOSettingsResponse(SettingsResponse):
|
||||
"""Response model for SSO settings"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class InternalUserSettingsResponse(SettingsResponse):
|
||||
"""Response model for internal user settings"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class DefaultTeamSettingsResponse(SettingsResponse):
|
||||
"""Response model for default team settings"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@@ -166,7 +172,10 @@ async def _get_settings_with_schema(
|
||||
# Add descriptions to the response
|
||||
result = {
|
||||
"values": settings_dict,
|
||||
"field_schema": {"description": schema.get("description", ""), "properties": {}},
|
||||
"field_schema": {
|
||||
"description": schema.get("description", ""),
|
||||
"properties": {},
|
||||
},
|
||||
}
|
||||
|
||||
# Add property descriptions
|
||||
@@ -322,20 +331,21 @@ async def get_sso_settings():
|
||||
Returns a structured object with values and descriptions for UI display.
|
||||
"""
|
||||
import os
|
||||
|
||||
from litellm.proxy.proxy_server import proxy_config
|
||||
|
||||
|
||||
# Load existing config to get both environment variables and general settings
|
||||
config = await proxy_config.get_config()
|
||||
general_settings = config.get("general_settings", {}) or {}
|
||||
environment_variables = config.get("environment_variables", {}) or {}
|
||||
|
||||
|
||||
# Get user_email from general_settings
|
||||
proxy_admin_email = general_settings.get("proxy_admin_email", None)
|
||||
|
||||
|
||||
# Helper function to get env var value (first from config, then from environment)
|
||||
def get_env_value(env_var_name: str):
|
||||
return environment_variables.get(env_var_name) or os.getenv(env_var_name)
|
||||
|
||||
|
||||
# Get current environment variables for SSO
|
||||
sso_config = SSOConfig(
|
||||
google_client_id=get_env_value("GOOGLE_CLIENT_ID"),
|
||||
@@ -351,27 +361,31 @@ async def get_sso_settings():
|
||||
proxy_base_url=get_env_value("PROXY_BASE_URL"),
|
||||
user_email=proxy_admin_email, # Get from config instead of environment
|
||||
)
|
||||
|
||||
|
||||
# Get the schema for UI display
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
schema = TypeAdapter(SSOConfig).json_schema(by_alias=True)
|
||||
|
||||
|
||||
# Convert to dict for response
|
||||
sso_dict = sso_config.model_dump()
|
||||
|
||||
|
||||
# Add descriptions to the response
|
||||
result = {
|
||||
"values": sso_dict,
|
||||
"field_schema": {"description": schema.get("description", ""), "properties": {}},
|
||||
"field_schema": {
|
||||
"description": schema.get("description", ""),
|
||||
"properties": {},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
# Add property descriptions
|
||||
for field_name, field_info in schema["properties"].items():
|
||||
result["field_schema"]["properties"][field_name] = {
|
||||
"description": field_info.get("description", ""),
|
||||
"type": field_info.get("type", "string"),
|
||||
}
|
||||
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@@ -384,51 +398,56 @@ async def update_sso_settings(sso_config: SSOConfig):
|
||||
"""
|
||||
Update SSO configuration by saving to both environment variables and config file.
|
||||
"""
|
||||
from litellm.proxy.proxy_server import proxy_config
|
||||
import os
|
||||
|
||||
|
||||
from litellm.proxy.proxy_server import proxy_config
|
||||
|
||||
# Update environment variables
|
||||
env_var_mapping = {
|
||||
'google_client_id': 'GOOGLE_CLIENT_ID',
|
||||
'google_client_secret': 'GOOGLE_CLIENT_SECRET',
|
||||
'microsoft_client_id': 'MICROSOFT_CLIENT_ID',
|
||||
'microsoft_client_secret': 'MICROSOFT_CLIENT_SECRET',
|
||||
'microsoft_tenant': 'MICROSOFT_TENANT',
|
||||
'generic_client_id': 'GENERIC_CLIENT_ID',
|
||||
'generic_client_secret': 'GENERIC_CLIENT_SECRET',
|
||||
'generic_authorization_endpoint': 'GENERIC_AUTHORIZATION_ENDPOINT',
|
||||
'generic_token_endpoint': 'GENERIC_TOKEN_ENDPOINT',
|
||||
'generic_userinfo_endpoint': 'GENERIC_USERINFO_ENDPOINT',
|
||||
'proxy_base_url': 'PROXY_BASE_URL',
|
||||
"google_client_id": "GOOGLE_CLIENT_ID",
|
||||
"google_client_secret": "GOOGLE_CLIENT_SECRET",
|
||||
"microsoft_client_id": "MICROSOFT_CLIENT_ID",
|
||||
"microsoft_client_secret": "MICROSOFT_CLIENT_SECRET",
|
||||
"microsoft_tenant": "MICROSOFT_TENANT",
|
||||
"generic_client_id": "GENERIC_CLIENT_ID",
|
||||
"generic_client_secret": "GENERIC_CLIENT_SECRET",
|
||||
"generic_authorization_endpoint": "GENERIC_AUTHORIZATION_ENDPOINT",
|
||||
"generic_token_endpoint": "GENERIC_TOKEN_ENDPOINT",
|
||||
"generic_userinfo_endpoint": "GENERIC_USERINFO_ENDPOINT",
|
||||
"proxy_base_url": "PROXY_BASE_URL",
|
||||
}
|
||||
|
||||
|
||||
# Load existing config
|
||||
config = await proxy_config.get_config()
|
||||
|
||||
|
||||
# Update config with new environment variables
|
||||
if "environment_variables" not in config:
|
||||
config["environment_variables"] = {}
|
||||
|
||||
|
||||
# Update general_settings for user_email (admin email)
|
||||
if "general_settings" not in config:
|
||||
config["general_settings"] = {}
|
||||
|
||||
|
||||
# Update environment variables in config and in memory
|
||||
sso_data = sso_config.model_dump(exclude_none=True)
|
||||
for field_name, value in sso_data.items():
|
||||
if field_name == 'user_email' and value is not None:
|
||||
|
||||
if field_name == "user_email" and value is not None:
|
||||
# Store user_email in general_settings instead of environment variables
|
||||
config["general_settings"]["proxy_admin_email"] = value
|
||||
elif field_name == "ui_access_mode" and value is not None:
|
||||
|
||||
config["general_settings"]["ui_access_mode"] = value
|
||||
elif field_name in env_var_mapping and value is not None:
|
||||
env_var_name = env_var_mapping[field_name]
|
||||
# Update in config
|
||||
config["environment_variables"][env_var_name] = value
|
||||
# Update in runtime environment
|
||||
os.environ[env_var_name] = value
|
||||
|
||||
|
||||
# Save the updated config
|
||||
await proxy_config.save_config(new_config=config)
|
||||
|
||||
|
||||
return {
|
||||
"message": "SSO settings updated successfully",
|
||||
"status": "success",
|
||||
|
||||
+18
-12
@@ -4108,20 +4108,26 @@ class Router:
|
||||
original_exception=exception
|
||||
)
|
||||
|
||||
_time_to_cooldown = kwargs.get("litellm_params", {}).get(
|
||||
"cooldown_time", self.cooldown_time
|
||||
)
|
||||
|
||||
# Determine cooldown time with priority: deployment config > response header > router default
|
||||
deployment_cooldown = kwargs.get("litellm_params", {}).get("cooldown_time", None)
|
||||
|
||||
header_cooldown = None
|
||||
if exception_headers is not None:
|
||||
_time_to_cooldown = (
|
||||
litellm.utils._get_retry_after_from_exception_header(
|
||||
response_headers=exception_headers
|
||||
)
|
||||
header_cooldown = litellm.utils._get_retry_after_from_exception_header(
|
||||
response_headers=exception_headers
|
||||
)
|
||||
|
||||
if _time_to_cooldown is None or _time_to_cooldown < 0:
|
||||
# if the response headers did not read it -> set to default cooldown time
|
||||
_time_to_cooldown = self.cooldown_time
|
||||
##############################################
|
||||
# Logic to determine cooldown time
|
||||
# 1. Check if a cooldown time is set in the deployment config
|
||||
# 2. Check if a cooldown time is set in the response header
|
||||
# 3. If no cooldown time is set, use the router default cooldown time
|
||||
##############################################
|
||||
if deployment_cooldown is not None and deployment_cooldown >= 0:
|
||||
_time_to_cooldown = deployment_cooldown
|
||||
elif header_cooldown is not None and header_cooldown >= 0:
|
||||
_time_to_cooldown = header_cooldown
|
||||
else:
|
||||
_time_to_cooldown = self.cooldown_time
|
||||
|
||||
if isinstance(_model_info, dict):
|
||||
deployment_id = _model_info.get("id", None)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import List, Literal, Optional, TypedDict
|
||||
from typing import List, Literal, Optional, TypedDict, Union
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
@@ -31,6 +31,14 @@ class MicrosoftServicePrincipalTeam(TypedDict, total=False):
|
||||
principalId: Optional[str]
|
||||
|
||||
|
||||
class AccessControl_UI_AccessMode(LiteLLMPydanticObjectBase):
|
||||
"""Model for Controlling UI Access Mode via SSO Groups"""
|
||||
|
||||
type: Literal["restricted_sso_group"]
|
||||
restricted_sso_group: str
|
||||
sso_group_jwt_field: str
|
||||
|
||||
|
||||
class SSOConfig(LiteLLMPydanticObjectBase):
|
||||
"""
|
||||
Configuration for SSO environment variables and settings
|
||||
@@ -45,7 +53,7 @@ class SSOConfig(LiteLLMPydanticObjectBase):
|
||||
default=None,
|
||||
description="Google OAuth Client Secret for SSO authentication",
|
||||
)
|
||||
|
||||
|
||||
# Microsoft SSO
|
||||
microsoft_client_id: Optional[str] = Field(
|
||||
default=None,
|
||||
@@ -59,7 +67,7 @@ class SSOConfig(LiteLLMPydanticObjectBase):
|
||||
default=None,
|
||||
description="Microsoft Azure Tenant ID for SSO authentication",
|
||||
)
|
||||
|
||||
|
||||
# Generic/Okta SSO
|
||||
generic_client_id: Optional[str] = Field(
|
||||
default=None,
|
||||
@@ -81,7 +89,7 @@ class SSOConfig(LiteLLMPydanticObjectBase):
|
||||
default=None,
|
||||
description="User info endpoint URL for generic OAuth provider",
|
||||
)
|
||||
|
||||
|
||||
# Common settings
|
||||
proxy_base_url: Optional[str] = Field(
|
||||
default=None,
|
||||
@@ -92,6 +100,12 @@ class SSOConfig(LiteLLMPydanticObjectBase):
|
||||
description="Email of the proxy admin user",
|
||||
)
|
||||
|
||||
# Access Mode
|
||||
ui_access_mode: Optional[Union[AccessControl_UI_AccessMode, str]] = Field(
|
||||
default=None,
|
||||
description="Access mode for the UI",
|
||||
)
|
||||
|
||||
|
||||
class DefaultTeamSSOParams(LiteLLMPydanticObjectBase):
|
||||
"""
|
||||
|
||||
@@ -85,3 +85,83 @@ class VectorStoreSearchResponse(TypedDict, total=False):
|
||||
] # Always "vector_store.search_results.page"
|
||||
search_query: Optional[str]
|
||||
data: Optional[List[VectorStoreSearchResult]]
|
||||
|
||||
class VectorStoreSearchOptionalRequestParams(TypedDict, total=False):
|
||||
"""TypedDict for Optional parameters supported by the vector store search API."""
|
||||
filters: Optional[Dict]
|
||||
max_num_results: Optional[int]
|
||||
ranking_options: Optional[Dict]
|
||||
rewrite_query: Optional[bool]
|
||||
|
||||
class VectorStoreSearchRequest(VectorStoreSearchOptionalRequestParams, total=False):
|
||||
"""Request body for searching a vector store"""
|
||||
query: Union[str, List[str]]
|
||||
|
||||
|
||||
# Vector Store Creation Types
|
||||
class VectorStoreExpirationPolicy(TypedDict, total=False):
|
||||
"""The expiration policy for a vector store"""
|
||||
anchor: Literal["last_active_at"] # Anchor timestamp after which the expiration policy applies
|
||||
days: int # Number of days after anchor time that the vector store will expire
|
||||
|
||||
|
||||
class VectorStoreAutoChunkingStrategy(TypedDict, total=False):
|
||||
"""Auto chunking strategy configuration"""
|
||||
type: Literal["auto"] # Always "auto"
|
||||
|
||||
|
||||
class VectorStoreStaticChunkingStrategyConfig(TypedDict, total=False):
|
||||
"""Static chunking strategy configuration"""
|
||||
max_chunk_size_tokens: int # Maximum number of tokens per chunk
|
||||
chunk_overlap_tokens: int # Number of tokens to overlap between chunks
|
||||
|
||||
|
||||
class VectorStoreStaticChunkingStrategy(TypedDict, total=False):
|
||||
"""Static chunking strategy"""
|
||||
type: Literal["static"] # Always "static"
|
||||
static: VectorStoreStaticChunkingStrategyConfig
|
||||
|
||||
|
||||
class VectorStoreChunkingStrategy(TypedDict, total=False):
|
||||
"""Union type for chunking strategies"""
|
||||
# This can be either auto or static
|
||||
type: Literal["auto", "static"]
|
||||
static: Optional[VectorStoreStaticChunkingStrategyConfig]
|
||||
|
||||
|
||||
class VectorStoreFileCounts(TypedDict, total=False):
|
||||
"""File counts for a vector store"""
|
||||
in_progress: int
|
||||
completed: int
|
||||
failed: int
|
||||
cancelled: int
|
||||
total: int
|
||||
|
||||
|
||||
class VectorStoreCreateOptionalRequestParams(TypedDict, total=False):
|
||||
"""TypedDict for Optional parameters supported by the vector store create API."""
|
||||
name: Optional[str] # Name of the vector store
|
||||
file_ids: Optional[List[str]] # List of File IDs that the vector store should use
|
||||
expires_after: Optional[VectorStoreExpirationPolicy] # Expiration policy for the vector store
|
||||
chunking_strategy: Optional[VectorStoreChunkingStrategy] # Chunking strategy for the files
|
||||
metadata: Optional[Dict[str, str]] # Set of key-value pairs for metadata
|
||||
|
||||
|
||||
class VectorStoreCreateRequest(VectorStoreCreateOptionalRequestParams, total=False):
|
||||
"""Request body for creating a vector store"""
|
||||
pass # All fields are optional for vector store creation
|
||||
|
||||
|
||||
class VectorStoreCreateResponse(TypedDict, total=False):
|
||||
"""Response after creating a vector store"""
|
||||
id: str # ID of the vector store
|
||||
object: Literal["vector_store"] # Always "vector_store"
|
||||
created_at: int # Unix timestamp of when the vector store was created
|
||||
name: Optional[str] # Name of the vector store
|
||||
bytes: int # Size of the vector store in bytes
|
||||
file_counts: VectorStoreFileCounts # File counts for the vector store
|
||||
status: Literal["expired", "in_progress", "completed"] # Status of the vector store
|
||||
expires_after: Optional[VectorStoreExpirationPolicy] # Expiration policy
|
||||
expires_at: Optional[int] # Unix timestamp of when the vector store expires
|
||||
last_active_at: Optional[int] # Unix timestamp of when the vector store was last active
|
||||
metadata: Optional[Dict[str, str]] # Metadata associated with the vector store
|
||||
+25
-2
@@ -78,7 +78,9 @@ from litellm.constants import (
|
||||
)
|
||||
from litellm.integrations.custom_guardrail import CustomGuardrail
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.integrations.vector_stores.base_vector_store import BaseVectorStore
|
||||
from litellm.integrations.vector_store_integrations.base_vector_store import (
|
||||
BaseVectorStore,
|
||||
)
|
||||
from litellm.litellm_core_utils.core_helpers import (
|
||||
map_finish_reason,
|
||||
process_response_headers,
|
||||
@@ -242,6 +244,7 @@ from litellm.llms.base_llm.image_variations.transformation import (
|
||||
from litellm.llms.base_llm.realtime.transformation import BaseRealtimeConfig
|
||||
from litellm.llms.base_llm.rerank.transformation import BaseRerankConfig
|
||||
from litellm.llms.base_llm.responses.transformation import BaseResponsesAPIConfig
|
||||
from litellm.llms.base_llm.vector_store.transformation import BaseVectorStoreConfig
|
||||
|
||||
from ._logging import _is_debugging_on, verbose_logger
|
||||
from .caching.caching import (
|
||||
@@ -6902,13 +6905,33 @@ class ProviderConfigManager:
|
||||
def get_provider_vector_store_config(
|
||||
provider: LlmProviders,
|
||||
) -> Optional[CustomLogger]:
|
||||
from litellm.integrations.vector_stores.bedrock_vector_store import (
|
||||
from litellm.integrations.vector_store_integrations.bedrock_vector_store import (
|
||||
BedrockVectorStore,
|
||||
)
|
||||
|
||||
if LlmProviders.BEDROCK == provider:
|
||||
return BedrockVectorStore.get_initialized_custom_logger()
|
||||
return None
|
||||
|
||||
|
||||
@staticmethod
|
||||
def get_provider_vector_stores_config(
|
||||
provider: LlmProviders,
|
||||
) -> Optional[BaseVectorStoreConfig]:
|
||||
"""
|
||||
v2 vector store config, use this for new vector store integrations
|
||||
"""
|
||||
if litellm.LlmProviders.OPENAI == provider:
|
||||
from litellm.llms.openai.vector_stores.transformation import (
|
||||
OpenAIVectorStoreConfig,
|
||||
)
|
||||
return OpenAIVectorStoreConfig()
|
||||
elif litellm.LlmProviders.AZURE == provider:
|
||||
from litellm.llms.azure.vector_stores.transformation import (
|
||||
AzureOpenAIVectorStoreConfig,
|
||||
)
|
||||
return AzureOpenAIVectorStoreConfig()
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def get_provider_image_generation_config(
|
||||
|
||||
@@ -0,0 +1,4 @@
|
||||
from .main import acreate, asearch, create, search
|
||||
from .vector_store_registry import VectorStoreRegistry
|
||||
|
||||
__all__ = ["search", "asearch", "create", "acreate", "VectorStoreRegistry"]
|
||||
@@ -0,0 +1,434 @@
|
||||
"""
|
||||
LiteLLM SDK Functions for Creating and Searching Vector Stores
|
||||
"""
|
||||
import asyncio
|
||||
import contextvars
|
||||
from functools import partial
|
||||
from typing import Any, Coroutine, Dict, List, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm.constants import request_timeout
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
from litellm.llms.custom_httpx.llm_http_handler import BaseLLMHTTPHandler
|
||||
from litellm.types.router import GenericLiteLLMParams
|
||||
from litellm.types.vector_stores import (
|
||||
VectorStoreCreateOptionalRequestParams,
|
||||
VectorStoreCreateResponse,
|
||||
VectorStoreFileCounts,
|
||||
VectorStoreResultContent,
|
||||
VectorStoreSearchOptionalRequestParams,
|
||||
VectorStoreSearchResponse,
|
||||
VectorStoreSearchResult,
|
||||
)
|
||||
from litellm.utils import ProviderConfigManager, client
|
||||
from litellm.vector_stores.utils import VectorStoreRequestUtils
|
||||
|
||||
####### ENVIRONMENT VARIABLES ###################
|
||||
# Initialize any necessary instances or variables here
|
||||
base_llm_http_handler = BaseLLMHTTPHandler()
|
||||
#################################################
|
||||
|
||||
|
||||
def mock_vector_store_search_response(
|
||||
mock_results: Optional[List[VectorStoreSearchResult]] = None,
|
||||
):
|
||||
"""Mock response for vector store search"""
|
||||
if mock_results is None:
|
||||
mock_results = [
|
||||
VectorStoreSearchResult(
|
||||
score=0.95,
|
||||
content=[
|
||||
VectorStoreResultContent(
|
||||
text="This is a sample search result from the vector store.",
|
||||
type="text"
|
||||
)
|
||||
]
|
||||
)
|
||||
]
|
||||
|
||||
return VectorStoreSearchResponse(
|
||||
object="vector_store.search_results.page",
|
||||
search_query="sample query",
|
||||
data=mock_results,
|
||||
)
|
||||
|
||||
|
||||
def mock_vector_store_create_response(
|
||||
mock_response: Optional[VectorStoreCreateResponse] = None,
|
||||
):
|
||||
"""Mock response for vector store create"""
|
||||
if mock_response is None:
|
||||
mock_response = VectorStoreCreateResponse(
|
||||
id="vs_mock123",
|
||||
object="vector_store",
|
||||
created_at=1699061776,
|
||||
name="Mock Vector Store",
|
||||
bytes=0,
|
||||
file_counts=VectorStoreFileCounts(
|
||||
in_progress=0,
|
||||
completed=0,
|
||||
failed=0,
|
||||
cancelled=0,
|
||||
total=0,
|
||||
),
|
||||
status="completed",
|
||||
expires_after=None,
|
||||
expires_at=None,
|
||||
last_active_at=None,
|
||||
metadata=None,
|
||||
)
|
||||
|
||||
return mock_response
|
||||
|
||||
|
||||
@client
|
||||
async def acreate(
|
||||
name: Optional[str] = None,
|
||||
file_ids: Optional[List[str]] = None,
|
||||
expires_after: Optional[Dict] = None,
|
||||
chunking_strategy: Optional[Dict] = None,
|
||||
metadata: Optional[Dict[str, str]] = None,
|
||||
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
|
||||
# The extra values given here take precedence over values defined on the client or passed to this method.
|
||||
extra_headers: Optional[Dict[str, Any]] = None,
|
||||
extra_query: Optional[Dict[str, Any]] = None,
|
||||
extra_body: Optional[Dict[str, Any]] = None,
|
||||
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||
# LiteLLM specific params,
|
||||
custom_llm_provider: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> VectorStoreCreateResponse:
|
||||
"""
|
||||
Async: Create a vector store.
|
||||
"""
|
||||
local_vars = locals()
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
kwargs["acreate"] = True
|
||||
|
||||
# get custom llm provider so we can use this for mapping exceptions
|
||||
if custom_llm_provider is None:
|
||||
custom_llm_provider = "openai" # Default to OpenAI for vector stores
|
||||
|
||||
func = partial(
|
||||
create,
|
||||
name=name,
|
||||
file_ids=file_ids,
|
||||
expires_after=expires_after,
|
||||
chunking_strategy=chunking_strategy,
|
||||
metadata=metadata,
|
||||
extra_headers=extra_headers,
|
||||
extra_query=extra_query,
|
||||
extra_body=extra_body,
|
||||
timeout=timeout,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
ctx = contextvars.copy_context()
|
||||
func_with_context = partial(ctx.run, func)
|
||||
init_response = await loop.run_in_executor(None, func_with_context)
|
||||
|
||||
if asyncio.iscoroutine(init_response):
|
||||
response = await init_response
|
||||
else:
|
||||
response = init_response
|
||||
|
||||
return response
|
||||
except Exception as e:
|
||||
raise litellm.exception_type(
|
||||
model=None,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
original_exception=e,
|
||||
completion_kwargs=local_vars,
|
||||
extra_kwargs=kwargs,
|
||||
)
|
||||
|
||||
|
||||
@client
|
||||
def create(
|
||||
name: Optional[str] = None,
|
||||
file_ids: Optional[List[str]] = None,
|
||||
expires_after: Optional[Dict] = None,
|
||||
chunking_strategy: Optional[Dict] = None,
|
||||
metadata: Optional[Dict[str, str]] = None,
|
||||
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
|
||||
# The extra values given here take precedence over values defined on the client or passed to this method.
|
||||
extra_headers: Optional[Dict[str, Any]] = None,
|
||||
extra_query: Optional[Dict[str, Any]] = None,
|
||||
extra_body: Optional[Dict[str, Any]] = None,
|
||||
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||
# LiteLLM specific params,
|
||||
custom_llm_provider: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> Union[VectorStoreCreateResponse, Coroutine[Any, Any, VectorStoreCreateResponse]]:
|
||||
"""
|
||||
Create a vector store.
|
||||
|
||||
Args:
|
||||
name: The name of the vector store.
|
||||
file_ids: A list of File IDs that the vector store should use.
|
||||
expires_after: The expiration policy for the vector store.
|
||||
chunking_strategy: The chunking strategy used to chunk the file(s).
|
||||
metadata: Set of 16 key-value pairs that can be attached to an object.
|
||||
|
||||
Returns:
|
||||
VectorStoreCreateResponse containing the created vector store details.
|
||||
"""
|
||||
local_vars = locals()
|
||||
try:
|
||||
litellm_logging_obj: LiteLLMLoggingObj = kwargs.get("litellm_logging_obj") # type: ignore
|
||||
litellm_call_id: Optional[str] = kwargs.get("litellm_call_id", None)
|
||||
_is_async = kwargs.pop("acreate", False) is True
|
||||
|
||||
# get llm provider logic
|
||||
litellm_params = GenericLiteLLMParams(**kwargs)
|
||||
|
||||
## MOCK RESPONSE LOGIC
|
||||
if litellm_params.mock_response and isinstance(
|
||||
litellm_params.mock_response, dict
|
||||
):
|
||||
return mock_vector_store_create_response(
|
||||
mock_response=VectorStoreCreateResponse(**litellm_params.mock_response)
|
||||
)
|
||||
|
||||
# Default to OpenAI for vector stores
|
||||
if custom_llm_provider is None:
|
||||
custom_llm_provider = "openai"
|
||||
|
||||
# get provider config - using vector store custom logger for now
|
||||
vector_store_provider_config = ProviderConfigManager.get_provider_vector_stores_config(
|
||||
provider=litellm.LlmProviders(custom_llm_provider),
|
||||
)
|
||||
|
||||
if vector_store_provider_config is None:
|
||||
raise ValueError(
|
||||
f"Vector store create is not supported for {custom_llm_provider}"
|
||||
)
|
||||
|
||||
local_vars.update(kwargs)
|
||||
|
||||
# Get VectorStoreCreateOptionalRequestParams with only valid parameters
|
||||
vector_store_create_optional_params: VectorStoreCreateOptionalRequestParams = (
|
||||
VectorStoreRequestUtils.get_requested_vector_store_create_optional_param(
|
||||
local_vars
|
||||
)
|
||||
)
|
||||
|
||||
# Pre Call logging
|
||||
litellm_logging_obj.update_environment_variables(
|
||||
model=None,
|
||||
optional_params={
|
||||
"name": name,
|
||||
**vector_store_create_optional_params,
|
||||
},
|
||||
litellm_params={
|
||||
"litellm_call_id": litellm_call_id,
|
||||
},
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
)
|
||||
|
||||
response = base_llm_http_handler.vector_store_create_handler(
|
||||
vector_store_create_optional_params=vector_store_create_optional_params,
|
||||
vector_store_provider_config=vector_store_provider_config,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
litellm_params=litellm_params,
|
||||
logging_obj=litellm_logging_obj,
|
||||
extra_headers=extra_headers,
|
||||
extra_body=extra_body,
|
||||
timeout=timeout or request_timeout,
|
||||
_is_async=_is_async,
|
||||
client=kwargs.get("client"),
|
||||
)
|
||||
|
||||
return response
|
||||
except Exception as e:
|
||||
raise litellm.exception_type(
|
||||
model=None,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
original_exception=e,
|
||||
completion_kwargs=local_vars,
|
||||
extra_kwargs=kwargs,
|
||||
)
|
||||
|
||||
|
||||
@client
|
||||
async def asearch(
|
||||
vector_store_id: str,
|
||||
query: Union[str, List[str]],
|
||||
filters: Optional[Dict] = None,
|
||||
max_num_results: Optional[int] = None,
|
||||
ranking_options: Optional[Dict] = None,
|
||||
rewrite_query: Optional[bool] = None,
|
||||
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
|
||||
# The extra values given here take precedence over values defined on the client or passed to this method.
|
||||
extra_headers: Optional[Dict[str, Any]] = None,
|
||||
extra_query: Optional[Dict[str, Any]] = None,
|
||||
extra_body: Optional[Dict[str, Any]] = None,
|
||||
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||
# LiteLLM specific params,
|
||||
custom_llm_provider: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> VectorStoreSearchResponse:
|
||||
"""
|
||||
Async: Search a vector store for relevant chunks based on a query and file attributes filter.
|
||||
"""
|
||||
local_vars = locals()
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
kwargs["asearch"] = True
|
||||
|
||||
# get custom llm provider so we can use this for mapping exceptions
|
||||
if custom_llm_provider is None:
|
||||
custom_llm_provider = "openai" # Default to OpenAI for vector stores
|
||||
|
||||
func = partial(
|
||||
search,
|
||||
vector_store_id=vector_store_id,
|
||||
query=query,
|
||||
filters=filters,
|
||||
max_num_results=max_num_results,
|
||||
ranking_options=ranking_options,
|
||||
rewrite_query=rewrite_query,
|
||||
extra_headers=extra_headers,
|
||||
extra_query=extra_query,
|
||||
extra_body=extra_body,
|
||||
timeout=timeout,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
ctx = contextvars.copy_context()
|
||||
func_with_context = partial(ctx.run, func)
|
||||
init_response = await loop.run_in_executor(None, func_with_context)
|
||||
|
||||
if asyncio.iscoroutine(init_response):
|
||||
response = await init_response
|
||||
else:
|
||||
response = init_response
|
||||
|
||||
return response
|
||||
except Exception as e:
|
||||
raise litellm.exception_type(
|
||||
model=None,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
original_exception=e,
|
||||
completion_kwargs=local_vars,
|
||||
extra_kwargs=kwargs,
|
||||
)
|
||||
|
||||
|
||||
@client
|
||||
def search(
|
||||
vector_store_id: str,
|
||||
query: Union[str, List[str]],
|
||||
filters: Optional[Dict] = None,
|
||||
max_num_results: Optional[int] = None,
|
||||
ranking_options: Optional[Dict] = None,
|
||||
rewrite_query: Optional[bool] = None,
|
||||
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
|
||||
# The extra values given here take precedence over values defined on the client or passed to this method.
|
||||
extra_headers: Optional[Dict[str, Any]] = None,
|
||||
extra_query: Optional[Dict[str, Any]] = None,
|
||||
extra_body: Optional[Dict[str, Any]] = None,
|
||||
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||
# LiteLLM specific params,
|
||||
custom_llm_provider: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> Union[VectorStoreSearchResponse, Coroutine[Any, Any, VectorStoreSearchResponse]]:
|
||||
"""
|
||||
Search a vector store for relevant chunks based on a query and file attributes filter.
|
||||
|
||||
Args:
|
||||
vector_store_id: The ID of the vector store to search.
|
||||
query: A query string or array for the search.
|
||||
filters: Optional filter to apply based on file attributes.
|
||||
max_num_results: Maximum number of results to return (1-50, default 10).
|
||||
ranking_options: Optional ranking options for search.
|
||||
rewrite_query: Whether to rewrite the natural language query for vector search.
|
||||
|
||||
Returns:
|
||||
VectorStoreSearchResponse containing the search results.
|
||||
"""
|
||||
local_vars = locals()
|
||||
try:
|
||||
litellm_logging_obj: LiteLLMLoggingObj = kwargs.get("litellm_logging_obj") # type: ignore
|
||||
litellm_call_id: Optional[str] = kwargs.get("litellm_call_id", None)
|
||||
_is_async = kwargs.pop("asearch", False) is True
|
||||
|
||||
# get llm provider logic
|
||||
litellm_params = GenericLiteLLMParams(**kwargs)
|
||||
|
||||
## MOCK RESPONSE LOGIC
|
||||
if litellm_params.mock_response and isinstance(
|
||||
litellm_params.mock_response, (str, list)
|
||||
):
|
||||
mock_results = None
|
||||
if isinstance(litellm_params.mock_response, list):
|
||||
mock_results = litellm_params.mock_response
|
||||
return mock_vector_store_search_response(mock_results=mock_results)
|
||||
|
||||
# Default to OpenAI for vector stores
|
||||
if custom_llm_provider is None:
|
||||
custom_llm_provider = "openai"
|
||||
|
||||
# get provider config - using vector store custom logger for now
|
||||
vector_store_provider_config = ProviderConfigManager.get_provider_vector_stores_config(
|
||||
provider=litellm.LlmProviders(custom_llm_provider),
|
||||
)
|
||||
|
||||
if vector_store_provider_config is None:
|
||||
raise ValueError(
|
||||
f"Vector store search is not supported for {custom_llm_provider}"
|
||||
)
|
||||
|
||||
local_vars.update(kwargs)
|
||||
|
||||
# Get VectorStoreSearchOptionalRequestParams with only valid parameters
|
||||
vector_store_search_optional_params: VectorStoreSearchOptionalRequestParams = (
|
||||
VectorStoreRequestUtils.get_requested_vector_store_search_optional_param(
|
||||
local_vars
|
||||
)
|
||||
)
|
||||
|
||||
# Pre Call logging
|
||||
litellm_logging_obj.update_environment_variables(
|
||||
model=None,
|
||||
optional_params={
|
||||
"vector_store_id": vector_store_id,
|
||||
"query": query,
|
||||
**vector_store_search_optional_params,
|
||||
},
|
||||
litellm_params={
|
||||
"litellm_call_id": litellm_call_id,
|
||||
"vector_store_id": vector_store_id,
|
||||
},
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
)
|
||||
|
||||
response = base_llm_http_handler.vector_store_search_handler(
|
||||
vector_store_id=vector_store_id,
|
||||
query=query,
|
||||
vector_store_search_optional_params=vector_store_search_optional_params,
|
||||
vector_store_provider_config=vector_store_provider_config,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
litellm_params=litellm_params,
|
||||
logging_obj=litellm_logging_obj,
|
||||
extra_headers=extra_headers,
|
||||
extra_body=extra_body,
|
||||
timeout=timeout or request_timeout,
|
||||
_is_async=_is_async,
|
||||
client=kwargs.get("client"),
|
||||
)
|
||||
|
||||
return response
|
||||
except Exception as e:
|
||||
raise litellm.exception_type(
|
||||
model=None,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
original_exception=e,
|
||||
completion_kwargs=local_vars,
|
||||
extra_kwargs=kwargs,
|
||||
)
|
||||
@@ -0,0 +1,51 @@
|
||||
from typing import Any, Dict, cast, get_type_hints
|
||||
|
||||
from litellm.types.vector_stores import (
|
||||
VectorStoreCreateOptionalRequestParams,
|
||||
VectorStoreSearchOptionalRequestParams,
|
||||
)
|
||||
|
||||
|
||||
class VectorStoreRequestUtils:
|
||||
"""Helper utils for constructing Vector Store search requests"""
|
||||
|
||||
@staticmethod
|
||||
def get_requested_vector_store_search_optional_param(
|
||||
params: Dict[str, Any],
|
||||
) -> VectorStoreSearchOptionalRequestParams:
|
||||
"""
|
||||
Filter parameters to only include those defined in VectorStoreSearchOptionalRequestParams.
|
||||
|
||||
Args:
|
||||
params: Dictionary of parameters to filter
|
||||
|
||||
Returns:
|
||||
VectorStoreSearchOptionalRequestParams instance with only the valid parameters
|
||||
"""
|
||||
valid_keys = get_type_hints(VectorStoreSearchOptionalRequestParams).keys()
|
||||
filtered_params = {
|
||||
k: v for k, v in params.items() if k in valid_keys and v is not None
|
||||
}
|
||||
|
||||
return cast(VectorStoreSearchOptionalRequestParams, filtered_params)
|
||||
|
||||
@staticmethod
|
||||
def get_requested_vector_store_create_optional_param(
|
||||
params: Dict[str, Any],
|
||||
) -> VectorStoreCreateOptionalRequestParams:
|
||||
"""
|
||||
Filter parameters to only include those defined in VectorStoreCreateOptionalRequestParams.
|
||||
|
||||
Args:
|
||||
params: Dictionary of parameters to filter
|
||||
|
||||
Returns:
|
||||
VectorStoreCreateOptionalRequestParams instance with only the valid parameters
|
||||
"""
|
||||
valid_keys = get_type_hints(VectorStoreCreateOptionalRequestParams).keys()
|
||||
filtered_params = {
|
||||
k: v for k, v in params.items() if k in valid_keys and v is not None
|
||||
}
|
||||
|
||||
return cast(VectorStoreCreateOptionalRequestParams, filtered_params)
|
||||
|
||||
@@ -4046,7 +4046,8 @@
|
||||
"litellm_provider": "mistral",
|
||||
"mode": "chat",
|
||||
"supports_assistant_prefill": true,
|
||||
"supports_tool_choice": true
|
||||
"supports_tool_choice": true,
|
||||
"supports_response_schema": true
|
||||
},
|
||||
"mistral/mistral-small": {
|
||||
"max_tokens": 8191,
|
||||
@@ -4058,7 +4059,8 @@
|
||||
"supports_function_calling": true,
|
||||
"mode": "chat",
|
||||
"supports_assistant_prefill": true,
|
||||
"supports_tool_choice": true
|
||||
"supports_tool_choice": true,
|
||||
"supports_response_schema": true
|
||||
},
|
||||
"mistral/mistral-small-latest": {
|
||||
"max_tokens": 8191,
|
||||
@@ -4070,7 +4072,8 @@
|
||||
"supports_function_calling": true,
|
||||
"mode": "chat",
|
||||
"supports_assistant_prefill": true,
|
||||
"supports_tool_choice": true
|
||||
"supports_tool_choice": true,
|
||||
"supports_response_schema": true
|
||||
},
|
||||
"mistral/mistral-medium": {
|
||||
"max_tokens": 8191,
|
||||
@@ -4081,7 +4084,8 @@
|
||||
"litellm_provider": "mistral",
|
||||
"mode": "chat",
|
||||
"supports_assistant_prefill": true,
|
||||
"supports_tool_choice": true
|
||||
"supports_tool_choice": true,
|
||||
"supports_response_schema": true
|
||||
},
|
||||
"mistral/mistral-medium-latest": {
|
||||
"max_tokens": 8191,
|
||||
@@ -4093,7 +4097,8 @@
|
||||
"mode": "chat",
|
||||
"supports_function_calling": true,
|
||||
"supports_assistant_prefill": true,
|
||||
"supports_tool_choice": true
|
||||
"supports_tool_choice": true,
|
||||
"supports_response_schema": true
|
||||
},
|
||||
"mistral/mistral-medium-2505": {
|
||||
"max_tokens": 8191,
|
||||
@@ -4105,7 +4110,8 @@
|
||||
"mode": "chat",
|
||||
"supports_function_calling": true,
|
||||
"supports_assistant_prefill": true,
|
||||
"supports_tool_choice": true
|
||||
"supports_tool_choice": true,
|
||||
"supports_response_schema": true
|
||||
},
|
||||
"mistral/mistral-medium-2312": {
|
||||
"max_tokens": 8191,
|
||||
@@ -4116,7 +4122,8 @@
|
||||
"litellm_provider": "mistral",
|
||||
"mode": "chat",
|
||||
"supports_assistant_prefill": true,
|
||||
"supports_tool_choice": true
|
||||
"supports_tool_choice": true,
|
||||
"supports_response_schema": true
|
||||
},
|
||||
"mistral/mistral-large-latest": {
|
||||
"max_tokens": 128000,
|
||||
@@ -4128,7 +4135,8 @@
|
||||
"mode": "chat",
|
||||
"supports_function_calling": true,
|
||||
"supports_assistant_prefill": true,
|
||||
"supports_tool_choice": true
|
||||
"supports_tool_choice": true,
|
||||
"supports_response_schema": true
|
||||
},
|
||||
"mistral/mistral-large-2411": {
|
||||
"max_tokens": 128000,
|
||||
@@ -4140,7 +4148,8 @@
|
||||
"mode": "chat",
|
||||
"supports_function_calling": true,
|
||||
"supports_assistant_prefill": true,
|
||||
"supports_tool_choice": true
|
||||
"supports_tool_choice": true,
|
||||
"supports_response_schema": true
|
||||
},
|
||||
"mistral/mistral-large-2402": {
|
||||
"max_tokens": 8191,
|
||||
@@ -4152,7 +4161,8 @@
|
||||
"mode": "chat",
|
||||
"supports_function_calling": true,
|
||||
"supports_assistant_prefill": true,
|
||||
"supports_tool_choice": true
|
||||
"supports_tool_choice": true,
|
||||
"supports_response_schema": true
|
||||
},
|
||||
"mistral/mistral-large-2407": {
|
||||
"max_tokens": 128000,
|
||||
@@ -4164,7 +4174,8 @@
|
||||
"mode": "chat",
|
||||
"supports_function_calling": true,
|
||||
"supports_assistant_prefill": true,
|
||||
"supports_tool_choice": true
|
||||
"supports_tool_choice": true,
|
||||
"supports_response_schema": true
|
||||
},
|
||||
"mistral/pixtral-large-latest": {
|
||||
"max_tokens": 128000,
|
||||
@@ -4177,7 +4188,8 @@
|
||||
"supports_function_calling": true,
|
||||
"supports_assistant_prefill": true,
|
||||
"supports_vision": true,
|
||||
"supports_tool_choice": true
|
||||
"supports_tool_choice": true,
|
||||
"supports_response_schema": true
|
||||
},
|
||||
"mistral/pixtral-large-2411": {
|
||||
"max_tokens": 128000,
|
||||
@@ -4190,7 +4202,8 @@
|
||||
"supports_function_calling": true,
|
||||
"supports_assistant_prefill": true,
|
||||
"supports_vision": true,
|
||||
"supports_tool_choice": true
|
||||
"supports_tool_choice": true,
|
||||
"supports_response_schema": true
|
||||
},
|
||||
"mistral/pixtral-12b-2409": {
|
||||
"max_tokens": 128000,
|
||||
@@ -4203,7 +4216,8 @@
|
||||
"supports_function_calling": true,
|
||||
"supports_assistant_prefill": true,
|
||||
"supports_vision": true,
|
||||
"supports_tool_choice": true
|
||||
"supports_tool_choice": true,
|
||||
"supports_response_schema": true
|
||||
},
|
||||
"mistral/open-mistral-7b": {
|
||||
"max_tokens": 8191,
|
||||
@@ -4214,7 +4228,8 @@
|
||||
"litellm_provider": "mistral",
|
||||
"mode": "chat",
|
||||
"supports_assistant_prefill": true,
|
||||
"supports_tool_choice": true
|
||||
"supports_tool_choice": true,
|
||||
"supports_response_schema": true
|
||||
},
|
||||
"mistral/open-mixtral-8x7b": {
|
||||
"max_tokens": 8191,
|
||||
@@ -4226,7 +4241,8 @@
|
||||
"mode": "chat",
|
||||
"supports_function_calling": true,
|
||||
"supports_assistant_prefill": true,
|
||||
"supports_tool_choice": true
|
||||
"supports_tool_choice": true,
|
||||
"supports_response_schema": true
|
||||
},
|
||||
"mistral/open-mixtral-8x22b": {
|
||||
"max_tokens": 8191,
|
||||
@@ -4238,7 +4254,8 @@
|
||||
"mode": "chat",
|
||||
"supports_function_calling": true,
|
||||
"supports_assistant_prefill": true,
|
||||
"supports_tool_choice": true
|
||||
"supports_tool_choice": true,
|
||||
"supports_response_schema": true
|
||||
},
|
||||
"mistral/codestral-latest": {
|
||||
"max_tokens": 8191,
|
||||
@@ -4249,7 +4266,8 @@
|
||||
"litellm_provider": "mistral",
|
||||
"mode": "chat",
|
||||
"supports_assistant_prefill": true,
|
||||
"supports_tool_choice": true
|
||||
"supports_tool_choice": true,
|
||||
"supports_response_schema": true
|
||||
},
|
||||
"mistral/codestral-2405": {
|
||||
"max_tokens": 8191,
|
||||
@@ -4260,7 +4278,8 @@
|
||||
"litellm_provider": "mistral",
|
||||
"mode": "chat",
|
||||
"supports_assistant_prefill": true,
|
||||
"supports_tool_choice": true
|
||||
"supports_tool_choice": true,
|
||||
"supports_response_schema": true
|
||||
},
|
||||
"mistral/open-mistral-nemo": {
|
||||
"max_tokens": 128000,
|
||||
@@ -4272,7 +4291,8 @@
|
||||
"mode": "chat",
|
||||
"source": "https://mistral.ai/technology/",
|
||||
"supports_assistant_prefill": true,
|
||||
"supports_tool_choice": true
|
||||
"supports_tool_choice": true,
|
||||
"supports_response_schema": true
|
||||
},
|
||||
"mistral/open-mistral-nemo-2407": {
|
||||
"max_tokens": 128000,
|
||||
@@ -4284,7 +4304,8 @@
|
||||
"mode": "chat",
|
||||
"source": "https://mistral.ai/technology/",
|
||||
"supports_assistant_prefill": true,
|
||||
"supports_tool_choice": true
|
||||
"supports_tool_choice": true,
|
||||
"supports_response_schema": true
|
||||
},
|
||||
"mistral/open-codestral-mamba": {
|
||||
"max_tokens": 256000,
|
||||
@@ -4321,7 +4342,8 @@
|
||||
"source": "https://mistral.ai/news/devstral",
|
||||
"supports_function_calling": true,
|
||||
"supports_assistant_prefill": true,
|
||||
"supports_tool_choice": true
|
||||
"supports_tool_choice": true,
|
||||
"supports_response_schema": true
|
||||
},
|
||||
"mistral/magistral-medium-latest": {
|
||||
"max_tokens": 40000,
|
||||
@@ -4335,7 +4357,8 @@
|
||||
"supports_function_calling": true,
|
||||
"supports_assistant_prefill": true,
|
||||
"supports_tool_choice": true,
|
||||
"supports_reasoning": true
|
||||
"supports_reasoning": true,
|
||||
"supports_response_schema": true
|
||||
},
|
||||
"mistral/magistral-medium-2506": {
|
||||
"max_tokens": 40000,
|
||||
@@ -4349,7 +4372,8 @@
|
||||
"supports_function_calling": true,
|
||||
"supports_assistant_prefill": true,
|
||||
"supports_tool_choice": true,
|
||||
"supports_reasoning": true
|
||||
"supports_reasoning": true,
|
||||
"supports_response_schema": true
|
||||
},
|
||||
"mistral/magistral-small-latest": {
|
||||
"max_tokens": 40000,
|
||||
@@ -4363,7 +4387,8 @@
|
||||
"supports_function_calling": true,
|
||||
"supports_assistant_prefill": true,
|
||||
"supports_tool_choice": true,
|
||||
"supports_reasoning": true
|
||||
"supports_reasoning": true,
|
||||
"supports_response_schema": true
|
||||
},
|
||||
"mistral/magistral-small-2506": {
|
||||
"max_tokens": 40000,
|
||||
@@ -4377,7 +4402,8 @@
|
||||
"supports_function_calling": true,
|
||||
"supports_assistant_prefill": true,
|
||||
"supports_tool_choice": true,
|
||||
"supports_reasoning": true
|
||||
"supports_reasoning": true,
|
||||
"supports_response_schema": true
|
||||
},
|
||||
"mistral/mistral-embed": {
|
||||
"max_tokens": 8192,
|
||||
|
||||
+8
-12
@@ -1,24 +1,20 @@
|
||||
import os
|
||||
import json
|
||||
|
||||
gemini_model_cost_map = json.load(open("model_prices_and_context_window.json"))
|
||||
mistral_model_cost_map = json.load(open("model_prices_and_context_window.json"))
|
||||
|
||||
for model, model_info in gemini_model_cost_map.items():
|
||||
for model, model_info in mistral_model_cost_map.items():
|
||||
if (
|
||||
(
|
||||
model_info.get("litellm_provider") == "gemini"
|
||||
or model_info.get("litellm_provider") == "vertex_ai-language-models"
|
||||
)
|
||||
(model_info.get("litellm_provider") == "mistral")
|
||||
and model_info.get("mode") == "chat"
|
||||
and ("gemini-2.5" in model and "tts" not in model)
|
||||
and model_info.get("supports_pdf_input") is None
|
||||
and ("codestral-mamba" not in model)
|
||||
):
|
||||
"""
|
||||
Update all gemini chat models to support pdf input
|
||||
Update all mistral models to supports_response_schema
|
||||
"""
|
||||
model_info["supports_pdf_input"] = True
|
||||
print(f"Updated {model} to support pdf input")
|
||||
model_info["supports_response_schema"] = True
|
||||
print(f"Updated {model} to support response schema")
|
||||
|
||||
json.dump(
|
||||
gemini_model_cost_map, open("model_prices_and_context_window.json", "w"), indent=4
|
||||
mistral_model_cost_map, open("model_prices_and_context_window.json", "w"), indent=4
|
||||
)
|
||||
|
||||
@@ -594,6 +594,7 @@ async def test_azure_embedding_max_retries_0(
|
||||
def test_azure_safety_result():
|
||||
"""Bubble up safety result from Azure OpenAI"""
|
||||
from litellm import completion
|
||||
|
||||
litellm._turn_on_debug()
|
||||
|
||||
response = completion(
|
||||
@@ -602,4 +603,30 @@ def test_azure_safety_result():
|
||||
)
|
||||
print(f"response: {response}")
|
||||
assert response.choices[0].message.content is not None
|
||||
assert response.choices[0].provider_specific_fields is not None
|
||||
assert response.choices[0].provider_specific_fields is not None
|
||||
|
||||
|
||||
def test_azure_openai_responses_bridge():
|
||||
from litellm import completion
|
||||
import litellm
|
||||
|
||||
litellm._turn_on_debug()
|
||||
|
||||
with patch.object(litellm, "responses") as mock_responses:
|
||||
try:
|
||||
response = completion(
|
||||
model="azure/responses/test-azure-computer-use-preview",
|
||||
messages=[{"role": "user", "content": "Hello world"}],
|
||||
api_base=os.getenv("AZURE_COMPUTER_USE_API_BASE"),
|
||||
api_version="2025-04-01-preview",
|
||||
api_key=os.getenv("AZURE_COMPUTER_USE_API_KEY"),
|
||||
)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
||||
mock_responses.assert_called_once()
|
||||
assert (
|
||||
mock_responses.call_args.kwargs["model"]
|
||||
== "test-azure-computer-use-preview"
|
||||
)
|
||||
assert mock_responses.call_args.kwargs["custom_llm_provider"] == "azure"
|
||||
|
||||
@@ -19,7 +19,7 @@ import pytest
|
||||
import litellm
|
||||
from litellm import completion
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.integrations.vector_stores.bedrock_vector_store import BedrockVectorStore
|
||||
from litellm.integrations.vector_store_integrations.bedrock_vector_store import BedrockVectorStore
|
||||
from litellm.llms.custom_httpx.http_handler import HTTPHandler, AsyncHTTPHandler
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.types.utils import StandardLoggingPayload, StandardLoggingVectorStoreRequest
|
||||
|
||||
@@ -446,6 +446,40 @@ async def test_deployment_callback_on_failure(model_list):
|
||||
)
|
||||
|
||||
|
||||
def test_deployment_callback_respects_cooldown_time(model_list):
|
||||
"""Ensure per-model cooldown_time is honored even when exception headers are present."""
|
||||
import httpx
|
||||
import time
|
||||
from unittest.mock import patch
|
||||
|
||||
router = Router(model_list=model_list)
|
||||
|
||||
class FakeException(Exception):
|
||||
def __init__(self):
|
||||
self.status_code = 429
|
||||
self.headers = httpx.Headers({"x-test": "1"})
|
||||
|
||||
kwargs = {
|
||||
"exception": FakeException(),
|
||||
"litellm_params": {
|
||||
"metadata": {"model_group": "gpt-3.5-turbo"},
|
||||
"model_info": {"id": 100},
|
||||
"cooldown_time": 0,
|
||||
},
|
||||
}
|
||||
|
||||
with patch("litellm.router._set_cooldown_deployments") as mock_set:
|
||||
router.deployment_callback_on_failure(
|
||||
kwargs=kwargs,
|
||||
completion_response=None,
|
||||
start_time=time.time(),
|
||||
end_time=time.time(),
|
||||
)
|
||||
|
||||
mock_set.assert_called_once()
|
||||
assert mock_set.call_args.kwargs["time_to_cooldown"] == 0
|
||||
|
||||
|
||||
def test_log_retry(model_list):
|
||||
"""Test if the '_log_retry' function is working correctly"""
|
||||
import time
|
||||
|
||||
+30
@@ -34,3 +34,33 @@ def test_anthropic_experimental_pass_through_messages_handler():
|
||||
print(f"Error: {e}")
|
||||
mock_completion.assert_called_once()
|
||||
mock_completion.call_args.kwargs["api_key"] == "test-api-key"
|
||||
|
||||
|
||||
def test_anthropic_experimental_pass_through_messages_handler_custom_llm_provider():
|
||||
"""
|
||||
Test that litellm.completion is called when a custom LLM provider is given
|
||||
"""
|
||||
from litellm.llms.anthropic.experimental_pass_through.messages.handler import (
|
||||
anthropic_messages_handler,
|
||||
)
|
||||
|
||||
with patch("litellm.completion", return_value="test-response") as mock_completion:
|
||||
try:
|
||||
anthropic_messages_handler(
|
||||
max_tokens=100,
|
||||
messages=[{"role": "user", "content": "Hello, how are you?"}],
|
||||
model="my-custom-model",
|
||||
custom_llm_provider="my-custom-llm",
|
||||
api_key="test-api-key",
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
|
||||
# Assert that litellm.completion was called when using a custom LLM provider
|
||||
mock_completion.assert_called_once()
|
||||
|
||||
# Verify that the custom provider was passed through
|
||||
call_kwargs = mock_completion.call_args.kwargs
|
||||
assert call_kwargs["custom_llm_provider"] == "my-custom-llm"
|
||||
assert call_kwargs["model"] == "my-custom-llm/my-custom-model"
|
||||
assert call_kwargs["api_key"] == "test-api-key"
|
||||
|
||||
@@ -54,9 +54,9 @@ def test_validate_environment_azure_key_within_litellm():
|
||||
def test_validate_environment_azure_openai_api_key_within_secret_str():
|
||||
azure_openai_responses_apiconfig = AzureOpenAIResponsesAPIConfig()
|
||||
|
||||
with patch(
|
||||
"litellm.llms.azure.responses.transformation.get_secret_str"
|
||||
) as mock_get_secret_str:
|
||||
with patch("litellm.api_key", None), \
|
||||
patch("litellm.azure_key", None), \
|
||||
patch("litellm.llms.azure.common_utils.get_secret_str") as mock_get_secret_str:
|
||||
# Configure the mock to return "test-api-key" when called with "AZURE_OPENAI_API_KEY"
|
||||
mock_get_secret_str.side_effect = (
|
||||
lambda key: "test-api-key" if key == "AZURE_OPENAI_API_KEY" else None
|
||||
@@ -74,13 +74,19 @@ def test_validate_environment_azure_openai_api_key_within_secret_str():
|
||||
def test_validate_environment_azure_api_key_within_secret_str():
|
||||
azure_openai_responses_apiconfig = AzureOpenAIResponsesAPIConfig()
|
||||
|
||||
with patch(
|
||||
"litellm.llms.azure.responses.transformation.get_secret_str"
|
||||
) as mock_get_secret_str:
|
||||
# Configure the mock to return "test-api-key" when called with "AZURE_API_KEY"
|
||||
mock_get_secret_str.side_effect = (
|
||||
lambda key: "test-api-key" if key == "AZURE_API_KEY" else None
|
||||
)
|
||||
with patch("litellm.api_key", None), \
|
||||
patch("litellm.azure_key", None), \
|
||||
patch("litellm.llms.azure.common_utils.get_secret_str") as mock_get_secret_str:
|
||||
# Configure the mock to return None for "AZURE_OPENAI_API_KEY" and "test-api-key" for "AZURE_API_KEY"
|
||||
def mock_side_effect(key):
|
||||
if key == "AZURE_OPENAI_API_KEY":
|
||||
return None
|
||||
elif key == "AZURE_API_KEY":
|
||||
return "test-api-key"
|
||||
else:
|
||||
return None
|
||||
|
||||
mock_get_secret_str.side_effect = mock_side_effect
|
||||
|
||||
litellm_params = GenericLiteLLMParams()
|
||||
result = azure_openai_responses_apiconfig.validate_environment(
|
||||
|
||||
@@ -95,10 +95,6 @@ class TestMistralReasoningSupport:
|
||||
def test_get_mistral_reasoning_system_prompt(self):
|
||||
"""Test that the reasoning system prompt is properly formatted."""
|
||||
prompt = MistralConfig._get_mistral_reasoning_system_prompt()
|
||||
|
||||
assert "<think>" in prompt
|
||||
assert "</think>" in prompt
|
||||
assert "step-by-step" in prompt
|
||||
assert isinstance(prompt, str)
|
||||
assert len(prompt) > 50 # Ensure it's not empty
|
||||
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from litellm.types.utils import EmbeddingResponse
|
||||
|
||||
from litellm.llms.ollama.completion.handler import ollama_embeddings, ollama_aembeddings
|
||||
import pytest
|
||||
|
||||
from litellm.llms.ollama.completion.handler import ollama_aembeddings, ollama_embeddings
|
||||
from litellm.types.utils import EmbeddingResponse
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -26,8 +27,9 @@ def mock_encoding():
|
||||
|
||||
|
||||
def test_ollama_embeddings(mock_response_data, mock_embedding_response, mock_encoding):
|
||||
with patch("litellm.module_level_client.post") as mock_post, \
|
||||
patch("litellm.OllamaConfig.get_config", return_value={"truncate": 512}):
|
||||
with patch("litellm.module_level_client.post") as mock_post, patch(
|
||||
"litellm.OllamaConfig.get_config", return_value={"truncate": 512}
|
||||
):
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = mock_response_data
|
||||
@@ -50,11 +52,17 @@ def test_ollama_embeddings(mock_response_data, mock_embedding_response, mock_enc
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ollama_aembeddings(mock_response_data, mock_embedding_response, mock_encoding):
|
||||
with patch("litellm.module_level_aclient.post", new_callable=AsyncMock) as mock_post, \
|
||||
patch("litellm.OllamaConfig.get_config", return_value={"truncate": 512}):
|
||||
|
||||
mock_post.return_value.json.return_value = mock_response_data
|
||||
async def test_ollama_aembeddings(
|
||||
mock_response_data, mock_embedding_response, mock_encoding
|
||||
):
|
||||
mock_response = AsyncMock()
|
||||
# Make json() a regular synchronous method, not async
|
||||
mock_response.json = MagicMock(return_value=mock_response_data)
|
||||
with patch(
|
||||
"litellm.module_level_aclient.post", return_value=mock_response
|
||||
) as mock_post, patch(
|
||||
"litellm.OllamaConfig.get_config", return_value={"truncate": 512}
|
||||
):
|
||||
|
||||
response = await ollama_aembeddings(
|
||||
api_base="http://localhost:11434",
|
||||
@@ -78,9 +86,10 @@ def test_prompt_eval_fallback_when_missing(mock_embedding_response, mock_encodin
|
||||
# No "prompt_eval_count"
|
||||
}
|
||||
|
||||
with patch("litellm.module_level_client.post") as mock_post, \
|
||||
patch("litellm.OllamaConfig.get_config", return_value={}):
|
||||
|
||||
with patch("litellm.module_level_client.post") as mock_post, patch(
|
||||
"litellm.OllamaConfig.get_config", return_value={}
|
||||
):
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = response_data
|
||||
mock_post.return_value = mock_response
|
||||
@@ -99,4 +108,4 @@ def test_prompt_eval_fallback_when_missing(mock_embedding_response, mock_encodin
|
||||
assert response.usage.prompt_tokens == 5
|
||||
assert response.usage.total_tokens == 5
|
||||
assert response.usage.completion_tokens == 0
|
||||
assert response.data[0]['embedding'] == [0.1, 0.2, 0.3]
|
||||
assert response.data[0]["embedding"] == [0.1, 0.2, 0.3]
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import os
|
||||
import sys
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
@@ -23,7 +24,9 @@ class TestVolcEngineConfig:
|
||||
)
|
||||
|
||||
assert mapped_params == {
|
||||
"thinking": {"type": "disabled"},
|
||||
"extra_body": {
|
||||
"thinking": {"type": "disabled"},
|
||||
}
|
||||
}
|
||||
|
||||
e2e_mapped_params = get_optional_params(
|
||||
@@ -33,6 +36,48 @@ class TestVolcEngineConfig:
|
||||
drop_params=False,
|
||||
)
|
||||
|
||||
assert "thinking" in e2e_mapped_params and e2e_mapped_params["thinking"] == {
|
||||
assert "thinking" in e2e_mapped_params["extra_body"] and e2e_mapped_params[
|
||||
"extra_body"
|
||||
]["thinking"] == {
|
||||
"type": "enabled",
|
||||
}
|
||||
|
||||
def test_e2e_completion(self):
|
||||
from openai import OpenAI
|
||||
|
||||
from litellm import completion
|
||||
from litellm.types.utils import ModelResponse
|
||||
|
||||
client = OpenAI(api_key="test_api_key")
|
||||
|
||||
mock_raw_response = MagicMock()
|
||||
mock_raw_response.headers = {
|
||||
"x-request-id": "123",
|
||||
"openai-organization": "org-123",
|
||||
"x-ratelimit-limit-requests": "100",
|
||||
"x-ratelimit-remaining-requests": "99",
|
||||
}
|
||||
mock_raw_response.parse.return_value = ModelResponse()
|
||||
|
||||
with patch.object(
|
||||
client.chat.completions.with_raw_response, "create", mock_raw_response
|
||||
) as mock_create:
|
||||
completion(
|
||||
model="volcengine/doubao-seed-1.6",
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": "**Tell me your model detail information.**",
|
||||
}
|
||||
],
|
||||
user="guest",
|
||||
stream=True,
|
||||
thinking={"type": "disabled"},
|
||||
client=client,
|
||||
)
|
||||
|
||||
mock_create.assert_called_once()
|
||||
print(mock_create.call_args.kwargs)
|
||||
assert mock_create.call_args.kwargs["extra_body"] == {
|
||||
"thinking": {"type": "disabled"},
|
||||
}
|
||||
|
||||
@@ -3,7 +3,7 @@ from unittest.mock import AsyncMock
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
|
||||
from litellm.proxy._types import NewUserRequest, ProxyException
|
||||
from litellm.proxy._types import LitellmUserRoles, NewUserRequest, ProxyException
|
||||
from litellm.proxy.management_endpoints.scim.scim_v2 import (
|
||||
UserProvisionerHelpers,
|
||||
_handle_team_membership_changes,
|
||||
@@ -58,6 +58,94 @@ async def test_create_user_existing_user_conflict(mocker):
|
||||
mocked_new_user.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_user_defaults_to_viewer(mocker, monkeypatch):
|
||||
"""If no role provided, new user should default to viewer"""
|
||||
|
||||
scim_user = SCIMUser(
|
||||
schemas=["urn:ietf:params:scim:schemas:core:2.0:User"],
|
||||
userName="new-user",
|
||||
name=SCIMUserName(familyName="User", givenName="New"),
|
||||
emails=[SCIMUserEmail(value="new@example.com")],
|
||||
)
|
||||
|
||||
mock_prisma_client = mocker.MagicMock()
|
||||
mock_prisma_client.db = mocker.MagicMock()
|
||||
mock_prisma_client.db.litellm_usertable = mocker.MagicMock()
|
||||
mock_prisma_client.db.litellm_usertable.find_unique = AsyncMock(return_value=None)
|
||||
mock_prisma_client.db.litellm_usertable.find_first = AsyncMock(return_value=None)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"litellm.default_internal_user_params", None, raising=False
|
||||
)
|
||||
|
||||
mocker.patch(
|
||||
"litellm.proxy.management_endpoints.scim.scim_v2._get_prisma_client_or_raise_exception",
|
||||
AsyncMock(return_value=mock_prisma_client),
|
||||
)
|
||||
|
||||
new_user_mock = mocker.patch(
|
||||
"litellm.proxy.management_endpoints.scim.scim_v2.new_user",
|
||||
AsyncMock(return_value=NewUserRequest(user_id="new-user")),
|
||||
)
|
||||
|
||||
mocker.patch(
|
||||
"litellm.proxy.management_endpoints.scim.scim_v2.ScimTransformations.transform_litellm_user_to_scim_user",
|
||||
AsyncMock(return_value=scim_user),
|
||||
)
|
||||
|
||||
await create_user(user=scim_user)
|
||||
|
||||
called_args = new_user_mock.call_args.kwargs["data"]
|
||||
assert called_args.user_role == LitellmUserRoles.INTERNAL_USER_VIEW_ONLY
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_user_uses_default_internal_user_params_role(mocker, monkeypatch):
|
||||
"""If role is set in default_internal_user_params, new user should use that role"""
|
||||
|
||||
scim_user = SCIMUser(
|
||||
schemas=["urn:ietf:params:scim:schemas:core:2.0:User"],
|
||||
userName="new-user",
|
||||
name=SCIMUserName(familyName="User", givenName="New"),
|
||||
emails=[SCIMUserEmail(value="new@example.com")],
|
||||
)
|
||||
|
||||
mock_prisma_client = mocker.MagicMock()
|
||||
mock_prisma_client.db = mocker.MagicMock()
|
||||
mock_prisma_client.db.litellm_usertable = mocker.MagicMock()
|
||||
mock_prisma_client.db.litellm_usertable.find_unique = AsyncMock(return_value=None)
|
||||
mock_prisma_client.db.litellm_usertable.find_first = AsyncMock(return_value=None)
|
||||
|
||||
# Set default_internal_user_params with a specific role
|
||||
default_params = {
|
||||
"user_role": LitellmUserRoles.PROXY_ADMIN,
|
||||
}
|
||||
monkeypatch.setattr(
|
||||
"litellm.default_internal_user_params", default_params, raising=False
|
||||
)
|
||||
|
||||
mocker.patch(
|
||||
"litellm.proxy.management_endpoints.scim.scim_v2._get_prisma_client_or_raise_exception",
|
||||
AsyncMock(return_value=mock_prisma_client),
|
||||
)
|
||||
|
||||
new_user_mock = mocker.patch(
|
||||
"litellm.proxy.management_endpoints.scim.scim_v2.new_user",
|
||||
AsyncMock(return_value=NewUserRequest(user_id="new-user")),
|
||||
)
|
||||
|
||||
mocker.patch(
|
||||
"litellm.proxy.management_endpoints.scim.scim_v2.ScimTransformations.transform_litellm_user_to_scim_user",
|
||||
AsyncMock(return_value=scim_user),
|
||||
)
|
||||
|
||||
await create_user(user=scim_user)
|
||||
|
||||
called_args = new_user_mock.call_args.kwargs["data"]
|
||||
assert called_args.user_role == LitellmUserRoles.PROXY_ADMIN
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_existing_user_by_email_no_email(mocker):
|
||||
"""Should return None when new_user_request has no email"""
|
||||
|
||||
@@ -696,11 +696,12 @@ async def test_get_generic_sso_response_with_additional_headers():
|
||||
"fastapi_sso.sso.generic.create_provider", return_value=mock_sso_class
|
||||
) as mock_create_provider:
|
||||
# Act
|
||||
result = await get_generic_sso_response(
|
||||
result, received_response = await get_generic_sso_response(
|
||||
request=mock_request,
|
||||
jwt_handler=mock_jwt_handler,
|
||||
generic_client_id=generic_client_id,
|
||||
redirect_url=redirect_url,
|
||||
sso_jwt_handler=None,
|
||||
)
|
||||
|
||||
# Assert
|
||||
@@ -756,11 +757,12 @@ async def test_get_generic_sso_response_with_empty_headers():
|
||||
"fastapi_sso.sso.generic.create_provider", return_value=mock_sso_class
|
||||
) as mock_create_provider:
|
||||
# Act
|
||||
result = await get_generic_sso_response(
|
||||
result, received_response = await get_generic_sso_response(
|
||||
request=mock_request,
|
||||
jwt_handler=mock_jwt_handler,
|
||||
generic_client_id=generic_client_id,
|
||||
redirect_url=redirect_url,
|
||||
sso_jwt_handler=None,
|
||||
)
|
||||
|
||||
# Assert
|
||||
|
||||
@@ -102,3 +102,25 @@ async def test_route_request_no_model_required_with_router_settings():
|
||||
|
||||
# Reset the mock for the next route
|
||||
llm_router.reset_mock()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_route_request_no_model_required_with_router_settings_and_no_router():
|
||||
"""Test route types that don't require model parameter with router settings and no router"""
|
||||
from unittest.mock import patch
|
||||
|
||||
import litellm
|
||||
from litellm.proxy.route_llm_request import route_request
|
||||
|
||||
data = {
|
||||
"model": "my-model-id",
|
||||
"api_key": "my-api-key",
|
||||
"messages": [{"role": "user", "content": "what llm are you"}],
|
||||
}
|
||||
|
||||
with patch.object(
|
||||
litellm, "acompletion", return_value="fake_response"
|
||||
) as mock_completion:
|
||||
response = await route_request(data, None, "gpt-3.5-turbo", "acompletion")
|
||||
|
||||
mock_completion.assert_called_once_with(**data)
|
||||
|
||||
@@ -0,0 +1,255 @@
|
||||
import httpx
|
||||
import json
|
||||
import pytest
|
||||
import sys
|
||||
from typing import Any, Dict, List
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
import os
|
||||
import uuid
|
||||
import time
|
||||
import base64
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
import litellm
|
||||
from abc import ABC, abstractmethod
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
import json
|
||||
from litellm.types.utils import StandardLoggingPayload
|
||||
|
||||
class BaseVectorStoreTest(ABC):
|
||||
"""
|
||||
Abstract base test class that enforces a common test across all test classes.
|
||||
"""
|
||||
@abstractmethod
|
||||
def get_base_request_args(self) -> dict:
|
||||
"""Must return the base request args"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_base_create_vector_store_args(self) -> dict:
|
||||
"""Must return the base create vector store args"""
|
||||
pass
|
||||
|
||||
@pytest.mark.parametrize("sync_mode", [True, False])
|
||||
@pytest.mark.asyncio
|
||||
async def test_basic_search_vector_store(self, sync_mode):
|
||||
litellm._turn_on_debug()
|
||||
litellm.set_verbose = True
|
||||
base_request_args = self.get_base_request_args()
|
||||
try:
|
||||
if sync_mode:
|
||||
response = litellm.vector_stores.search(
|
||||
query="Basic ping",
|
||||
**base_request_args
|
||||
)
|
||||
else:
|
||||
response = await litellm.vector_stores.asearch(
|
||||
query="Basic ping",
|
||||
**base_request_args
|
||||
)
|
||||
except litellm.InternalServerError:
|
||||
pytest.skip("Skipping test due to litellm.InternalServerError")
|
||||
|
||||
print("litellm response=", json.dumps(response, indent=4, default=str))
|
||||
|
||||
# Validate response structure
|
||||
self._validate_vector_store_response(response)
|
||||
|
||||
@pytest.mark.parametrize("sync_mode", [True, False])
|
||||
@pytest.mark.asyncio
|
||||
async def test_basic_create_vector_store(self, sync_mode):
|
||||
litellm._turn_on_debug()
|
||||
litellm.set_verbose = True
|
||||
base_request_args = self.get_base_create_vector_store_args()
|
||||
|
||||
# Extract custom_llm_provider from base args if present
|
||||
create_args = base_request_args
|
||||
try:
|
||||
if sync_mode:
|
||||
response = litellm.vector_stores.create(
|
||||
name="Test Vector Store",
|
||||
**create_args
|
||||
)
|
||||
else:
|
||||
response = await litellm.vector_stores.acreate(
|
||||
name="Test Vector Store",
|
||||
**create_args
|
||||
)
|
||||
except litellm.InternalServerError:
|
||||
pytest.skip("Skipping test due to litellm.InternalServerError")
|
||||
except Exception as e:
|
||||
# If this is an authentication or permission error, skip the test
|
||||
if "authentication" in str(e).lower() or "permission" in str(e).lower() or "unauthorized" in str(e).lower():
|
||||
pytest.skip(f"Skipping test due to authentication/permission error: {e}")
|
||||
raise
|
||||
|
||||
print("litellm create response=", json.dumps(response, indent=4, default=str))
|
||||
|
||||
# Validate response structure
|
||||
self._validate_vector_store_create_response(response)
|
||||
|
||||
def _validate_vector_store_response(self, response):
|
||||
"""Validate the structure and content of a vector store search response"""
|
||||
|
||||
# Check that response is a dictionary
|
||||
assert isinstance(response, dict), f"Response should be a dict, got {type(response)}"
|
||||
|
||||
# Check required top-level fields
|
||||
required_fields = ['object', 'search_query', 'data']
|
||||
for field in required_fields:
|
||||
assert field in response, f"Missing required field '{field}' in response"
|
||||
|
||||
# Validate object field
|
||||
assert response['object'] == 'vector_store.search_results.page', \
|
||||
f"Expected object to be 'vector_store.search_results.page', got '{response['object']}'"
|
||||
|
||||
# Validate search_query field
|
||||
assert isinstance(response['search_query'], list), \
|
||||
f"search_query should be a list, got {type(response['search_query'])}"
|
||||
assert len(response['search_query']) > 0, "search_query should not be empty"
|
||||
assert all(isinstance(query, str) for query in response['search_query']), \
|
||||
"All items in search_query should be strings"
|
||||
|
||||
# Validate data field
|
||||
assert isinstance(response['data'], list), \
|
||||
f"data should be a list, got {type(response['data'])}"
|
||||
|
||||
# Validate each result in data
|
||||
for i, result in enumerate(response['data']):
|
||||
self._validate_search_result(result, i)
|
||||
|
||||
print(f"✅ Response validation passed: Found {len(response['data'])} search results")
|
||||
|
||||
def _validate_vector_store_create_response(self, response):
|
||||
"""Validate the structure and content of a vector store create response"""
|
||||
|
||||
# Check that response is a dictionary
|
||||
assert isinstance(response, dict), f"Response should be a dict, got {type(response)}"
|
||||
|
||||
# Check required top-level fields for create response
|
||||
required_fields = ['id', 'object', 'created_at']
|
||||
for field in required_fields:
|
||||
assert field in response, f"Missing required field '{field}' in create response"
|
||||
|
||||
# Validate object field
|
||||
assert response['object'] == 'vector_store', \
|
||||
f"Expected object to be 'vector_store', got '{response['object']}'"
|
||||
|
||||
# Validate id field
|
||||
assert isinstance(response['id'], str), \
|
||||
f"id should be a string, got {type(response['id'])}"
|
||||
assert len(response['id']) > 0, "id should not be empty"
|
||||
assert response['id'].startswith('vs_'), \
|
||||
f"id should start with 'vs_', got '{response['id']}'"
|
||||
|
||||
# Validate created_at field
|
||||
assert isinstance(response['created_at'], int), \
|
||||
f"created_at should be an integer, got {type(response['created_at'])}"
|
||||
assert response['created_at'] > 0, "created_at should be a positive timestamp"
|
||||
|
||||
# Validate optional fields if present
|
||||
if 'name' in response:
|
||||
assert isinstance(response['name'], str), \
|
||||
f"name should be a string, got {type(response['name'])}"
|
||||
|
||||
if 'bytes' in response:
|
||||
assert isinstance(response['bytes'], int), \
|
||||
f"bytes should be an integer, got {type(response['bytes'])}"
|
||||
assert response['bytes'] >= 0, "bytes should be non-negative"
|
||||
|
||||
if 'file_counts' in response:
|
||||
self._validate_file_counts(response['file_counts'])
|
||||
|
||||
if 'status' in response:
|
||||
valid_statuses = ['expired', 'in_progress', 'completed']
|
||||
assert response['status'] in valid_statuses, \
|
||||
f"status should be one of {valid_statuses}, got '{response['status']}'"
|
||||
|
||||
if 'expires_at' in response and response['expires_at'] is not None:
|
||||
assert isinstance(response['expires_at'], int), \
|
||||
f"expires_at should be an integer, got {type(response['expires_at'])}"
|
||||
|
||||
if 'last_active_at' in response and response['last_active_at'] is not None:
|
||||
assert isinstance(response['last_active_at'], int), \
|
||||
f"last_active_at should be an integer, got {type(response['last_active_at'])}"
|
||||
|
||||
if 'metadata' in response and response['metadata'] is not None:
|
||||
assert isinstance(response['metadata'], dict), \
|
||||
f"metadata should be a dict, got {type(response['metadata'])}"
|
||||
|
||||
print(f"✅ Create response validation passed: Vector store '{response['id']}' created successfully")
|
||||
|
||||
def _validate_file_counts(self, file_counts):
|
||||
"""Validate file_counts structure"""
|
||||
assert isinstance(file_counts, dict), \
|
||||
f"file_counts should be a dict, got {type(file_counts)}"
|
||||
|
||||
required_count_fields = ['in_progress', 'completed', 'failed', 'cancelled', 'total']
|
||||
for field in required_count_fields:
|
||||
assert field in file_counts, f"Missing required field '{field}' in file_counts"
|
||||
assert isinstance(file_counts[field], int), \
|
||||
f"{field} should be an integer, got {type(file_counts[field])}"
|
||||
assert file_counts[field] >= 0, f"{field} should be non-negative"
|
||||
|
||||
# Validate that total equals sum of other counts
|
||||
calculated_total = (
|
||||
file_counts['in_progress'] +
|
||||
file_counts['completed'] +
|
||||
file_counts['failed'] +
|
||||
file_counts['cancelled']
|
||||
)
|
||||
assert file_counts['total'] == calculated_total, \
|
||||
f"total should equal sum of other counts ({calculated_total}), got {file_counts['total']}"
|
||||
|
||||
def _validate_search_result(self, result, index):
|
||||
"""Validate an individual search result"""
|
||||
|
||||
# Check that result is a dictionary
|
||||
assert isinstance(result, dict), f"Result {index} should be a dict, got {type(result)}"
|
||||
|
||||
# Check required fields in each result
|
||||
required_result_fields = ['file_id', 'filename', 'score', 'attributes', 'content']
|
||||
for field in required_result_fields:
|
||||
assert field in result, f"Missing required field '{field}' in result {index}"
|
||||
|
||||
# Validate file_id
|
||||
assert isinstance(result['file_id'], str), \
|
||||
f"file_id should be a string, got {type(result['file_id'])} in result {index}"
|
||||
assert len(result['file_id']) > 0, f"file_id should not be empty in result {index}"
|
||||
|
||||
# Validate filename
|
||||
assert isinstance(result['filename'], str), \
|
||||
f"filename should be a string, got {type(result['filename'])} in result {index}"
|
||||
assert len(result['filename']) > 0, f"filename should not be empty in result {index}"
|
||||
|
||||
# Validate score
|
||||
assert isinstance(result['score'], (int, float)), \
|
||||
f"score should be a number, got {type(result['score'])} in result {index}"
|
||||
assert 0.0 <= result['score'] <= 1.0, \
|
||||
f"score should be between 0.0 and 1.0, got {result['score']} in result {index}"
|
||||
|
||||
# Validate attributes
|
||||
assert isinstance(result['attributes'], dict), \
|
||||
f"attributes should be a dict, got {type(result['attributes'])} in result {index}"
|
||||
|
||||
# Validate content
|
||||
assert isinstance(result['content'], list), \
|
||||
f"content should be a list, got {type(result['content'])} in result {index}"
|
||||
assert len(result['content']) > 0, f"content should not be empty in result {index}"
|
||||
|
||||
# Validate each content item
|
||||
for j, content_item in enumerate(result['content']):
|
||||
assert isinstance(content_item, dict), \
|
||||
f"Content item {j} in result {index} should be a dict, got {type(content_item)}"
|
||||
assert 'type' in content_item, \
|
||||
f"Content item {j} in result {index} missing 'type' field"
|
||||
assert 'text' in content_item, \
|
||||
f"Content item {j} in result {index} missing 'text' field"
|
||||
assert isinstance(content_item['text'], str), \
|
||||
f"Content text should be a string in item {j} of result {index}"
|
||||
assert len(content_item['text']) > 0, \
|
||||
f"Content text should not be empty in item {j} of result {index}"
|
||||
|
||||
print(f"✅ Result {index} validation passed: {result['filename']} (score: {result['score']:.4f})")
|
||||
@@ -0,0 +1,63 @@
|
||||
# conftest.py
|
||||
|
||||
import importlib
|
||||
import os
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
import litellm
|
||||
|
||||
|
||||
@pytest.fixture(scope="function", autouse=True)
|
||||
def setup_and_teardown():
|
||||
"""
|
||||
This fixture reloads litellm before every function. To speed up testing by removing callbacks being chained.
|
||||
"""
|
||||
curr_dir = os.getcwd() # Get the current working directory
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the project directory to the system path
|
||||
|
||||
import litellm
|
||||
from litellm import Router
|
||||
|
||||
importlib.reload(litellm)
|
||||
|
||||
try:
|
||||
if hasattr(litellm, "proxy") and hasattr(litellm.proxy, "proxy_server"):
|
||||
import litellm.proxy.proxy_server
|
||||
|
||||
importlib.reload(litellm.proxy.proxy_server)
|
||||
except Exception as e:
|
||||
print(f"Error reloading litellm.proxy.proxy_server: {e}")
|
||||
|
||||
import asyncio
|
||||
|
||||
loop = asyncio.get_event_loop_policy().new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
print(litellm)
|
||||
# from litellm import Router, completion, aembedding, acompletion, embedding
|
||||
yield
|
||||
|
||||
# Teardown code (executes after the yield point)
|
||||
loop.close() # Close the loop created earlier
|
||||
asyncio.set_event_loop(None) # Remove the reference to the loop
|
||||
|
||||
|
||||
def pytest_collection_modifyitems(config, items):
|
||||
# Separate tests in 'test_amazing_proxy_custom_logger.py' and other tests
|
||||
custom_logger_tests = [
|
||||
item for item in items if "custom_logger" in item.parent.name
|
||||
]
|
||||
other_tests = [item for item in items if "custom_logger" not in item.parent.name]
|
||||
|
||||
# Sort tests based on their names
|
||||
custom_logger_tests.sort(key=lambda x: x.name)
|
||||
other_tests.sort(key=lambda x: x.name)
|
||||
|
||||
# Reorder the items list
|
||||
items[:] = custom_logger_tests + other_tests
|
||||
@@ -0,0 +1,25 @@
|
||||
from base_vector_store_test import BaseVectorStoreTest
|
||||
import os
|
||||
import pytest
|
||||
|
||||
class TestAzureOpenAIVectorStore(BaseVectorStoreTest):
|
||||
def get_base_request_args(self) -> dict:
|
||||
"""Must return the base request args"""
|
||||
return {}
|
||||
|
||||
@pytest.mark.parametrize("sync_mode", [True, False])
|
||||
@pytest.mark.asyncio
|
||||
async def test_basic_search_vector_store(self, sync_mode):
|
||||
pass
|
||||
|
||||
|
||||
def get_base_create_vector_store_args(self) -> dict:
|
||||
"""
|
||||
This is a real vector store on Azure
|
||||
"""
|
||||
return {
|
||||
"custom_llm_provider": "azure",
|
||||
"api_base": os.getenv("AZURE_RESPONSES_OPENAI_ENDPOINT"),
|
||||
"api_key": os.getenv("AZURE_RESPONSES_OPENAI_API_KEY"),
|
||||
"api_version": "2025-04-01-preview",
|
||||
}
|
||||
@@ -0,0 +1,20 @@
|
||||
from base_vector_store_test import BaseVectorStoreTest
|
||||
|
||||
class TestOpenAIVectorStore(BaseVectorStoreTest):
|
||||
def get_base_request_args(self) -> dict:
|
||||
"""
|
||||
This is a real vector store on OpenAI
|
||||
"""
|
||||
return {
|
||||
"vector_store_id": "vs_685b14b1a1b88191bc27e04f1917fddd",
|
||||
"custom_llm_provider": "openai",
|
||||
}
|
||||
|
||||
|
||||
def get_base_create_vector_store_args(self) -> dict:
|
||||
"""
|
||||
This is a real vector store on OpenAI
|
||||
"""
|
||||
return {
|
||||
"custom_llm_provider": "openai",
|
||||
}
|
||||
@@ -132,7 +132,7 @@ const SSOModals: React.FC<SSOModalsProps> = ({
|
||||
}
|
||||
}
|
||||
|
||||
// Set form values with existing data
|
||||
// Set form values with existing data (excluding UI access control fields)
|
||||
const formValues = {
|
||||
sso_provider: selectedProvider,
|
||||
proxy_base_url: ssoData.values.proxy_base_url,
|
||||
|
||||
@@ -0,0 +1,148 @@
|
||||
import React, { useEffect, useState } from "react";
|
||||
import { Form, Button as Button2, Select, message } from "antd";
|
||||
import { Text, TextInput } from "@tremor/react";
|
||||
import { getSSOSettings, updateSSOSettings } from "./networking";
|
||||
|
||||
interface UIAccessControlFormProps {
|
||||
accessToken: string | null;
|
||||
onSuccess: () => void;
|
||||
}
|
||||
|
||||
// Separate UI Access Control Form Component
|
||||
const UIAccessControlForm: React.FC<UIAccessControlFormProps> = ({ accessToken, onSuccess }) => {
|
||||
const [form] = Form.useForm();
|
||||
const [loading, setLoading] = useState(false);
|
||||
|
||||
// Load existing UI access control settings
|
||||
useEffect(() => {
|
||||
const loadUIAccessSettings = async () => {
|
||||
if (accessToken) {
|
||||
try {
|
||||
const ssoData = await getSSOSettings(accessToken);
|
||||
if (ssoData && ssoData.values) {
|
||||
// Handle nested ui_access_mode structure
|
||||
const uiAccessMode = ssoData.values.ui_access_mode;
|
||||
let formValues = {};
|
||||
|
||||
if (uiAccessMode && typeof uiAccessMode === 'object') {
|
||||
formValues = {
|
||||
ui_access_mode_type: uiAccessMode.type,
|
||||
restricted_sso_group: uiAccessMode.restricted_sso_group,
|
||||
sso_group_jwt_field: uiAccessMode.sso_group_jwt_field,
|
||||
};
|
||||
} else if (typeof uiAccessMode === 'string') {
|
||||
// Handle legacy flat structure
|
||||
formValues = {
|
||||
ui_access_mode_type: uiAccessMode,
|
||||
restricted_sso_group: ssoData.values.restricted_sso_group,
|
||||
sso_group_jwt_field: ssoData.values.team_ids_jwt_field || ssoData.values.sso_group_jwt_field,
|
||||
};
|
||||
}
|
||||
|
||||
form.setFieldsValue(formValues);
|
||||
}
|
||||
} catch (error) {
|
||||
console.error("Failed to load UI access settings:", error);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
loadUIAccessSettings();
|
||||
}, [accessToken, form]);
|
||||
|
||||
const handleUIAccessSubmit = async (formValues: Record<string, any>) => {
|
||||
if (!accessToken) {
|
||||
message.error("No access token available");
|
||||
return;
|
||||
}
|
||||
|
||||
setLoading(true);
|
||||
try {
|
||||
// Transform form data to match API expected structure
|
||||
const apiPayload = {
|
||||
ui_access_mode: {
|
||||
type: formValues.ui_access_mode_type,
|
||||
restricted_sso_group: formValues.restricted_sso_group,
|
||||
sso_group_jwt_field: formValues.sso_group_jwt_field,
|
||||
}
|
||||
};
|
||||
|
||||
await updateSSOSettings(accessToken, apiPayload);
|
||||
onSuccess();
|
||||
} catch (error) {
|
||||
console.error("Failed to save UI access settings:", error);
|
||||
message.error("Failed to save UI access settings");
|
||||
} finally {
|
||||
setLoading(false);
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<div style={{ padding: '16px' }}>
|
||||
<div style={{ marginBottom: '16px' }}>
|
||||
<Text style={{ fontSize: '14px', color: '#6b7280' }}>
|
||||
Configure who can access the UI interface and how group information is extracted from JWT tokens.
|
||||
</Text>
|
||||
</div>
|
||||
|
||||
<Form
|
||||
form={form}
|
||||
onFinish={handleUIAccessSubmit}
|
||||
layout="vertical"
|
||||
>
|
||||
<Form.Item
|
||||
label="UI Access Mode"
|
||||
name="ui_access_mode_type"
|
||||
tooltip="Controls who can access the UI interface"
|
||||
>
|
||||
<Select placeholder="Select access mode">
|
||||
<Select.Option value="all_authenticated_users">All Authenticated Users</Select.Option>
|
||||
<Select.Option value="restricted_sso_group">Restricted SSO Group</Select.Option>
|
||||
</Select>
|
||||
</Form.Item>
|
||||
|
||||
<Form.Item
|
||||
noStyle
|
||||
shouldUpdate={(prevValues, currentValues) => prevValues.ui_access_mode_type !== currentValues.ui_access_mode_type}
|
||||
>
|
||||
{({ getFieldValue }) => {
|
||||
const uiAccessModeType = getFieldValue('ui_access_mode_type');
|
||||
return uiAccessModeType === 'restricted_sso_group' ? (
|
||||
<Form.Item
|
||||
label="Restricted SSO Group"
|
||||
name="restricted_sso_group"
|
||||
rules={[{ required: true, message: "Please enter the restricted SSO group" }]}
|
||||
>
|
||||
<TextInput placeholder="ui-access-group" />
|
||||
</Form.Item>
|
||||
) : null;
|
||||
}}
|
||||
</Form.Item>
|
||||
|
||||
<Form.Item
|
||||
label="SSO Group JWT Field"
|
||||
name="sso_group_jwt_field"
|
||||
tooltip="JWT field name that contains team/group information. Use dot notation to access nested fields."
|
||||
>
|
||||
<TextInput placeholder="groups" />
|
||||
</Form.Item>
|
||||
|
||||
<div style={{ textAlign: "right", marginTop: "16px" }}>
|
||||
<Button2
|
||||
type="primary"
|
||||
htmlType="submit"
|
||||
loading={loading}
|
||||
style={{
|
||||
backgroundColor: '#6366f1',
|
||||
borderColor: '#6366f1'
|
||||
}}
|
||||
>
|
||||
Update UI Access Control
|
||||
</Button2>
|
||||
</div>
|
||||
</Form>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
export default UIAccessControlForm;
|
||||
@@ -44,6 +44,7 @@ import { InvitationLink } from "./onboarding_link";
|
||||
import SSOModals from "./SSOModals";
|
||||
import { ssoProviderConfigs } from './SSOModals';
|
||||
import SCIMConfig from "./SCIM";
|
||||
import UIAccessControlForm from "./UIAccessControlForm";
|
||||
|
||||
interface AdminPanelProps {
|
||||
searchParams: any;
|
||||
@@ -97,6 +98,7 @@ const AdminPanel: React.FC<AdminPanelProps> = ({
|
||||
const [isAllowedIPModalVisible, setIsAllowedIPModalVisible] = useState(false);
|
||||
const [isAddIPModalVisible, setIsAddIPModalVisible] = useState(false);
|
||||
const [isDeleteIPModalVisible, setIsDeleteIPModalVisible] = useState(false);
|
||||
const [isUIAccessControlModalVisible, setIsUIAccessControlModalVisible] = useState(false);
|
||||
const [allowedIPs, setAllowedIPs] = useState<string[]>([]);
|
||||
const [ipToDelete, setIPToDelete] = useState<string | null>(null);
|
||||
const [ssoConfigured, setSsoConfigured] = useState<boolean>(false);
|
||||
@@ -532,6 +534,14 @@ const AdminPanel: React.FC<AdminPanelProps> = ({
|
||||
}
|
||||
};
|
||||
|
||||
const handleUIAccessControlOk = () => {
|
||||
setIsUIAccessControlModalVisible(false);
|
||||
};
|
||||
|
||||
const handleUIAccessControlCancel = () => {
|
||||
setIsUIAccessControlModalVisible(false);
|
||||
};
|
||||
|
||||
console.log(`admins: ${admins?.length}`);
|
||||
return (
|
||||
<div className="w-full m-2 mt-2 p-8">
|
||||
@@ -563,6 +573,14 @@ const AdminPanel: React.FC<AdminPanelProps> = ({
|
||||
Allowed IPs
|
||||
</Button>
|
||||
</div>
|
||||
<div>
|
||||
<Button
|
||||
style={{ width: '150px' }}
|
||||
onClick={() => premiumUser === true ? setIsUIAccessControlModalVisible(true) : message.error("Only premium users can configure UI access control")}
|
||||
>
|
||||
UI Access Control
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
</Card>
|
||||
|
||||
@@ -654,6 +672,24 @@ const AdminPanel: React.FC<AdminPanelProps> = ({
|
||||
>
|
||||
<p>Are you sure you want to delete the IP address: {ipToDelete}?</p>
|
||||
</Modal>
|
||||
|
||||
{/* UI Access Control Modal */}
|
||||
<Modal
|
||||
title="UI Access Control Settings"
|
||||
visible={isUIAccessControlModalVisible}
|
||||
width={600}
|
||||
footer={null}
|
||||
onOk={handleUIAccessControlOk}
|
||||
onCancel={handleUIAccessControlCancel}
|
||||
>
|
||||
<UIAccessControlForm
|
||||
accessToken={accessToken}
|
||||
onSuccess={() => {
|
||||
handleUIAccessControlOk();
|
||||
message.success("UI Access Control settings updated successfully");
|
||||
}}
|
||||
/>
|
||||
</Modal>
|
||||
</div>
|
||||
<Callout title="Login without SSO" color="teal">
|
||||
If you need to login without sso, you can access{" "}
|
||||
|
||||
@@ -260,6 +260,10 @@ const TeamInfoView: React.FC<TeamInfoProps> = ({
|
||||
updateData.team_member_budget = Number(values.team_member_budget);
|
||||
}
|
||||
|
||||
if (values.team_member_key_duration !== undefined) {
|
||||
updateData.team_member_key_duration = values.team_member_key_duration;
|
||||
}
|
||||
|
||||
// Handle object_permission updates
|
||||
if (values.vector_stores !== undefined || values.mcp_servers !== undefined) {
|
||||
updateData.object_permission = {
|
||||
@@ -453,6 +457,15 @@ const TeamInfoView: React.FC<TeamInfoProps> = ({
|
||||
<NumericalInput step={0.01} precision={2} style={{ width: "100%" }} />
|
||||
</Form.Item>
|
||||
|
||||
<Form.Item label="Team Member Key Duration" name="team_member_key_duration" tooltip="Set a limit to the duration of a team member's key.">
|
||||
<Select placeholder="n/a">
|
||||
<Select.Option value="1d">1 day</Select.Option>
|
||||
<Select.Option value="1w">1 week</Select.Option>
|
||||
<Select.Option value="1mo">1 month</Select.Option>
|
||||
</Select>
|
||||
</Form.Item>
|
||||
|
||||
|
||||
<Form.Item label="Reset Budget" name="budget_duration">
|
||||
<Select placeholder="n/a">
|
||||
<Select.Option value="24h">daily</Select.Option>
|
||||
@@ -561,9 +574,19 @@ const TeamInfoView: React.FC<TeamInfoProps> = ({
|
||||
<div>RPM: {info.rpm_limit || 'Unlimited'}</div>
|
||||
</div>
|
||||
<div>
|
||||
<Text className="font-medium">Budget</Text>
|
||||
<div>Max: {info.max_budget !== null ? `$${info.max_budget}` : 'No Limit'}</div>
|
||||
<div>Reset: {info.budget_duration || 'Never'}</div>
|
||||
<Text className="font-medium">Team Budget</Text>
|
||||
<div>Max Budget: {info.max_budget !== null ? `$${info.max_budget}` : 'No Limit'}</div>
|
||||
<div>Budget Reset: {info.budget_duration || 'Never'}</div>
|
||||
</div>
|
||||
<div>
|
||||
<Text className="font-medium">
|
||||
Team Member Settings{' '}
|
||||
<Tooltip title="These are limits on individual team members">
|
||||
<InfoCircleOutlined style={{ marginLeft: '4px' }} />
|
||||
</Tooltip>
|
||||
</Text>
|
||||
<div>Max Budget: {info.team_member_budget_table?.max_budget || 'No Limit'}</div>
|
||||
<div>Key Duration: {info.metadata?.team_member_key_duration || 'No Limit'}</div>
|
||||
</div>
|
||||
<div>
|
||||
<Text className="font-medium">Organization ID</Text>
|
||||
|
||||
@@ -1060,6 +1060,17 @@ const Teams: React.FC<TeamProps> = ({
|
||||
>
|
||||
<NumericalInput step={0.01} precision={2} width={200} />
|
||||
</Form.Item>
|
||||
<Form.Item
|
||||
label="Team Member Key Duration"
|
||||
name="team_member_key_duration"
|
||||
tooltip="Set a limit to the duration of a team member's key."
|
||||
>
|
||||
<Select2 defaultValue={null} placeholder="n/a">
|
||||
<Select2.Option value="1d">1 day</Select2.Option>
|
||||
<Select2.Option value="1w">1 week</Select2.Option>
|
||||
<Select2.Option value="1mo">1 month</Select2.Option>
|
||||
</Select2>
|
||||
</Form.Item>
|
||||
<Form.Item label="Metadata" name="metadata" help="Additional team metadata. Enter metadata as JSON object.">
|
||||
<Input.TextArea rows={4} />
|
||||
</Form.Item>
|
||||
|
||||
Reference in New Issue
Block a user