From b9feb43dac4fd2c2085240c62ca13753b960efce Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Tue, 24 Jun 2025 13:22:56 -0700 Subject: [PATCH 01/12] [Bug Fix] SCIM - Ensure new user roles are applied (#12015) * SCIM fix new user roles * test_create_user_defaults_to_viewer * test_create_user_uses_default_internal_user_params_role * fix default user for SCIM * fix linting error --- .../management_endpoints/scim/scim_v2.py | 14 +++ .../scim/test_scim_v2_endpoints.py | 90 ++++++++++++++++++- 2 files changed, 103 insertions(+), 1 deletion(-) diff --git a/litellm/proxy/management_endpoints/scim/scim_v2.py b/litellm/proxy/management_endpoints/scim/scim_v2.py index b88413f180..1ef4cdc974 100644 --- a/litellm/proxy/management_endpoints/scim/scim_v2.py +++ b/litellm/proxy/management_endpoints/scim/scim_v2.py @@ -18,6 +18,7 @@ from fastapi import ( Response, ) +import litellm from litellm._logging import verbose_proxy_logger from litellm.litellm_core_utils.safe_json_dumps import safe_dumps from litellm.proxy._types import ( @@ -361,6 +362,18 @@ async def create_user( # Create user in database user_id = user.userName or str(uuid.uuid4()) metadata = _build_scim_metadata(user_data["given_name"], user_data["family_name"]) + + default_role: Optional[ + Literal[ + LitellmUserRoles.PROXY_ADMIN, + LitellmUserRoles.PROXY_ADMIN_VIEW_ONLY, + LitellmUserRoles.INTERNAL_USER, + LitellmUserRoles.INTERNAL_USER_VIEW_ONLY, + ] + ] = LitellmUserRoles.INTERNAL_USER_VIEW_ONLY + if litellm.default_internal_user_params: + default_role = litellm.default_internal_user_params.get("user_role") + new_user_request = NewUserRequest( user_id=user_id, user_email=user_data["user_email"], @@ -368,6 +381,7 @@ async def create_user( teams=user_data["teams"], metadata=metadata, auto_create_key=False, + user_role=default_role, ) # Check if user with email already exists and update if found diff --git a/tests/test_litellm/proxy/management_endpoints/scim/test_scim_v2_endpoints.py b/tests/test_litellm/proxy/management_endpoints/scim/test_scim_v2_endpoints.py index aab700e2a7..f93728b90c 100644 --- a/tests/test_litellm/proxy/management_endpoints/scim/test_scim_v2_endpoints.py +++ b/tests/test_litellm/proxy/management_endpoints/scim/test_scim_v2_endpoints.py @@ -3,7 +3,7 @@ from unittest.mock import AsyncMock import pytest from fastapi import HTTPException -from litellm.proxy._types import NewUserRequest, ProxyException +from litellm.proxy._types import LitellmUserRoles, NewUserRequest, ProxyException from litellm.proxy.management_endpoints.scim.scim_v2 import ( UserProvisionerHelpers, _handle_team_membership_changes, @@ -58,6 +58,94 @@ async def test_create_user_existing_user_conflict(mocker): mocked_new_user.assert_not_called() +@pytest.mark.asyncio +async def test_create_user_defaults_to_viewer(mocker, monkeypatch): + """If no role provided, new user should default to viewer""" + + scim_user = SCIMUser( + schemas=["urn:ietf:params:scim:schemas:core:2.0:User"], + userName="new-user", + name=SCIMUserName(familyName="User", givenName="New"), + emails=[SCIMUserEmail(value="new@example.com")], + ) + + mock_prisma_client = mocker.MagicMock() + mock_prisma_client.db = mocker.MagicMock() + mock_prisma_client.db.litellm_usertable = mocker.MagicMock() + mock_prisma_client.db.litellm_usertable.find_unique = AsyncMock(return_value=None) + mock_prisma_client.db.litellm_usertable.find_first = AsyncMock(return_value=None) + + monkeypatch.setattr( + "litellm.default_internal_user_params", None, raising=False + ) + + mocker.patch( + "litellm.proxy.management_endpoints.scim.scim_v2._get_prisma_client_or_raise_exception", + AsyncMock(return_value=mock_prisma_client), + ) + + new_user_mock = mocker.patch( + "litellm.proxy.management_endpoints.scim.scim_v2.new_user", + AsyncMock(return_value=NewUserRequest(user_id="new-user")), + ) + + mocker.patch( + "litellm.proxy.management_endpoints.scim.scim_v2.ScimTransformations.transform_litellm_user_to_scim_user", + AsyncMock(return_value=scim_user), + ) + + await create_user(user=scim_user) + + called_args = new_user_mock.call_args.kwargs["data"] + assert called_args.user_role == LitellmUserRoles.INTERNAL_USER_VIEW_ONLY + + +@pytest.mark.asyncio +async def test_create_user_uses_default_internal_user_params_role(mocker, monkeypatch): + """If role is set in default_internal_user_params, new user should use that role""" + + scim_user = SCIMUser( + schemas=["urn:ietf:params:scim:schemas:core:2.0:User"], + userName="new-user", + name=SCIMUserName(familyName="User", givenName="New"), + emails=[SCIMUserEmail(value="new@example.com")], + ) + + mock_prisma_client = mocker.MagicMock() + mock_prisma_client.db = mocker.MagicMock() + mock_prisma_client.db.litellm_usertable = mocker.MagicMock() + mock_prisma_client.db.litellm_usertable.find_unique = AsyncMock(return_value=None) + mock_prisma_client.db.litellm_usertable.find_first = AsyncMock(return_value=None) + + # Set default_internal_user_params with a specific role + default_params = { + "user_role": LitellmUserRoles.PROXY_ADMIN, + } + monkeypatch.setattr( + "litellm.default_internal_user_params", default_params, raising=False + ) + + mocker.patch( + "litellm.proxy.management_endpoints.scim.scim_v2._get_prisma_client_or_raise_exception", + AsyncMock(return_value=mock_prisma_client), + ) + + new_user_mock = mocker.patch( + "litellm.proxy.management_endpoints.scim.scim_v2.new_user", + AsyncMock(return_value=NewUserRequest(user_id="new-user")), + ) + + mocker.patch( + "litellm.proxy.management_endpoints.scim.scim_v2.ScimTransformations.transform_litellm_user_to_scim_user", + AsyncMock(return_value=scim_user), + ) + + await create_user(user=scim_user) + + called_args = new_user_mock.call_args.kwargs["data"] + assert called_args.user_role == LitellmUserRoles.PROXY_ADMIN + + @pytest.mark.asyncio async def test_handle_existing_user_by_email_no_email(mocker): """Should return None when new_user_request has no email""" From c9a8198d125ed0bd533b159e7f28f55c2a763da8 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Tue, 24 Jun 2025 13:24:23 -0700 Subject: [PATCH 02/12] docs(self_serve.md): clarify team must be created before setting as default team --- docs/my-website/docs/proxy/self_serve.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/docs/my-website/docs/proxy/self_serve.md b/docs/my-website/docs/proxy/self_serve.md index 30c46cfb0a..948062d7fd 100644 --- a/docs/my-website/docs/proxy/self_serve.md +++ b/docs/my-website/docs/proxy/self_serve.md @@ -292,6 +292,10 @@ Let's also set the default models to `no-default-models`. This means a user can +:::info +Team must be created before setting it as the default team. +::: + ```yaml default_internal_user_params: # Default Params used when a new user signs in Via SSO user_role: "internal_user" # one of "internal_user", "internal_user_viewer", From 8da22be199a22b1a4a6b3a054d0299e242fca270 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Tue, 24 Jun 2025 13:25:17 -0700 Subject: [PATCH 03/12] docs(self_serve.md): update doc --- docs/my-website/docs/proxy/self_serve.md | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/docs/my-website/docs/proxy/self_serve.md b/docs/my-website/docs/proxy/self_serve.md index 948062d7fd..e7860b4247 100644 --- a/docs/my-website/docs/proxy/self_serve.md +++ b/docs/my-website/docs/proxy/self_serve.md @@ -297,12 +297,13 @@ Team must be created before setting it as the default team. ::: ```yaml -default_internal_user_params: # Default Params used when a new user signs in Via SSO - user_role: "internal_user" # one of "internal_user", "internal_user_viewer", - models: ["no-default-models"] # Optional[List[str]], optional): models to be used by the user - teams: # Optional[List[NewUserRequestTeam]], optional): teams to be used by the user - - team_id: "team_id_1" # Required[str]: team_id to be used by the user - user_role: "user" # Optional[str], optional): Default role in the team. Values: "user" or "admin". Defaults to "user" +litellm_settings: + default_internal_user_params: # Default Params used when a new user signs in Via SSO + user_role: "internal_user" # one of "internal_user", "internal_user_viewer", + models: ["no-default-models"] # Optional[List[str]], optional): models to be used by the user + teams: # Optional[List[NewUserRequestTeam]], optional): teams to be used by the user + - team_id: "team_id_1" # Required[str]: team_id to be used by the user + user_role: "user" # Optional[str], optional): Default role in the team. Values: "user" or "admin". Defaults to "user" ``` From 1467a99aab22994da726f98e4adf9f07be80089e Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Tue, 24 Jun 2025 13:45:58 -0700 Subject: [PATCH 04/12] [Fix] Magistral small system prompt diverges too much from the official recommendation (#12007) * fix mistral _get_mistral_reasoning_system_prompt * fix test_get_mistral_reasoning_system_prompt --- .../mistral/mistral_chat_transformation.py | 27 ++++++++++++++----- .../test_mistral_chat_transformation.py | 4 --- 2 files changed, 20 insertions(+), 11 deletions(-) diff --git a/litellm/llms/mistral/mistral_chat_transformation.py b/litellm/llms/mistral/mistral_chat_transformation.py index 871fd8c7ea..e281e05553 100644 --- a/litellm/llms/mistral/mistral_chat_transformation.py +++ b/litellm/llms/mistral/mistral_chat_transformation.py @@ -6,7 +6,7 @@ Why separate file? Make it easy to see how transformation works Docs - https://docs.mistral.ai/api/ """ -from typing import Any, Coroutine, List, Literal, Optional, Tuple, Union, overload, cast +from typing import Any, Coroutine, List, Literal, Optional, Tuple, Union, cast, overload from litellm.litellm_core_utils.prompt_templates.common_utils import ( handle_messages_with_content_list_to_str_conversion, @@ -107,15 +107,28 @@ class MistralConfig(OpenAIGPTConfig): def _get_mistral_reasoning_system_prompt() -> str: """ Returns the system prompt for Mistral reasoning models. - Based on Mistral's documentation: https://docs.mistral.ai/capabilities/reasoning/ + Based on Mistral's documentation: https://huggingface.co/mistralai/Magistral-Small-2506 + + Mistral recommends the following system prompt for reasoning: """ - return """When solving problems, think step-by-step in tags before providing your final answer. Use the following format: + return """ + [SYSTEM_PROMPT]system_prompt + A user will ask you to solve a task. You should first draft your thinking process (inner monologue) until you have derived the final answer. Afterwards, write a self-contained summary of your thoughts (i.e. your summary should be succinct but contain all the critical steps you needed to reach the conclusion). You should use Markdown to format your response. Write both your thoughts and summary in the same language as the task posed by the user. NEVER use \boxed{} in your response. - -Your step-by-step reasoning process. Be thorough and work through the problem carefully. - + Your thinking process must follow the template below: + + Your thoughts or/and draft, like working through an exercise on scratch paper. Be as casual and as long as you want until you are confident to generate a correct answer. + -Then provide a clear, concise answer based on your reasoning.""" + Here, provide a concise summary that reflects your reasoning and presents a clear final answer to the user. Don't mention that this is a summary. + + Problem: + + [/SYSTEM_PROMPT][INST]user_message[/INST] + reasoning_traces + + assistant_response[INST]user_message[/INST] + """ def map_openai_params( self, diff --git a/tests/test_litellm/llms/mistral/test_mistral_chat_transformation.py b/tests/test_litellm/llms/mistral/test_mistral_chat_transformation.py index c46413799f..be0e73d850 100644 --- a/tests/test_litellm/llms/mistral/test_mistral_chat_transformation.py +++ b/tests/test_litellm/llms/mistral/test_mistral_chat_transformation.py @@ -95,10 +95,6 @@ class TestMistralReasoningSupport: def test_get_mistral_reasoning_system_prompt(self): """Test that the reasoning system prompt is properly formatted.""" prompt = MistralConfig._get_mistral_reasoning_system_prompt() - - assert "" in prompt - assert "" in prompt - assert "step-by-step" in prompt assert isinstance(prompt, str) assert len(prompt) > 50 # Ensure it's not empty From 97da33494adbe63852544b5fe2cce6cb1f71ed0d Mon Sep 17 00:00:00 2001 From: Cole McIntosh <82463175+colesmcintosh@users.noreply.github.com> Date: Tue, 24 Jun 2025 16:38:48 -0600 Subject: [PATCH 05/12] Refactor unpack_defs to use iterative approach instead of recursion (#12017) * Refactor unpack_defs to use iterative approach instead of recursion - Replace recursive depth-first traversal with iterative queue-based approach - Add collections.deque import for efficient queue operations - Avoid potential stack overflow issues with deeply nested schemas - Maintain same functionality while improving performance and safety * Remove unused import of Set in common_utils.py * Enhance type hinting for queue in unpack_defs function in common_utils.py * Enhance unpack_defs function to handle key validation for parent structures in common_utils.py - Added checks to ensure that the parent is a dictionary or list and that the key is of the appropriate type (string for dicts, integer for lists) before assigning the resolved schema. - This improves the robustness of the unpack_defs function when dealing with various schema structures. --- .../prompt_templates/common_utils.py | 72 +++++++++---------- 1 file changed, 35 insertions(+), 37 deletions(-) diff --git a/litellm/litellm_core_utils/prompt_templates/common_utils.py b/litellm/litellm_core_utils/prompt_templates/common_utils.py index 6b70f690f9..626c8b7f29 100644 --- a/litellm/litellm_core_utils/prompt_templates/common_utils.py +++ b/litellm/litellm_core_utils/prompt_templates/common_utils.py @@ -16,7 +16,6 @@ from typing import ( Optional, Union, cast, - Set, ) from litellm.types.llms.openai import ( @@ -508,6 +507,7 @@ def unpack_defs(schema: dict, defs: dict) -> None: """ import copy + from collections import deque # Combine the defs handed down by the caller with defs/definitions found on # the current node. Local keys shadow parent keys to match JSON-schema @@ -518,14 +518,17 @@ def unpack_defs(schema: dict, defs: dict) -> None: **schema.get("definitions", {}), } - def _walk_and_resolve(node: Any, active_defs: dict, seen: Set[int]): # type: ignore[name-defined] - """Depth-first resolver that replaces ``{"$ref": "#/defs/Foo"}`` with - the *actual* ``Foo`` schema. - """ - - # Avoid infinite recursion on self-referential schemas + # Use iterative approach with queue to avoid recursion + # Each item in queue is (node, parent_container, key/index, active_defs, seen_ids) + queue: deque[tuple[Any, Union[dict, list, None], Union[str, int, None], dict, set]] = deque([(schema, None, None, root_defs, set())]) + + while queue: + node, parent, key, active_defs, seen = queue.popleft() + + # Avoid infinite loops on self-referential schemas if id(node) in seen: - return node + continue + seen = seen.copy() # Create new set for this branch seen.add(id(node)) # ----------------------------- dict ----------------------------- @@ -536,7 +539,7 @@ def unpack_defs(schema: dict, defs: dict) -> None: target_schema = active_defs.get(ref_name) # Unknown reference – leave untouched if target_schema is None: - return node + continue # Merge defs from the target to capture nested definitions child_defs = { @@ -545,12 +548,24 @@ def unpack_defs(schema: dict, defs: dict) -> None: **target_schema.get("definitions", {}), } - # Recursively resolve the target *copy* to avoid mutating the - # shared definition map. - resolved = _walk_and_resolve(copy.deepcopy(target_schema), child_defs, seen) - return resolved + # Replace the reference with resolved copy + resolved = copy.deepcopy(target_schema) + if parent is not None and key is not None: + if isinstance(parent, dict) and isinstance(key, str): + parent[key] = resolved + elif isinstance(parent, list) and isinstance(key, int): + parent[key] = resolved + else: + # This is the root schema itself + schema.clear() + schema.update(resolved) + resolved = schema + + # Add resolved node to queue for further processing + queue.append((resolved, parent, key, child_defs, seen)) + continue - # --- Case 2: regular dict – recurse into its values --- + # --- Case 2: regular dict – process its values --- # Update defs with any nested $defs/definitions present *here*. current_defs = { **active_defs, @@ -558,32 +573,15 @@ def unpack_defs(schema: dict, defs: dict) -> None: **node.get("definitions", {}), } - for key, val in list(node.items()): - node[key] = _walk_and_resolve(val, current_defs, seen) - return node + # Add all dict values to queue + for k, v in node.items(): + queue.append((v, node, k, current_defs, seen)) # ---------------------------- list ------------------------------ - if isinstance(node, list): + elif isinstance(node, list): + # Add all list items to queue for idx, item in enumerate(node): - node[idx] = _walk_and_resolve(item, active_defs, seen) - return node - - # -------------------------- primitive --------------------------- - return node - - # Kick off traversal - resolved_root = _walk_and_resolve(schema, root_defs, set()) - # If the resolver returned a *different* dict (e.g., the root itself was a - # $ref), mirror the changes back into the original object so that callers - # holding a reference to ``schema`` see the updated structure. - if resolved_root is not schema: - schema.clear() - if isinstance(resolved_root, dict): - schema.update(resolved_root) - else: - # In the very unlikely case the root was resolved to a non-dict - # (e.g., a primitive), replace in-place via a sentinel key. - schema["__resolved_value__"] = resolved_root # type: ignore + queue.append((item, node, idx, active_defs, seen)) def _get_image_mime_type_from_url(url: str) -> Optional[str]: From 2bb8048864285c1a62fe5c89fca3bc31865d1b59 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Tue, 24 Jun 2025 15:52:43 -0700 Subject: [PATCH 06/12] [Feat] Add OpenAI Search Vector Store Operation (#12018) * add BaseVectorStoreTransformation * fix BaseVectorStoreTransformation * add OpenAIVectorStoreTransformation * fix transform * add search, asearch vector stores * add skeleton for vector store searching * fix VectorStoreSearchOptionalRequestParams * fix VectorStoreRequestUtils * fix litellm.asearch/litellm.search * fix BaseVectorStoreConfig * add vector_store_search_handler to llm http handler * use llm http handler for searching vector stores * fix base vector store config * fix vector_store_search_handler * async_vector_store_search_handler * add conftest * add BaseVectorStoreTest * move litellm.integrations.vector_store_integrations * fix working OAI OpenAIVectorStoreConfig * add Search vector store * add OpenAI Vector Stores --- .../docs/completion/knowledgebase.md | 1 + litellm/__init__.py | 1 + .../base_vector_store.py | 0 .../bedrock_vector_store.py | 409 ++++++++++++++++++ .../vector_stores/bedrock_vector_store.py | 4 +- .../custom_logger_registry.py | 4 +- litellm/litellm_core_utils/litellm_logging.py | 4 +- .../base_llm/vector_store/transformation.py | 73 ++++ litellm/llms/custom_httpx/llm_http_handler.py | 147 ++++++- .../openai/vector_stores/transformation.py | 107 +++++ litellm/types/vector_stores.py | 11 + litellm/utils.py | 22 +- litellm/vector_stores/__init__.py | 4 + litellm/vector_stores/main.py | 232 ++++++++++ litellm/vector_stores/utils.py | 28 ++ .../test_bedrock_knowledgebase_hook.py | 2 +- .../base_vector_store_test.py | 136 ++++++ tests/vector_store_tests/conftest.py | 63 +++ .../test_openai_vector_store.py | 11 + 19 files changed, 1252 insertions(+), 7 deletions(-) rename litellm/integrations/{vector_stores => vector_store_integrations}/base_vector_store.py (100%) create mode 100644 litellm/integrations/vector_store_integrations/bedrock_vector_store.py create mode 100644 litellm/llms/base_llm/vector_store/transformation.py create mode 100644 litellm/llms/openai/vector_stores/transformation.py create mode 100644 litellm/vector_stores/__init__.py create mode 100644 litellm/vector_stores/main.py create mode 100644 litellm/vector_stores/utils.py create mode 100644 tests/vector_store_tests/base_vector_store_test.py create mode 100644 tests/vector_store_tests/conftest.py create mode 100644 tests/vector_store_tests/test_openai_vector_store.py diff --git a/docs/my-website/docs/completion/knowledgebase.md b/docs/my-website/docs/completion/knowledgebase.md index 033dccea20..b3e6a06aa9 100644 --- a/docs/my-website/docs/completion/knowledgebase.md +++ b/docs/my-website/docs/completion/knowledgebase.md @@ -17,6 +17,7 @@ LiteLLM integrates with vector stores, allowing your models to access your organ ## Supported Vector Stores - [Bedrock Knowledge Bases](https://aws.amazon.com/bedrock/knowledge-bases/) +- [OpenAI Vector Stores](https://platform.openai.com/docs/api-reference/vector-stores/search) ## Quick Start diff --git a/litellm/__init__.py b/litellm/__init__.py index 5407dd85d5..2921c0600d 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -1141,6 +1141,7 @@ from .router import Router from .assistants.main import * from .batches.main import * from .images.main import * +from .vector_stores import * from .batch_completion.main import * # type: ignore from .rerank_api.main import * from .llms.anthropic.experimental_pass_through.messages.handler import * diff --git a/litellm/integrations/vector_stores/base_vector_store.py b/litellm/integrations/vector_store_integrations/base_vector_store.py similarity index 100% rename from litellm/integrations/vector_stores/base_vector_store.py rename to litellm/integrations/vector_store_integrations/base_vector_store.py diff --git a/litellm/integrations/vector_store_integrations/bedrock_vector_store.py b/litellm/integrations/vector_store_integrations/bedrock_vector_store.py new file mode 100644 index 0000000000..a00acefb6a --- /dev/null +++ b/litellm/integrations/vector_store_integrations/bedrock_vector_store.py @@ -0,0 +1,409 @@ +# +-------------------------------------------------------------+ +# +# Add Bedrock Knowledge Base Context to your LLM calls +# +# +-------------------------------------------------------------+ +# Thank you users! We ❤️ you! - Krrish & Ishaan + +import json +from datetime import datetime +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple + +import litellm +from litellm._logging import verbose_logger, verbose_proxy_logger +from litellm.integrations.custom_logger import CustomLogger +from litellm.integrations.vector_store_integrations.base_vector_store import ( + BaseVectorStore, +) +from litellm.llms.bedrock.base_aws_llm import BaseAWSLLM +from litellm.llms.custom_httpx.http_handler import ( + get_async_httpx_client, + httpxSpecialProvider, +) +from litellm.types.integrations.rag.bedrock_knowledgebase import ( + BedrockKBContent, + BedrockKBGuardrailConfiguration, + BedrockKBRequest, + BedrockKBResponse, + BedrockKBRetrievalConfiguration, + BedrockKBRetrievalQuery, + BedrockKBRetrievalResult, +) +from litellm.types.llms.openai import AllMessageValues, ChatCompletionUserMessage +from litellm.types.utils import StandardLoggingVectorStoreRequest +from litellm.types.vector_stores import ( + VectorStoreResultContent, + VectorStoreSearchResponse, + VectorStoreSearchResult, +) + +if TYPE_CHECKING: + from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj +else: + LiteLLMLoggingObj = Any + +if TYPE_CHECKING: + from litellm.litellm_core_utils.litellm_logging import StandardCallbackDynamicParams +else: + StandardCallbackDynamicParams = Any + + +class BedrockVectorStore(BaseVectorStore, BaseAWSLLM): + CONTENT_PREFIX_STRING = "Context: \n\n" + CUSTOM_LLM_PROVIDER = "bedrock" + + def __init__( + self, + **kwargs, + ): + self.async_handler = get_async_httpx_client( + llm_provider=httpxSpecialProvider.LoggingCallback + ) + + # store kwargs as optional_params + self.optional_params = kwargs + + super().__init__(**kwargs) + BaseAWSLLM.__init__(self) + + async def async_get_chat_completion_prompt( + self, + model: str, + messages: List[AllMessageValues], + non_default_params: dict, + prompt_id: Optional[str], + prompt_variables: Optional[dict], + dynamic_callback_params: StandardCallbackDynamicParams, + litellm_logging_obj: LiteLLMLoggingObj, + tools: Optional[List[Dict]] = None, + prompt_label: Optional[str] = None, + ) -> Tuple[str, List[AllMessageValues], dict]: + """ + Retrieves the context from the Bedrock Knowledge Base and appends it to the messages. + """ + if litellm.vector_store_registry is None: + return model, messages, non_default_params + + vector_store_ids = litellm.vector_store_registry.pop_vector_store_ids_to_run( + non_default_params=non_default_params, tools=tools + ) + vector_store_request_metadata: List[StandardLoggingVectorStoreRequest] = [] + if vector_store_ids: + for vector_store_id in vector_store_ids: + start_time = datetime.now() + query = self._get_kb_query_from_messages(messages) + bedrock_kb_response = await self.make_bedrock_kb_retrieve_request( + knowledge_base_id=vector_store_id, + query=query, + non_default_params=non_default_params, + ) + verbose_logger.debug( + f"Bedrock Knowledge Base Response: {bedrock_kb_response}" + ) + + ( + context_message, + context_string, + ) = self.get_chat_completion_message_from_bedrock_kb_response( + bedrock_kb_response + ) + if context_message is not None: + messages.append(context_message) + + ################################################################################################# + ########## LOGGING for Standard Logging Payload, Langfuse, s3, LiteLLM DB etc. ################## + ################################################################################################# + vector_store_search_response: VectorStoreSearchResponse = ( + self.transform_bedrock_kb_response_to_vector_store_search_response( + bedrock_kb_response=bedrock_kb_response, query=query + ) + ) + vector_store_request_metadata.append( + StandardLoggingVectorStoreRequest( + vector_store_id=vector_store_id, + query=query, + vector_store_search_response=vector_store_search_response, + custom_llm_provider=self.CUSTOM_LLM_PROVIDER, + start_time=start_time.timestamp(), + end_time=datetime.now().timestamp(), + ) + ) + + litellm_logging_obj.model_call_details[ + "vector_store_request_metadata" + ] = vector_store_request_metadata + + return model, messages, non_default_params + + def transform_bedrock_kb_response_to_vector_store_search_response( + self, + bedrock_kb_response: BedrockKBResponse, + query: str, + ) -> VectorStoreSearchResponse: + """ + Transform a BedrockKBResponse to a VectorStoreSearchResponse + """ + retrieval_results: Optional[ + List[BedrockKBRetrievalResult] + ] = bedrock_kb_response.get("retrievalResults", None) + vector_store_search_response: VectorStoreSearchResponse = ( + VectorStoreSearchResponse(search_query=query, data=[]) + ) + if retrieval_results is None: + return vector_store_search_response + + vector_search_response_data: List[VectorStoreSearchResult] = [] + for retrieval_result in retrieval_results: + content: Optional[BedrockKBContent] = retrieval_result.get("content", None) + if content is None: + continue + content_text: Optional[str] = content.get("text", None) + if content_text is None: + continue + vector_store_search_result: VectorStoreSearchResult = ( + VectorStoreSearchResult( + score=retrieval_result.get("score", None), + content=[VectorStoreResultContent(text=content_text, type="text")], + ) + ) + vector_search_response_data.append(vector_store_search_result) + vector_store_search_response["data"] = vector_search_response_data + return vector_store_search_response + + def _get_kb_query_from_messages(self, messages: List[AllMessageValues]) -> str: + """ + Uses the text `content` field of the last message in the list of messages + """ + if len(messages) == 0: + return "" + last_message = messages[-1] + last_message_content = last_message.get("content", None) + if last_message_content is None: + return "" + if isinstance(last_message_content, str): + return last_message_content + elif isinstance(last_message_content, list): + return "\n".join([item.get("text", "") for item in last_message_content]) + return "" + + def _prepare_request( + self, + credentials: Any, + data: BedrockKBRequest, + optional_params: dict, + aws_region_name: str, + api_base: str, + extra_headers: Optional[dict] = None, + ) -> Any: + """ + Prepare a signed AWS request. + + Args: + credentials: AWS credentials + data: Request data + optional_params: Additional parameters + aws_region_name: AWS region name + api_base: Base API URL + extra_headers: Additional headers + + Returns: + AWSRequest: A signed AWS request + """ + try: + from botocore.auth import SigV4Auth + from botocore.awsrequest import AWSRequest + except ImportError: + raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.") + + sigv4 = SigV4Auth(credentials, "bedrock", aws_region_name) + + encoded_data = json.dumps(data).encode("utf-8") + headers = {"Content-Type": "application/json"} + if extra_headers is not None: + headers = {"Content-Type": "application/json", **extra_headers} + + request = AWSRequest( + method="POST", url=api_base, data=encoded_data, headers=headers + ) + sigv4.add_auth(request) + if extra_headers is not None and "Authorization" in extra_headers: + # prevent sigv4 from overwriting the auth header + request.headers["Authorization"] = extra_headers["Authorization"] + + return request.prepare() + + async def make_bedrock_kb_retrieve_request( + self, + knowledge_base_id: str, + query: str, + guardrail_id: Optional[str] = None, + guardrail_version: Optional[str] = None, + next_token: Optional[str] = None, + retrieval_configuration: Optional[BedrockKBRetrievalConfiguration] = None, + non_default_params: Optional[dict] = None, + ) -> BedrockKBResponse: + """ + Make a Bedrock Knowledge Base retrieve request. + + Args: + knowledge_base_id (str): The unique identifier of the knowledge base to query + query (str): The query text to search for + guardrail_id (Optional[str]): The guardrail ID to apply + guardrail_version (Optional[str]): The version of the guardrail to apply + next_token (Optional[str]): Token for pagination + retrieval_configuration (Optional[BedrockKBRetrievalConfiguration]): Configuration for the retrieval process + + Returns: + BedrockKBRetrievalResponse: A typed response object containing the retrieval results + """ + from fastapi import HTTPException + + non_default_params = non_default_params or {} + credentials_dict: Dict[str, Any] = {} + if litellm.vector_store_registry is not None: + credentials_dict = ( + litellm.vector_store_registry.get_credentials_for_vector_store( + knowledge_base_id + ) + ) + + credentials = self.get_credentials( + aws_access_key_id=credentials_dict.get( + "aws_access_key_id", non_default_params.get("aws_access_key_id", None) + ), + aws_secret_access_key=credentials_dict.get( + "aws_secret_access_key", + non_default_params.get("aws_secret_access_key", None), + ), + aws_session_token=credentials_dict.get( + "aws_session_token", non_default_params.get("aws_session_token", None) + ), + aws_region_name=credentials_dict.get( + "aws_region_name", non_default_params.get("aws_region_name", None) + ), + aws_session_name=credentials_dict.get( + "aws_session_name", non_default_params.get("aws_session_name", None) + ), + aws_profile_name=credentials_dict.get( + "aws_profile_name", non_default_params.get("aws_profile_name", None) + ), + aws_role_name=credentials_dict.get( + "aws_role_name", non_default_params.get("aws_role_name", None) + ), + aws_web_identity_token=credentials_dict.get( + "aws_web_identity_token", + non_default_params.get("aws_web_identity_token", None), + ), + aws_sts_endpoint=credentials_dict.get( + "aws_sts_endpoint", non_default_params.get("aws_sts_endpoint", None) + ), + ) + aws_region_name = self.get_aws_region_name_for_non_llm_api_calls( + aws_region_name=credentials_dict.get( + "aws_region_name", non_default_params.get("aws_region_name", None) + ), + ) + + # Prepare request data + request_data: BedrockKBRequest = BedrockKBRequest( + retrievalQuery=BedrockKBRetrievalQuery(text=query), + ) + if next_token: + request_data["nextToken"] = next_token + if retrieval_configuration: + request_data["retrievalConfiguration"] = retrieval_configuration + if guardrail_id and guardrail_version: + request_data["guardrailConfiguration"] = BedrockKBGuardrailConfiguration( + guardrailId=guardrail_id, guardrailVersion=guardrail_version + ) + verbose_logger.debug( + f"Request Data: {json.dumps(request_data, indent=4, default=str)}" + ) + + # Prepare the request + api_base = f"https://bedrock-agent-runtime.{aws_region_name}.amazonaws.com/knowledgebases/{knowledge_base_id}/retrieve" + + prepared_request = self._prepare_request( + credentials=credentials, + data=request_data, + optional_params=self.optional_params, + aws_region_name=aws_region_name, + api_base=api_base, + ) + + verbose_proxy_logger.debug( + "Bedrock Knowledge Base request body: %s, url %s, headers: %s", + request_data, + prepared_request.url, + prepared_request.headers, + ) + + response = await self.async_handler.post( + url=prepared_request.url, + data=prepared_request.body, # type: ignore + headers=prepared_request.headers, # type: ignore + ) + + verbose_proxy_logger.debug("Bedrock Knowledge Base response: %s", response.text) + + if response.status_code == 200: + response_data = response.json() + return BedrockKBResponse(**response_data) + else: + verbose_proxy_logger.error( + "Bedrock Knowledge Base: error in response. Status code: %s, response: %s", + response.status_code, + response.text, + ) + raise HTTPException( + status_code=response.status_code, + detail={ + "error": "Error calling Bedrock Knowledge Base", + "response": response.text, + }, + ) + + @staticmethod + def get_initialized_custom_logger() -> Optional[CustomLogger]: + from litellm.litellm_core_utils.litellm_logging import ( + _init_custom_logger_compatible_class, + ) + + return _init_custom_logger_compatible_class( + logging_integration="bedrock_vector_store", + internal_usage_cache=None, + llm_router=None, + ) + + @staticmethod + def get_chat_completion_message_from_bedrock_kb_response( + response: BedrockKBResponse, + ) -> Tuple[Optional[ChatCompletionUserMessage], str]: + """ + Retrieves the context from the Bedrock Knowledge Base response and returns a ChatCompletionUserMessage object. + """ + retrieval_results: Optional[List[BedrockKBRetrievalResult]] = response.get( + "retrievalResults", None + ) + if retrieval_results is None: + return None, "" + + # string to combine the context from the knowledge base + context_string: str = BedrockVectorStore.CONTENT_PREFIX_STRING + for retrieval_result in retrieval_results: + retrieval_result_content: Optional[BedrockKBContent] = ( + retrieval_result.get("content", None) or {} + ) + if retrieval_result_content is None: + continue + retrieval_result_text: Optional[str] = retrieval_result_content.get( + "text", None + ) + if retrieval_result_text is None: + continue + context_string += retrieval_result_text + message = ChatCompletionUserMessage( + role="user", + content=context_string, + ) + return message, context_string diff --git a/litellm/integrations/vector_stores/bedrock_vector_store.py b/litellm/integrations/vector_stores/bedrock_vector_store.py index 0523dac8ed..a00acefb6a 100644 --- a/litellm/integrations/vector_stores/bedrock_vector_store.py +++ b/litellm/integrations/vector_stores/bedrock_vector_store.py @@ -12,7 +12,9 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple import litellm from litellm._logging import verbose_logger, verbose_proxy_logger from litellm.integrations.custom_logger import CustomLogger -from litellm.integrations.vector_stores.base_vector_store import BaseVectorStore +from litellm.integrations.vector_store_integrations.base_vector_store import ( + BaseVectorStore, +) from litellm.llms.bedrock.base_aws_llm import BaseAWSLLM from litellm.llms.custom_httpx.http_handler import ( get_async_httpx_client, diff --git a/litellm/litellm_core_utils/custom_logger_registry.py b/litellm/litellm_core_utils/custom_logger_registry.py index 20dc2b0c90..1b75cc3e3d 100644 --- a/litellm/litellm_core_utils/custom_logger_registry.py +++ b/litellm/litellm_core_utils/custom_logger_registry.py @@ -32,7 +32,9 @@ from litellm.integrations.opentelemetry import OpenTelemetry from litellm.integrations.opik.opik import OpikLogger from litellm.integrations.prometheus import PrometheusLogger from litellm.integrations.s3_v2 import S3Logger -from litellm.integrations.vector_stores.bedrock_vector_store import BedrockVectorStore +from litellm.integrations.vector_store_integrations.bedrock_vector_store import ( + BedrockVectorStore, +) from litellm.proxy.hooks.dynamic_rate_limiter import _PROXY_DynamicRateLimitHandler diff --git a/litellm/litellm_core_utils/litellm_logging.py b/litellm/litellm_core_utils/litellm_logging.py index c608dfc8d8..aa0280210d 100644 --- a/litellm/litellm_core_utils/litellm_logging.py +++ b/litellm/litellm_core_utils/litellm_logging.py @@ -54,7 +54,9 @@ from litellm.integrations.custom_guardrail import CustomGuardrail from litellm.integrations.custom_logger import CustomLogger from litellm.integrations.deepeval.deepeval import DeepEvalLogger from litellm.integrations.mlflow import MlflowLogger -from litellm.integrations.vector_stores.bedrock_vector_store import BedrockVectorStore +from litellm.integrations.vector_store_integrations.bedrock_vector_store import ( + BedrockVectorStore, +) from litellm.litellm_core_utils.get_litellm_params import get_litellm_params from litellm.litellm_core_utils.llm_cost_calc.tool_call_cost_tracking import ( StandardBuiltInToolCostTracking, diff --git a/litellm/llms/base_llm/vector_store/transformation.py b/litellm/llms/base_llm/vector_store/transformation.py new file mode 100644 index 0000000000..4df63bd2f9 --- /dev/null +++ b/litellm/llms/base_llm/vector_store/transformation.py @@ -0,0 +1,73 @@ +from abc import abstractmethod +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union + +import httpx + +from litellm.types.router import GenericLiteLLMParams +from litellm.types.vector_stores import ( + VectorStoreSearchOptionalRequestParams, + VectorStoreSearchResponse, +) + +if TYPE_CHECKING: + from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj + + from ..chat.transformation import BaseLLMException as _BaseLLMException + + LiteLLMLoggingObj = _LiteLLMLoggingObj + BaseLLMException = _BaseLLMException +else: + LiteLLMLoggingObj = Any + BaseLLMException = Any + +class BaseVectorStoreConfig: + @abstractmethod + def transform_search_vector_store_request( + self, + vector_store_id: str, + query: Union[str, List[str]], + vector_store_search_optional_params: VectorStoreSearchOptionalRequestParams, + api_base: str, + ) -> Tuple[str, Dict]: + pass + + @abstractmethod + def transform_search_vector_store_response(self, response: httpx.Response) -> VectorStoreSearchResponse: + pass + + + @abstractmethod + def validate_environment( + self, headers: dict, litellm_params: Optional[GenericLiteLLMParams] + ) -> dict: + return {} + + @abstractmethod + def get_complete_url( + self, + api_base: Optional[str], + litellm_params: dict, + ) -> str: + """ + OPTIONAL + + Get the complete url for the request + + Some providers need `model` in `api_base` + """ + if api_base is None: + raise ValueError("api_base is required") + return api_base + + + def get_error_class( + self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers] + ) -> BaseLLMException: + from ..chat.transformation import BaseLLMException + + raise BaseLLMException( + status_code=status_code, + message=error_message, + headers=headers, + ) + diff --git a/litellm/llms/custom_httpx/llm_http_handler.py b/litellm/llms/custom_httpx/llm_http_handler.py index 4ed0016d8b..dfc5ae5277 100644 --- a/litellm/llms/custom_httpx/llm_http_handler.py +++ b/litellm/llms/custom_httpx/llm_http_handler.py @@ -35,6 +35,7 @@ from litellm.llms.base_llm.image_edit.transformation import BaseImageEditConfig from litellm.llms.base_llm.realtime.transformation import BaseRealtimeConfig from litellm.llms.base_llm.rerank.transformation import BaseRerankConfig from litellm.llms.base_llm.responses.transformation import BaseResponsesAPIConfig +from litellm.llms.base_llm.vector_store.transformation import BaseVectorStoreConfig from litellm.llms.custom_httpx.http_handler import ( AsyncHTTPHandler, HTTPHandler, @@ -60,6 +61,10 @@ from litellm.types.rerank import OptionalRerankParams, RerankResponse from litellm.types.responses.main import DeleteResponseResult from litellm.types.router import GenericLiteLLMParams from litellm.types.utils import EmbeddingResponse, FileTypes, TranscriptionResponse +from litellm.types.vector_stores import ( + VectorStoreSearchOptionalRequestParams, + VectorStoreSearchResponse, +) from litellm.utils import ( CustomStreamWrapper, ImageResponse, @@ -2342,7 +2347,7 @@ class BaseLLMHTTPHandler: self, e: Exception, provider_config: Union[ - BaseConfig, BaseRerankConfig, BaseResponsesAPIConfig, BaseImageEditConfig + BaseConfig, BaseRerankConfig, BaseResponsesAPIConfig, BaseImageEditConfig, BaseVectorStoreConfig ], ): status_code = getattr(e, "status_code", 500) @@ -2613,3 +2618,143 @@ class BaseLLMHTTPHandler: raw_response=response, logging_obj=logging_obj, ) + + ###### VECTOR STORE HANDLER ###### + async def async_vector_store_search_handler( + self, + vector_store_id: str, + query: Union[str, List[str]], + vector_store_search_optional_params: VectorStoreSearchOptionalRequestParams, + vector_store_provider_config: BaseVectorStoreConfig, + custom_llm_provider: str, + litellm_params: GenericLiteLLMParams, + logging_obj: LiteLLMLoggingObj, + extra_headers: Optional[Dict[str, Any]] = None, + extra_body: Optional[Dict[str, Any]] = None, + timeout: Optional[Union[float, httpx.Timeout]] = None, + client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None, + _is_async: bool = False, + ) -> VectorStoreSearchResponse: + if client is None or not isinstance(client, AsyncHTTPHandler): + async_httpx_client = get_async_httpx_client( + llm_provider=litellm.LlmProviders(custom_llm_provider), + params={"ssl_verify": litellm_params.get("ssl_verify", None)}, + ) + else: + async_httpx_client = client + + headers = vector_store_provider_config.validate_environment( + headers=extra_headers or {}, + litellm_params=litellm_params + ) + + if extra_headers: + headers.update(extra_headers) + + api_base = vector_store_provider_config.get_complete_url( + api_base=litellm_params.api_base, + litellm_params=dict(litellm_params), + ) + + url, request_body = vector_store_provider_config.transform_search_vector_store_request( + vector_store_id=vector_store_id, + query=query, + vector_store_search_optional_params=vector_store_search_optional_params, + api_base=api_base, + ) + + logging_obj.pre_call( + input="", + api_key="", + additional_args={ + "complete_input_dict": request_body, + "api_base": api_base, + "headers": headers, + }, + ) + + try: + response = await async_httpx_client.post(url=url, headers=headers, json=request_body, timeout=timeout) + except Exception as e: + raise self._handle_error(e=e, provider_config=vector_store_provider_config) + + return vector_store_provider_config.transform_search_vector_store_response( + response=response, + ) + + def vector_store_search_handler( + self, + vector_store_id: str, + query: Union[str, List[str]], + vector_store_search_optional_params: VectorStoreSearchOptionalRequestParams, + vector_store_provider_config: BaseVectorStoreConfig, + custom_llm_provider: str, + litellm_params: GenericLiteLLMParams, + logging_obj: LiteLLMLoggingObj, + extra_headers: Optional[Dict[str, Any]] = None, + extra_body: Optional[Dict[str, Any]] = None, + timeout: Optional[Union[float, httpx.Timeout]] = None, + client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None, + _is_async: bool = False, + ) -> Union[VectorStoreSearchResponse, Coroutine[Any, Any, VectorStoreSearchResponse]]: + if _is_async: + return self.async_vector_store_search_handler( + vector_store_id=vector_store_id, + query=query, + vector_store_search_optional_params=vector_store_search_optional_params, + vector_store_provider_config=vector_store_provider_config, + litellm_params=litellm_params, + logging_obj=logging_obj, + custom_llm_provider=custom_llm_provider, + extra_headers=extra_headers, + extra_body=extra_body, + timeout=timeout, + client=client, + ) + + if client is None or not isinstance(client, HTTPHandler): + sync_httpx_client = _get_httpx_client( + params={"ssl_verify": litellm_params.get("ssl_verify", None)} + ) + else: + sync_httpx_client = client + + headers = vector_store_provider_config.validate_environment( + headers=extra_headers or {}, + litellm_params=litellm_params + ) + + if extra_headers: + headers.update(extra_headers) + + api_base = vector_store_provider_config.get_complete_url( + api_base=litellm_params.api_base, + litellm_params=dict(litellm_params), + ) + + url, request_body = vector_store_provider_config.transform_search_vector_store_request( + vector_store_id=vector_store_id, + query=query, + vector_store_search_optional_params=vector_store_search_optional_params, + api_base=api_base, + ) + + logging_obj.pre_call( + input="", + api_key="", + additional_args={ + "complete_input_dict": request_body, + "api_base": api_base, + "headers": headers, + }, + ) + + try: + response = sync_httpx_client.post(url=url, headers=headers, json=request_body) + except Exception as e: + raise self._handle_error(e=e, provider_config=vector_store_provider_config) + + return vector_store_provider_config.transform_search_vector_store_response( + response=response, + ) + diff --git a/litellm/llms/openai/vector_stores/transformation.py b/litellm/llms/openai/vector_stores/transformation.py new file mode 100644 index 0000000000..9865782568 --- /dev/null +++ b/litellm/llms/openai/vector_stores/transformation.py @@ -0,0 +1,107 @@ +from typing import Dict, List, Optional, Tuple, Union, cast + +import httpx + +import litellm +from litellm.llms.base_llm.vector_store.transformation import BaseVectorStoreConfig +from litellm.secret_managers.main import get_secret_str +from litellm.types.router import GenericLiteLLMParams +from litellm.types.vector_stores import ( + VectorStoreSearchOptionalRequestParams, + VectorStoreSearchRequest, + VectorStoreSearchResponse, +) + + +class OpenAIVectorStoreConfig(BaseVectorStoreConfig): + ASSISTANTS_HEADER_KEY = "OpenAI-Beta" + ASSISTANTS_HEADER_VALUE = "assistants=v2" + + def validate_environment( + self, headers: dict, litellm_params: Optional[GenericLiteLLMParams] + ) -> dict: + litellm_params = litellm_params or GenericLiteLLMParams() + api_key = ( + litellm_params.api_key + or litellm.api_key + or litellm.openai_key + or get_secret_str("OPENAI_API_KEY") + ) + headers.update( + { + "Authorization": f"Bearer {api_key}", + } + ) + + ######################################################### + # Ensure OpenAI Assistants header is includes + ######################################################### + if self.ASSISTANTS_HEADER_KEY not in headers: + headers.update( + { + self.ASSISTANTS_HEADER_KEY: self.ASSISTANTS_HEADER_VALUE, + } + ) + + return headers + + def get_complete_url( + self, + api_base: Optional[str], + litellm_params: dict, + ) -> str: + """ + Get the Base endpoint for OpenAI Vector Stores API + """ + api_base = ( + api_base + or litellm.api_base + or get_secret_str("OPENAI_BASE_URL") + or get_secret_str("OPENAI_API_BASE") + or "https://api.openai.com/v1" + ) + + # Remove trailing slashes + api_base = api_base.rstrip("/") + + return f"{api_base}/vector_stores" + + + def transform_search_vector_store_request( + self, + vector_store_id: str, + query: Union[str, List[str]], + vector_store_search_optional_params: VectorStoreSearchOptionalRequestParams, + api_base: str, + ) -> Tuple[str, Dict]: + url = f"{api_base}/{vector_store_id}/search" + typed_request_body = VectorStoreSearchRequest( + query=query, + filters=vector_store_search_optional_params.get("filters", None), + max_num_results=vector_store_search_optional_params.get("max_num_results", None), + ranking_options=vector_store_search_optional_params.get("ranking_options", None), + rewrite_query=vector_store_search_optional_params.get("rewrite_query", None), + ) + + dict_request_body = cast(dict, typed_request_body) + return url, dict_request_body + + + + def transform_search_vector_store_response(self, response: httpx.Response) -> VectorStoreSearchResponse: + try: + response_json = response.json() + return VectorStoreSearchResponse( + **response_json + ) + except Exception as e: + raise self.get_error_class( + error_message=str(e), + status_code=response.status_code, + headers=response.headers + ) + + + + + \ No newline at end of file diff --git a/litellm/types/vector_stores.py b/litellm/types/vector_stores.py index cd8280ac96..0fcf7a2ab0 100644 --- a/litellm/types/vector_stores.py +++ b/litellm/types/vector_stores.py @@ -85,3 +85,14 @@ class VectorStoreSearchResponse(TypedDict, total=False): ] # Always "vector_store.search_results.page" search_query: Optional[str] data: Optional[List[VectorStoreSearchResult]] + +class VectorStoreSearchOptionalRequestParams(TypedDict, total=False): + """TypedDict for Optional parameters supported by the vector store search API.""" + filters: Optional[Dict] + max_num_results: Optional[int] + ranking_options: Optional[Dict] + rewrite_query: Optional[bool] + +class VectorStoreSearchRequest(VectorStoreSearchOptionalRequestParams, total=False): + """Request body for searching a vector store""" + query: Union[str, List[str]] \ No newline at end of file diff --git a/litellm/utils.py b/litellm/utils.py index dacc168045..25f46d4bc5 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -78,7 +78,9 @@ from litellm.constants import ( ) from litellm.integrations.custom_guardrail import CustomGuardrail from litellm.integrations.custom_logger import CustomLogger -from litellm.integrations.vector_stores.base_vector_store import BaseVectorStore +from litellm.integrations.vector_store_integrations.base_vector_store import ( + BaseVectorStore, +) from litellm.litellm_core_utils.core_helpers import ( map_finish_reason, process_response_headers, @@ -242,6 +244,7 @@ from litellm.llms.base_llm.image_variations.transformation import ( from litellm.llms.base_llm.realtime.transformation import BaseRealtimeConfig from litellm.llms.base_llm.rerank.transformation import BaseRerankConfig from litellm.llms.base_llm.responses.transformation import BaseResponsesAPIConfig +from litellm.llms.base_llm.vector_store.transformation import BaseVectorStoreConfig from ._logging import _is_debugging_on, verbose_logger from .caching.caching import ( @@ -6902,13 +6905,28 @@ class ProviderConfigManager: def get_provider_vector_store_config( provider: LlmProviders, ) -> Optional[CustomLogger]: - from litellm.integrations.vector_stores.bedrock_vector_store import ( + from litellm.integrations.vector_store_integrations.bedrock_vector_store import ( BedrockVectorStore, ) if LlmProviders.BEDROCK == provider: return BedrockVectorStore.get_initialized_custom_logger() return None + + + @staticmethod + def get_provider_vector_stores_config( + provider: LlmProviders, + ) -> Optional[BaseVectorStoreConfig]: + """ + v2 vector store config, use this for new vector store integrations + """ + if litellm.LlmProviders.OPENAI == provider: + from litellm.llms.openai.vector_stores.transformation import ( + OpenAIVectorStoreConfig, + ) + return OpenAIVectorStoreConfig() + return None @staticmethod def get_provider_image_generation_config( diff --git a/litellm/vector_stores/__init__.py b/litellm/vector_stores/__init__.py new file mode 100644 index 0000000000..6546f9159e --- /dev/null +++ b/litellm/vector_stores/__init__.py @@ -0,0 +1,4 @@ +from .main import asearch, search +from .vector_store_registry import VectorStoreRegistry + +__all__ = ["search", "asearch", "VectorStoreRegistry"] diff --git a/litellm/vector_stores/main.py b/litellm/vector_stores/main.py new file mode 100644 index 0000000000..245a3f6fa3 --- /dev/null +++ b/litellm/vector_stores/main.py @@ -0,0 +1,232 @@ +""" +LiteLLM SDK Functions for Creating and Searching Vector Stores +""" +import asyncio +import contextvars +from functools import partial +from typing import Any, Coroutine, Dict, List, Optional, Union + +import httpx + +import litellm +from litellm.constants import request_timeout +from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj +from litellm.llms.custom_httpx.llm_http_handler import BaseLLMHTTPHandler +from litellm.types.router import GenericLiteLLMParams +from litellm.types.vector_stores import ( + VectorStoreResultContent, + VectorStoreSearchOptionalRequestParams, + VectorStoreSearchResponse, + VectorStoreSearchResult, +) +from litellm.utils import ProviderConfigManager, client +from litellm.vector_stores.utils import VectorStoreRequestUtils + +####### ENVIRONMENT VARIABLES ################### +# Initialize any necessary instances or variables here +base_llm_http_handler = BaseLLMHTTPHandler() +################################################# + + +def mock_vector_store_search_response( + mock_results: Optional[List[VectorStoreSearchResult]] = None, +): + """Mock response for vector store search""" + if mock_results is None: + mock_results = [ + VectorStoreSearchResult( + score=0.95, + content=[ + VectorStoreResultContent( + text="This is a sample search result from the vector store.", + type="text" + ) + ] + ) + ] + + return VectorStoreSearchResponse( + object="vector_store.search_results.page", + search_query="sample query", + data=mock_results, + ) + + +@client +async def asearch( + vector_store_id: str, + query: Union[str, List[str]], + filters: Optional[Dict] = None, + max_num_results: Optional[int] = None, + ranking_options: Optional[Dict] = None, + rewrite_query: Optional[bool] = None, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Optional[Dict[str, Any]] = None, + extra_query: Optional[Dict[str, Any]] = None, + extra_body: Optional[Dict[str, Any]] = None, + timeout: Optional[Union[float, httpx.Timeout]] = None, + # LiteLLM specific params, + custom_llm_provider: Optional[str] = None, + **kwargs, +) -> VectorStoreSearchResponse: + """ + Async: Search a vector store for relevant chunks based on a query and file attributes filter. + """ + local_vars = locals() + try: + loop = asyncio.get_event_loop() + kwargs["asearch"] = True + + # get custom llm provider so we can use this for mapping exceptions + if custom_llm_provider is None: + custom_llm_provider = "openai" # Default to OpenAI for vector stores + + func = partial( + search, + vector_store_id=vector_store_id, + query=query, + filters=filters, + max_num_results=max_num_results, + ranking_options=ranking_options, + rewrite_query=rewrite_query, + extra_headers=extra_headers, + extra_query=extra_query, + extra_body=extra_body, + timeout=timeout, + custom_llm_provider=custom_llm_provider, + **kwargs, + ) + + ctx = contextvars.copy_context() + func_with_context = partial(ctx.run, func) + init_response = await loop.run_in_executor(None, func_with_context) + + if asyncio.iscoroutine(init_response): + response = await init_response + else: + response = init_response + + return response + except Exception as e: + raise litellm.exception_type( + model=None, + custom_llm_provider=custom_llm_provider, + original_exception=e, + completion_kwargs=local_vars, + extra_kwargs=kwargs, + ) + + +@client +def search( + vector_store_id: str, + query: Union[str, List[str]], + filters: Optional[Dict] = None, + max_num_results: Optional[int] = None, + ranking_options: Optional[Dict] = None, + rewrite_query: Optional[bool] = None, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Optional[Dict[str, Any]] = None, + extra_query: Optional[Dict[str, Any]] = None, + extra_body: Optional[Dict[str, Any]] = None, + timeout: Optional[Union[float, httpx.Timeout]] = None, + # LiteLLM specific params, + custom_llm_provider: Optional[str] = None, + **kwargs, +) -> Union[VectorStoreSearchResponse, Coroutine[Any, Any, VectorStoreSearchResponse]]: + """ + Search a vector store for relevant chunks based on a query and file attributes filter. + + Args: + vector_store_id: The ID of the vector store to search. + query: A query string or array for the search. + filters: Optional filter to apply based on file attributes. + max_num_results: Maximum number of results to return (1-50, default 10). + ranking_options: Optional ranking options for search. + rewrite_query: Whether to rewrite the natural language query for vector search. + + Returns: + VectorStoreSearchResponse containing the search results. + """ + local_vars = locals() + try: + litellm_logging_obj: LiteLLMLoggingObj = kwargs.get("litellm_logging_obj") # type: ignore + litellm_call_id: Optional[str] = kwargs.get("litellm_call_id", None) + _is_async = kwargs.pop("asearch", False) is True + + # get llm provider logic + litellm_params = GenericLiteLLMParams(**kwargs) + + ## MOCK RESPONSE LOGIC + if litellm_params.mock_response and isinstance( + litellm_params.mock_response, (str, list) + ): + mock_results = None + if isinstance(litellm_params.mock_response, list): + mock_results = litellm_params.mock_response + return mock_vector_store_search_response(mock_results=mock_results) + + # Default to OpenAI for vector stores + if custom_llm_provider is None: + custom_llm_provider = "openai" + + # get provider config - using vector store custom logger for now + vector_store_provider_config = ProviderConfigManager.get_provider_vector_stores_config( + provider=litellm.LlmProviders(custom_llm_provider), + ) + + if vector_store_provider_config is None: + raise ValueError( + f"Vector store search is not supported for {custom_llm_provider}" + ) + + local_vars.update(kwargs) + + # Get VectorStoreSearchOptionalRequestParams with only valid parameters + vector_store_search_optional_params: VectorStoreSearchOptionalRequestParams = ( + VectorStoreRequestUtils.get_requested_vector_store_search_optional_param( + local_vars + ) + ) + + # Pre Call logging + litellm_logging_obj.update_environment_variables( + model=None, + optional_params={ + "vector_store_id": vector_store_id, + "query": query, + **vector_store_search_optional_params, + }, + litellm_params={ + "litellm_call_id": litellm_call_id, + "vector_store_id": vector_store_id, + }, + custom_llm_provider=custom_llm_provider, + ) + + response = base_llm_http_handler.vector_store_search_handler( + vector_store_id=vector_store_id, + query=query, + vector_store_search_optional_params=vector_store_search_optional_params, + vector_store_provider_config=vector_store_provider_config, + custom_llm_provider=custom_llm_provider, + litellm_params=litellm_params, + logging_obj=litellm_logging_obj, + extra_headers=extra_headers, + extra_body=extra_body, + timeout=timeout or request_timeout, + _is_async=_is_async, + client=kwargs.get("client"), + ) + + return response + except Exception as e: + raise litellm.exception_type( + model=None, + custom_llm_provider=custom_llm_provider, + original_exception=e, + completion_kwargs=local_vars, + extra_kwargs=kwargs, + ) \ No newline at end of file diff --git a/litellm/vector_stores/utils.py b/litellm/vector_stores/utils.py new file mode 100644 index 0000000000..7817526d00 --- /dev/null +++ b/litellm/vector_stores/utils.py @@ -0,0 +1,28 @@ +from typing import Any, Dict, cast, get_type_hints + +from litellm.types.vector_stores import VectorStoreSearchOptionalRequestParams + + +class VectorStoreRequestUtils: + """Helper utils for constructing Vector Store search requests""" + + @staticmethod + def get_requested_vector_store_search_optional_param( + params: Dict[str, Any], + ) -> VectorStoreSearchOptionalRequestParams: + """ + Filter parameters to only include those defined in VectorStoreSearchOptionalRequestParams. + + Args: + params: Dictionary of parameters to filter + + Returns: + VectorStoreSearchOptionalRequestParams instance with only the valid parameters + """ + valid_keys = get_type_hints(VectorStoreSearchOptionalRequestParams).keys() + filtered_params = { + k: v for k, v in params.items() if k in valid_keys and v is not None + } + + return cast(VectorStoreSearchOptionalRequestParams, filtered_params) + diff --git a/tests/logging_callback_tests/test_bedrock_knowledgebase_hook.py b/tests/logging_callback_tests/test_bedrock_knowledgebase_hook.py index 532a5c11c5..4c75d48527 100644 --- a/tests/logging_callback_tests/test_bedrock_knowledgebase_hook.py +++ b/tests/logging_callback_tests/test_bedrock_knowledgebase_hook.py @@ -19,7 +19,7 @@ import pytest import litellm from litellm import completion from litellm._logging import verbose_logger -from litellm.integrations.vector_stores.bedrock_vector_store import BedrockVectorStore +from litellm.integrations.vector_store_integrations.bedrock_vector_store import BedrockVectorStore from litellm.llms.custom_httpx.http_handler import HTTPHandler, AsyncHTTPHandler from litellm.integrations.custom_logger import CustomLogger from litellm.types.utils import StandardLoggingPayload, StandardLoggingVectorStoreRequest diff --git a/tests/vector_store_tests/base_vector_store_test.py b/tests/vector_store_tests/base_vector_store_test.py new file mode 100644 index 0000000000..d23012847b --- /dev/null +++ b/tests/vector_store_tests/base_vector_store_test.py @@ -0,0 +1,136 @@ +import httpx +import json +import pytest +import sys +from typing import Any, Dict, List +from unittest.mock import MagicMock, Mock, patch +import os +import uuid +import time +import base64 + +sys.path.insert( + 0, os.path.abspath("../..") +) # Adds the parent directory to the system path +import litellm +from abc import ABC, abstractmethod +from litellm.integrations.custom_logger import CustomLogger +import json +from litellm.types.utils import StandardLoggingPayload + +class BaseVectorStoreTest(ABC): + """ + Abstract base test class that enforces a common test across all test classes. + """ + @abstractmethod + def get_base_request_args(self) -> dict: + """Must return the base request args""" + pass + + @pytest.mark.parametrize("sync_mode", [True, False]) + @pytest.mark.asyncio + async def test_basic_search_vector_store(self, sync_mode): + litellm._turn_on_debug() + litellm.set_verbose = True + base_request_args = self.get_base_request_args() + try: + if sync_mode: + response = litellm.vector_stores.search( + query="Basic ping", + **base_request_args + ) + else: + response = await litellm.vector_stores.asearch( + query="Basic ping", + **base_request_args + ) + except litellm.InternalServerError: + pytest.skip("Skipping test due to litellm.InternalServerError") + + print("litellm response=", json.dumps(response, indent=4, default=str)) + + # Validate response structure + self._validate_vector_store_response(response) + + def _validate_vector_store_response(self, response): + """Validate the structure and content of a vector store search response""" + + # Check that response is a dictionary + assert isinstance(response, dict), f"Response should be a dict, got {type(response)}" + + # Check required top-level fields + required_fields = ['object', 'search_query', 'data'] + for field in required_fields: + assert field in response, f"Missing required field '{field}' in response" + + # Validate object field + assert response['object'] == 'vector_store.search_results.page', \ + f"Expected object to be 'vector_store.search_results.page', got '{response['object']}'" + + # Validate search_query field + assert isinstance(response['search_query'], list), \ + f"search_query should be a list, got {type(response['search_query'])}" + assert len(response['search_query']) > 0, "search_query should not be empty" + assert all(isinstance(query, str) for query in response['search_query']), \ + "All items in search_query should be strings" + + # Validate data field + assert isinstance(response['data'], list), \ + f"data should be a list, got {type(response['data'])}" + + # Validate each result in data + for i, result in enumerate(response['data']): + self._validate_search_result(result, i) + + print(f"✅ Response validation passed: Found {len(response['data'])} search results") + + def _validate_search_result(self, result, index): + """Validate an individual search result""" + + # Check that result is a dictionary + assert isinstance(result, dict), f"Result {index} should be a dict, got {type(result)}" + + # Check required fields in each result + required_result_fields = ['file_id', 'filename', 'score', 'attributes', 'content'] + for field in required_result_fields: + assert field in result, f"Missing required field '{field}' in result {index}" + + # Validate file_id + assert isinstance(result['file_id'], str), \ + f"file_id should be a string, got {type(result['file_id'])} in result {index}" + assert len(result['file_id']) > 0, f"file_id should not be empty in result {index}" + + # Validate filename + assert isinstance(result['filename'], str), \ + f"filename should be a string, got {type(result['filename'])} in result {index}" + assert len(result['filename']) > 0, f"filename should not be empty in result {index}" + + # Validate score + assert isinstance(result['score'], (int, float)), \ + f"score should be a number, got {type(result['score'])} in result {index}" + assert 0.0 <= result['score'] <= 1.0, \ + f"score should be between 0.0 and 1.0, got {result['score']} in result {index}" + + # Validate attributes + assert isinstance(result['attributes'], dict), \ + f"attributes should be a dict, got {type(result['attributes'])} in result {index}" + + # Validate content + assert isinstance(result['content'], list), \ + f"content should be a list, got {type(result['content'])} in result {index}" + assert len(result['content']) > 0, f"content should not be empty in result {index}" + + # Validate each content item + for j, content_item in enumerate(result['content']): + assert isinstance(content_item, dict), \ + f"Content item {j} in result {index} should be a dict, got {type(content_item)}" + assert 'type' in content_item, \ + f"Content item {j} in result {index} missing 'type' field" + assert 'text' in content_item, \ + f"Content item {j} in result {index} missing 'text' field" + assert isinstance(content_item['text'], str), \ + f"Content text should be a string in item {j} of result {index}" + assert len(content_item['text']) > 0, \ + f"Content text should not be empty in item {j} of result {index}" + + print(f"✅ Result {index} validation passed: {result['filename']} (score: {result['score']:.4f})") diff --git a/tests/vector_store_tests/conftest.py b/tests/vector_store_tests/conftest.py new file mode 100644 index 0000000000..b3561d8a62 --- /dev/null +++ b/tests/vector_store_tests/conftest.py @@ -0,0 +1,63 @@ +# conftest.py + +import importlib +import os +import sys + +import pytest + +sys.path.insert( + 0, os.path.abspath("../..") +) # Adds the parent directory to the system path +import litellm + + +@pytest.fixture(scope="function", autouse=True) +def setup_and_teardown(): + """ + This fixture reloads litellm before every function. To speed up testing by removing callbacks being chained. + """ + curr_dir = os.getcwd() # Get the current working directory + sys.path.insert( + 0, os.path.abspath("../..") + ) # Adds the project directory to the system path + + import litellm + from litellm import Router + + importlib.reload(litellm) + + try: + if hasattr(litellm, "proxy") and hasattr(litellm.proxy, "proxy_server"): + import litellm.proxy.proxy_server + + importlib.reload(litellm.proxy.proxy_server) + except Exception as e: + print(f"Error reloading litellm.proxy.proxy_server: {e}") + + import asyncio + + loop = asyncio.get_event_loop_policy().new_event_loop() + asyncio.set_event_loop(loop) + print(litellm) + # from litellm import Router, completion, aembedding, acompletion, embedding + yield + + # Teardown code (executes after the yield point) + loop.close() # Close the loop created earlier + asyncio.set_event_loop(None) # Remove the reference to the loop + + +def pytest_collection_modifyitems(config, items): + # Separate tests in 'test_amazing_proxy_custom_logger.py' and other tests + custom_logger_tests = [ + item for item in items if "custom_logger" in item.parent.name + ] + other_tests = [item for item in items if "custom_logger" not in item.parent.name] + + # Sort tests based on their names + custom_logger_tests.sort(key=lambda x: x.name) + other_tests.sort(key=lambda x: x.name) + + # Reorder the items list + items[:] = custom_logger_tests + other_tests diff --git a/tests/vector_store_tests/test_openai_vector_store.py b/tests/vector_store_tests/test_openai_vector_store.py new file mode 100644 index 0000000000..20980f2ff5 --- /dev/null +++ b/tests/vector_store_tests/test_openai_vector_store.py @@ -0,0 +1,11 @@ +from base_vector_store_test import BaseVectorStoreTest + +class TestOpenAIVectorStore(BaseVectorStoreTest): + def get_base_request_args(self) -> dict: + """ + This is a real vector store on OpenAI + """ + return { + "vector_store_id": "vs_685b14b1a1b88191bc27e04f1917fddd", + "custom_llm_provider": "openai", + } \ No newline at end of file From d6cc3847808dabbee472943f9977a1be25ab99fe Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Tue, 24 Jun 2025 20:46:48 -0700 Subject: [PATCH 07/12] [Feat] OpenAI/Azure OpenAI - Add support for creating vector stores on LiteLLM (#12021) * add create/acreate vector store * add azure config * add _base_validate_azure_environment * fix base test * add get_base_create_vector_store_args * use base llm for headers responses api * add _get_base_azure_url * fix AzureOpenAIVectorStoreConfig * TestAzureOpenAIVectorStore * fix azure openai vector store * fix test comment * fix unused imports * test_validate_environment_azure_api_key_within_secret_str * test_azure_transformation.py --- .../docs/completion/knowledgebase.md | 1 + litellm/llms/azure/common_utils.py | 80 ++++++- .../llms/azure/responses/transformation.py | 76 +------ .../azure/vector_stores/transformation.py | 27 +++ .../base_llm/vector_store/transformation.py | 13 ++ litellm/llms/custom_httpx/llm_http_handler.py | 130 +++++++++++ .../openai/vector_stores/transformation.py | 33 +++ litellm/types/vector_stores.py | 71 +++++- litellm/utils.py | 5 + litellm/vector_stores/__init__.py | 4 +- litellm/vector_stores/main.py | 202 ++++++++++++++++++ litellm/vector_stores/utils.py | 25 ++- .../response/test_azure_transformation.py | 26 ++- .../base_vector_store_test.py | 119 +++++++++++ .../test_azure_vector_store.py | 25 +++ .../test_openai_vector_store.py | 9 + 16 files changed, 764 insertions(+), 82 deletions(-) create mode 100644 litellm/llms/azure/vector_stores/transformation.py create mode 100644 tests/vector_store_tests/test_azure_vector_store.py diff --git a/docs/my-website/docs/completion/knowledgebase.md b/docs/my-website/docs/completion/knowledgebase.md index b3e6a06aa9..a1c926274c 100644 --- a/docs/my-website/docs/completion/knowledgebase.md +++ b/docs/my-website/docs/completion/knowledgebase.md @@ -18,6 +18,7 @@ LiteLLM integrates with vector stores, allowing your models to access your organ ## Supported Vector Stores - [Bedrock Knowledge Bases](https://aws.amazon.com/bedrock/knowledge-bases/) - [OpenAI Vector Stores](https://platform.openai.com/docs/api-reference/vector-stores/search) +- [Azure Vector Stores](https://learn.microsoft.com/en-us/azure/ai-services/openai/how-to/file-search?tabs=python#vector-stores ## Quick Start diff --git a/litellm/llms/azure/common_utils.py b/litellm/llms/azure/common_utils.py index ed62175289..f2a8defe13 100644 --- a/litellm/llms/azure/common_utils.py +++ b/litellm/llms/azure/common_utils.py @@ -1,12 +1,11 @@ import json import os -from typing import Any, Callable, Dict, Optional, Union +from typing import Any, Callable, Dict, Literal, Optional, Union, cast import httpx from openai import AsyncAzureOpenAI, AzureOpenAI import litellm -from litellm.types.router import GenericLiteLLMParams from litellm._logging import verbose_logger from litellm.caching.caching import DualCache from litellm.llms.base_llm.chat.transformation import BaseLLMException @@ -15,6 +14,8 @@ from litellm.secret_managers.get_azure_ad_token_provider import ( get_azure_ad_token_provider, ) from litellm.secret_managers.main import get_secret_str +from litellm.types.router import GenericLiteLLMParams +from litellm.utils import _add_path_to_api_base azure_ad_cache = DualCache() @@ -613,3 +614,78 @@ class BaseAzureLLM(BaseOpenAILLM): else: client = AzureOpenAI(**azure_client_params) # type: ignore return client + + @staticmethod + def _base_validate_azure_environment( + headers: dict, litellm_params: Optional[GenericLiteLLMParams] + ) -> dict: + litellm_params = litellm_params or GenericLiteLLMParams() + api_key = ( + litellm_params.api_key + or litellm.api_key + or litellm.azure_key + or get_secret_str("AZURE_OPENAI_API_KEY") + or get_secret_str("AZURE_API_KEY") + ) + + if api_key: + headers["api-key"] = api_key + return headers + + ### Fallback to Azure AD token-based authentication if no API key is available + ### Retrieves Azure AD token and adds it to the Authorization header + azure_ad_token = get_azure_ad_token(litellm_params) + if azure_ad_token: + headers["Authorization"] = f"Bearer {azure_ad_token}" + + return headers + + @staticmethod + def _get_base_azure_url( + api_base: Optional[str], + litellm_params: Optional[Union[GenericLiteLLMParams, Dict[str, Any]]], + route: Literal["/openai/responses", "/openai/vector_stores"] + ) -> str: + api_base = api_base or litellm.api_base or get_secret_str("AZURE_API_BASE") + if api_base is None: + raise ValueError( + f"api_base is required for Azure AI Studio. Please set the api_base parameter. Passed `api_base={api_base}`" + ) + original_url = httpx.URL(api_base) + + # Extract api_version or use default + litellm_params = litellm_params or {} + api_version = cast(Optional[str], litellm_params.get("api_version")) + + # Create a new dictionary with existing params + query_params = dict(original_url.params) + + # Add api_version if needed + if "api-version" not in query_params and api_version: + query_params["api-version"] = api_version + + # Add the path to the base URL + if route not in api_base: + new_url = _add_path_to_api_base( + api_base=api_base, ending_path=route + ) + else: + new_url = api_base + + if BaseAzureLLM._is_azure_v1_api_version(api_version): + # ensure the request go to /openai/v1 and not just /openai + if "/openai/v1" not in new_url: + parsed_url = httpx.URL(new_url) + new_url = str(parsed_url.copy_with(path=parsed_url.path.replace("/openai", "/openai/v1"))) + + + # Use the new query_params dictionary + final_url = httpx.URL(new_url).copy_with(params=query_params) + + return str(final_url) + + @staticmethod + def _is_azure_v1_api_version(api_version: Optional[str]) -> bool: + if api_version is None: + return False + return api_version == "preview" or api_version == "latest" diff --git a/litellm/llms/azure/responses/transformation.py b/litellm/llms/azure/responses/transformation.py index d5425ef439..e6f48179e4 100644 --- a/litellm/llms/azure/responses/transformation.py +++ b/litellm/llms/azure/responses/transformation.py @@ -1,16 +1,11 @@ -from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, cast +from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Tuple -import httpx - -import litellm from litellm._logging import verbose_logger -from litellm.llms.azure.common_utils import get_azure_ad_token +from litellm.llms.azure.common_utils import BaseAzureLLM from litellm.llms.openai.responses.transformation import OpenAIResponsesAPIConfig -from litellm.secret_managers.main import get_secret_str from litellm.types.llms.openai import * from litellm.types.responses.main import * from litellm.types.router import GenericLiteLLMParams -from litellm.utils import _add_path_to_api_base if TYPE_CHECKING: from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj @@ -24,27 +19,11 @@ class AzureOpenAIResponsesAPIConfig(OpenAIResponsesAPIConfig): def validate_environment( self, headers: dict, model: str, litellm_params: Optional[GenericLiteLLMParams] ) -> dict: - litellm_params = litellm_params or GenericLiteLLMParams() - api_key = ( - litellm_params.api_key - or litellm.api_key - or litellm.azure_key - or get_secret_str("AZURE_OPENAI_API_KEY") - or get_secret_str("AZURE_API_KEY") + return BaseAzureLLM._base_validate_azure_environment( + headers=headers, + litellm_params=litellm_params ) - if api_key: - headers["api-key"] = api_key - return headers - - ### Fallback to Azure AD token-based authentication if no API key is available - ### Retrieves Azure AD token and adds it to the Authorization header - azure_ad_token = get_azure_ad_token(litellm_params) - if azure_ad_token: - headers["Authorization"] = f"Bearer {azure_ad_token}" - - return headers - def get_complete_url( self, api_base: Optional[str], @@ -66,47 +45,12 @@ class AzureOpenAIResponsesAPIConfig(OpenAIResponsesAPIConfig): - A complete URL string, e.g., "https://litellm8397336933.openai.azure.com/openai/responses?api-version=2024-05-01-preview" """ - api_base = api_base or litellm.api_base or get_secret_str("AZURE_API_BASE") - if api_base is None: - raise ValueError( - f"api_base is required for Azure AI Studio. Please set the api_base parameter. Passed `api_base={api_base}`" - ) - original_url = httpx.URL(api_base) - - # Extract api_version or use default - api_version = cast(Optional[str], litellm_params.get("api_version")) - - # Create a new dictionary with existing params - query_params = dict(original_url.params) - - # Add api_version if needed - if "api-version" not in query_params and api_version: - query_params["api-version"] = api_version - - # Add the path to the base URL - if "/openai/responses" not in api_base: - new_url = _add_path_to_api_base( - api_base=api_base, ending_path="/openai/responses" - ) - else: - new_url = api_base - - if self._is_azure_v1_api_version(api_version): - # ensure the request go to /openai/v1 and not just /openai - if "/openai/v1" not in new_url: - parsed_url = httpx.URL(new_url) - new_url = str(parsed_url.copy_with(path=parsed_url.path.replace("/openai", "/openai/v1"))) - - - # Use the new query_params dictionary - final_url = httpx.URL(new_url).copy_with(params=query_params) - - return str(final_url) + return BaseAzureLLM._get_base_azure_url( + api_base=api_base, + litellm_params=litellm_params, + route="/openai/responses" + ) - def _is_azure_v1_api_version(self, api_version: Optional[str]) -> bool: - if api_version is None: - return False - return api_version == "preview" or api_version == "latest" ######################################################### ########## DELETE RESPONSE API TRANSFORMATION ############## diff --git a/litellm/llms/azure/vector_stores/transformation.py b/litellm/llms/azure/vector_stores/transformation.py new file mode 100644 index 0000000000..f1cd81b2bf --- /dev/null +++ b/litellm/llms/azure/vector_stores/transformation.py @@ -0,0 +1,27 @@ +from typing import Optional + +from litellm.llms.azure.common_utils import BaseAzureLLM +from litellm.llms.openai.vector_stores.transformation import OpenAIVectorStoreConfig +from litellm.types.router import GenericLiteLLMParams + + +class AzureOpenAIVectorStoreConfig(OpenAIVectorStoreConfig): + def get_complete_url( + self, + api_base: Optional[str], + litellm_params: dict, + ) -> str: + return BaseAzureLLM._get_base_azure_url( + api_base=api_base, + litellm_params=litellm_params, + route="/openai/vector_stores" + ) + + + def validate_environment( + self, headers: dict, litellm_params: Optional[GenericLiteLLMParams] + ) -> dict: + return BaseAzureLLM._base_validate_azure_environment( + headers=headers, + litellm_params=litellm_params + ) \ No newline at end of file diff --git a/litellm/llms/base_llm/vector_store/transformation.py b/litellm/llms/base_llm/vector_store/transformation.py index 4df63bd2f9..6ec9d25ae5 100644 --- a/litellm/llms/base_llm/vector_store/transformation.py +++ b/litellm/llms/base_llm/vector_store/transformation.py @@ -5,6 +5,8 @@ import httpx from litellm.types.router import GenericLiteLLMParams from litellm.types.vector_stores import ( + VectorStoreCreateOptionalRequestParams, + VectorStoreCreateResponse, VectorStoreSearchOptionalRequestParams, VectorStoreSearchResponse, ) @@ -35,6 +37,17 @@ class BaseVectorStoreConfig: def transform_search_vector_store_response(self, response: httpx.Response) -> VectorStoreSearchResponse: pass + @abstractmethod + def transform_create_vector_store_request( + self, + vector_store_create_optional_params: VectorStoreCreateOptionalRequestParams, + api_base: str, + ) -> Tuple[str, Dict]: + pass + + @abstractmethod + def transform_create_vector_store_response(self, response: httpx.Response) -> VectorStoreCreateResponse: + pass @abstractmethod def validate_environment( diff --git a/litellm/llms/custom_httpx/llm_http_handler.py b/litellm/llms/custom_httpx/llm_http_handler.py index dfc5ae5277..92f9d6d958 100644 --- a/litellm/llms/custom_httpx/llm_http_handler.py +++ b/litellm/llms/custom_httpx/llm_http_handler.py @@ -62,6 +62,8 @@ from litellm.types.responses.main import DeleteResponseResult from litellm.types.router import GenericLiteLLMParams from litellm.types.utils import EmbeddingResponse, FileTypes, TranscriptionResponse from litellm.types.vector_stores import ( + VectorStoreCreateOptionalRequestParams, + VectorStoreCreateResponse, VectorStoreSearchOptionalRequestParams, VectorStoreSearchResponse, ) @@ -2758,3 +2760,131 @@ class BaseLLMHTTPHandler: response=response, ) + async def async_vector_store_create_handler( + self, + vector_store_create_optional_params: VectorStoreCreateOptionalRequestParams, + vector_store_provider_config: BaseVectorStoreConfig, + custom_llm_provider: str, + litellm_params: GenericLiteLLMParams, + logging_obj: LiteLLMLoggingObj, + extra_headers: Optional[Dict[str, Any]] = None, + extra_body: Optional[Dict[str, Any]] = None, + timeout: Optional[Union[float, httpx.Timeout]] = None, + client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None, + _is_async: bool = False, + ) -> VectorStoreCreateResponse: + if client is None or not isinstance(client, AsyncHTTPHandler): + async_httpx_client = get_async_httpx_client( + llm_provider=litellm.LlmProviders(custom_llm_provider), + params={"ssl_verify": litellm_params.get("ssl_verify", None)}, + ) + else: + async_httpx_client = client + + headers = vector_store_provider_config.validate_environment( + headers=extra_headers or {}, + litellm_params=litellm_params + ) + + if extra_headers: + headers.update(extra_headers) + + api_base = vector_store_provider_config.get_complete_url( + api_base=litellm_params.api_base, + litellm_params=dict(litellm_params), + ) + + url, request_body = vector_store_provider_config.transform_create_vector_store_request( + vector_store_create_optional_params=vector_store_create_optional_params, + api_base=api_base, + ) + + logging_obj.pre_call( + input="", + api_key="", + additional_args={ + "complete_input_dict": request_body, + "api_base": api_base, + "headers": headers, + }, + ) + + try: + response = await async_httpx_client.post(url=url, headers=headers, json=request_body, timeout=timeout) + except Exception as e: + raise self._handle_error(e=e, provider_config=vector_store_provider_config) + + return vector_store_provider_config.transform_create_vector_store_response( + response=response, + ) + + def vector_store_create_handler( + self, + vector_store_create_optional_params: VectorStoreCreateOptionalRequestParams, + vector_store_provider_config: BaseVectorStoreConfig, + custom_llm_provider: str, + litellm_params: GenericLiteLLMParams, + logging_obj: LiteLLMLoggingObj, + extra_headers: Optional[Dict[str, Any]] = None, + extra_body: Optional[Dict[str, Any]] = None, + timeout: Optional[Union[float, httpx.Timeout]] = None, + client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None, + _is_async: bool = False, + ) -> Union[VectorStoreCreateResponse, Coroutine[Any, Any, VectorStoreCreateResponse]]: + if _is_async: + return self.async_vector_store_create_handler( + vector_store_create_optional_params=vector_store_create_optional_params, + vector_store_provider_config=vector_store_provider_config, + litellm_params=litellm_params, + logging_obj=logging_obj, + custom_llm_provider=custom_llm_provider, + extra_headers=extra_headers, + extra_body=extra_body, + timeout=timeout, + client=client, + ) + + if client is None or not isinstance(client, HTTPHandler): + sync_httpx_client = _get_httpx_client( + params={"ssl_verify": litellm_params.get("ssl_verify", None)} + ) + else: + sync_httpx_client = client + + headers = vector_store_provider_config.validate_environment( + headers=extra_headers or {}, + litellm_params=litellm_params + ) + + if extra_headers: + headers.update(extra_headers) + + api_base = vector_store_provider_config.get_complete_url( + api_base=litellm_params.api_base, + litellm_params=dict(litellm_params), + ) + + url, request_body = vector_store_provider_config.transform_create_vector_store_request( + vector_store_create_optional_params=vector_store_create_optional_params, + api_base=api_base, + ) + + logging_obj.pre_call( + input="", + api_key="", + additional_args={ + "complete_input_dict": request_body, + "api_base": api_base, + "headers": headers, + }, + ) + + try: + response = sync_httpx_client.post(url=url, headers=headers, json=request_body) + except Exception as e: + raise self._handle_error(e=e, provider_config=vector_store_provider_config) + + return vector_store_provider_config.transform_create_vector_store_response( + response=response, + ) + diff --git a/litellm/llms/openai/vector_stores/transformation.py b/litellm/llms/openai/vector_stores/transformation.py index 9865782568..11d76937ab 100644 --- a/litellm/llms/openai/vector_stores/transformation.py +++ b/litellm/llms/openai/vector_stores/transformation.py @@ -7,6 +7,9 @@ from litellm.llms.base_llm.vector_store.transformation import BaseVectorStoreCon from litellm.secret_managers.main import get_secret_str from litellm.types.router import GenericLiteLLMParams from litellm.types.vector_stores import ( + VectorStoreCreateOptionalRequestParams, + VectorStoreCreateRequest, + VectorStoreCreateResponse, VectorStoreSearchOptionalRequestParams, VectorStoreSearchRequest, VectorStoreSearchResponse, @@ -101,6 +104,36 @@ class OpenAIVectorStoreConfig(BaseVectorStoreConfig): headers=response.headers ) + def transform_create_vector_store_request( + self, + vector_store_create_optional_params: VectorStoreCreateOptionalRequestParams, + api_base: str, + ) -> Tuple[str, Dict]: + url = api_base # Base URL for creating vector stores + typed_request_body = VectorStoreCreateRequest( + name=vector_store_create_optional_params.get("name", None), + file_ids=vector_store_create_optional_params.get("file_ids", None), + expires_after=vector_store_create_optional_params.get("expires_after", None), + chunking_strategy=vector_store_create_optional_params.get("chunking_strategy", None), + metadata=vector_store_create_optional_params.get("metadata", None), + ) + + dict_request_body = cast(dict, typed_request_body) + return url, dict_request_body + + def transform_create_vector_store_response(self, response: httpx.Response) -> VectorStoreCreateResponse: + try: + response_json = response.json() + return VectorStoreCreateResponse( + **response_json + ) + except Exception as e: + raise self.get_error_class( + error_message=str(e), + status_code=response.status_code, + headers=response.headers + ) + diff --git a/litellm/types/vector_stores.py b/litellm/types/vector_stores.py index 0fcf7a2ab0..3d7190d7f9 100644 --- a/litellm/types/vector_stores.py +++ b/litellm/types/vector_stores.py @@ -95,4 +95,73 @@ class VectorStoreSearchOptionalRequestParams(TypedDict, total=False): class VectorStoreSearchRequest(VectorStoreSearchOptionalRequestParams, total=False): """Request body for searching a vector store""" - query: Union[str, List[str]] \ No newline at end of file + query: Union[str, List[str]] + + +# Vector Store Creation Types +class VectorStoreExpirationPolicy(TypedDict, total=False): + """The expiration policy for a vector store""" + anchor: Literal["last_active_at"] # Anchor timestamp after which the expiration policy applies + days: int # Number of days after anchor time that the vector store will expire + + +class VectorStoreAutoChunkingStrategy(TypedDict, total=False): + """Auto chunking strategy configuration""" + type: Literal["auto"] # Always "auto" + + +class VectorStoreStaticChunkingStrategyConfig(TypedDict, total=False): + """Static chunking strategy configuration""" + max_chunk_size_tokens: int # Maximum number of tokens per chunk + chunk_overlap_tokens: int # Number of tokens to overlap between chunks + + +class VectorStoreStaticChunkingStrategy(TypedDict, total=False): + """Static chunking strategy""" + type: Literal["static"] # Always "static" + static: VectorStoreStaticChunkingStrategyConfig + + +class VectorStoreChunkingStrategy(TypedDict, total=False): + """Union type for chunking strategies""" + # This can be either auto or static + type: Literal["auto", "static"] + static: Optional[VectorStoreStaticChunkingStrategyConfig] + + +class VectorStoreFileCounts(TypedDict, total=False): + """File counts for a vector store""" + in_progress: int + completed: int + failed: int + cancelled: int + total: int + + +class VectorStoreCreateOptionalRequestParams(TypedDict, total=False): + """TypedDict for Optional parameters supported by the vector store create API.""" + name: Optional[str] # Name of the vector store + file_ids: Optional[List[str]] # List of File IDs that the vector store should use + expires_after: Optional[VectorStoreExpirationPolicy] # Expiration policy for the vector store + chunking_strategy: Optional[VectorStoreChunkingStrategy] # Chunking strategy for the files + metadata: Optional[Dict[str, str]] # Set of key-value pairs for metadata + + +class VectorStoreCreateRequest(VectorStoreCreateOptionalRequestParams, total=False): + """Request body for creating a vector store""" + pass # All fields are optional for vector store creation + + +class VectorStoreCreateResponse(TypedDict, total=False): + """Response after creating a vector store""" + id: str # ID of the vector store + object: Literal["vector_store"] # Always "vector_store" + created_at: int # Unix timestamp of when the vector store was created + name: Optional[str] # Name of the vector store + bytes: int # Size of the vector store in bytes + file_counts: VectorStoreFileCounts # File counts for the vector store + status: Literal["expired", "in_progress", "completed"] # Status of the vector store + expires_after: Optional[VectorStoreExpirationPolicy] # Expiration policy + expires_at: Optional[int] # Unix timestamp of when the vector store expires + last_active_at: Optional[int] # Unix timestamp of when the vector store was last active + metadata: Optional[Dict[str, str]] # Metadata associated with the vector store \ No newline at end of file diff --git a/litellm/utils.py b/litellm/utils.py index 25f46d4bc5..05328c3750 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -6926,6 +6926,11 @@ class ProviderConfigManager: OpenAIVectorStoreConfig, ) return OpenAIVectorStoreConfig() + elif litellm.LlmProviders.AZURE == provider: + from litellm.llms.azure.vector_stores.transformation import ( + AzureOpenAIVectorStoreConfig, + ) + return AzureOpenAIVectorStoreConfig() return None @staticmethod diff --git a/litellm/vector_stores/__init__.py b/litellm/vector_stores/__init__.py index 6546f9159e..6bcc654032 100644 --- a/litellm/vector_stores/__init__.py +++ b/litellm/vector_stores/__init__.py @@ -1,4 +1,4 @@ -from .main import asearch, search +from .main import acreate, asearch, create, search from .vector_store_registry import VectorStoreRegistry -__all__ = ["search", "asearch", "VectorStoreRegistry"] +__all__ = ["search", "asearch", "create", "acreate", "VectorStoreRegistry"] diff --git a/litellm/vector_stores/main.py b/litellm/vector_stores/main.py index 245a3f6fa3..80d6341146 100644 --- a/litellm/vector_stores/main.py +++ b/litellm/vector_stores/main.py @@ -14,6 +14,9 @@ from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLogging from litellm.llms.custom_httpx.llm_http_handler import BaseLLMHTTPHandler from litellm.types.router import GenericLiteLLMParams from litellm.types.vector_stores import ( + VectorStoreCreateOptionalRequestParams, + VectorStoreCreateResponse, + VectorStoreFileCounts, VectorStoreResultContent, VectorStoreSearchOptionalRequestParams, VectorStoreSearchResponse, @@ -52,6 +55,205 @@ def mock_vector_store_search_response( ) +def mock_vector_store_create_response( + mock_response: Optional[VectorStoreCreateResponse] = None, +): + """Mock response for vector store create""" + if mock_response is None: + mock_response = VectorStoreCreateResponse( + id="vs_mock123", + object="vector_store", + created_at=1699061776, + name="Mock Vector Store", + bytes=0, + file_counts=VectorStoreFileCounts( + in_progress=0, + completed=0, + failed=0, + cancelled=0, + total=0, + ), + status="completed", + expires_after=None, + expires_at=None, + last_active_at=None, + metadata=None, + ) + + return mock_response + + +@client +async def acreate( + name: Optional[str] = None, + file_ids: Optional[List[str]] = None, + expires_after: Optional[Dict] = None, + chunking_strategy: Optional[Dict] = None, + metadata: Optional[Dict[str, str]] = None, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Optional[Dict[str, Any]] = None, + extra_query: Optional[Dict[str, Any]] = None, + extra_body: Optional[Dict[str, Any]] = None, + timeout: Optional[Union[float, httpx.Timeout]] = None, + # LiteLLM specific params, + custom_llm_provider: Optional[str] = None, + **kwargs, +) -> VectorStoreCreateResponse: + """ + Async: Create a vector store. + """ + local_vars = locals() + try: + loop = asyncio.get_event_loop() + kwargs["acreate"] = True + + # get custom llm provider so we can use this for mapping exceptions + if custom_llm_provider is None: + custom_llm_provider = "openai" # Default to OpenAI for vector stores + + func = partial( + create, + name=name, + file_ids=file_ids, + expires_after=expires_after, + chunking_strategy=chunking_strategy, + metadata=metadata, + extra_headers=extra_headers, + extra_query=extra_query, + extra_body=extra_body, + timeout=timeout, + custom_llm_provider=custom_llm_provider, + **kwargs, + ) + + ctx = contextvars.copy_context() + func_with_context = partial(ctx.run, func) + init_response = await loop.run_in_executor(None, func_with_context) + + if asyncio.iscoroutine(init_response): + response = await init_response + else: + response = init_response + + return response + except Exception as e: + raise litellm.exception_type( + model=None, + custom_llm_provider=custom_llm_provider, + original_exception=e, + completion_kwargs=local_vars, + extra_kwargs=kwargs, + ) + + +@client +def create( + name: Optional[str] = None, + file_ids: Optional[List[str]] = None, + expires_after: Optional[Dict] = None, + chunking_strategy: Optional[Dict] = None, + metadata: Optional[Dict[str, str]] = None, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Optional[Dict[str, Any]] = None, + extra_query: Optional[Dict[str, Any]] = None, + extra_body: Optional[Dict[str, Any]] = None, + timeout: Optional[Union[float, httpx.Timeout]] = None, + # LiteLLM specific params, + custom_llm_provider: Optional[str] = None, + **kwargs, +) -> Union[VectorStoreCreateResponse, Coroutine[Any, Any, VectorStoreCreateResponse]]: + """ + Create a vector store. + + Args: + name: The name of the vector store. + file_ids: A list of File IDs that the vector store should use. + expires_after: The expiration policy for the vector store. + chunking_strategy: The chunking strategy used to chunk the file(s). + metadata: Set of 16 key-value pairs that can be attached to an object. + + Returns: + VectorStoreCreateResponse containing the created vector store details. + """ + local_vars = locals() + try: + litellm_logging_obj: LiteLLMLoggingObj = kwargs.get("litellm_logging_obj") # type: ignore + litellm_call_id: Optional[str] = kwargs.get("litellm_call_id", None) + _is_async = kwargs.pop("acreate", False) is True + + # get llm provider logic + litellm_params = GenericLiteLLMParams(**kwargs) + + ## MOCK RESPONSE LOGIC + if litellm_params.mock_response and isinstance( + litellm_params.mock_response, dict + ): + return mock_vector_store_create_response( + mock_response=VectorStoreCreateResponse(**litellm_params.mock_response) + ) + + # Default to OpenAI for vector stores + if custom_llm_provider is None: + custom_llm_provider = "openai" + + # get provider config - using vector store custom logger for now + vector_store_provider_config = ProviderConfigManager.get_provider_vector_stores_config( + provider=litellm.LlmProviders(custom_llm_provider), + ) + + if vector_store_provider_config is None: + raise ValueError( + f"Vector store create is not supported for {custom_llm_provider}" + ) + + local_vars.update(kwargs) + + # Get VectorStoreCreateOptionalRequestParams with only valid parameters + vector_store_create_optional_params: VectorStoreCreateOptionalRequestParams = ( + VectorStoreRequestUtils.get_requested_vector_store_create_optional_param( + local_vars + ) + ) + + # Pre Call logging + litellm_logging_obj.update_environment_variables( + model=None, + optional_params={ + "name": name, + **vector_store_create_optional_params, + }, + litellm_params={ + "litellm_call_id": litellm_call_id, + }, + custom_llm_provider=custom_llm_provider, + ) + + response = base_llm_http_handler.vector_store_create_handler( + vector_store_create_optional_params=vector_store_create_optional_params, + vector_store_provider_config=vector_store_provider_config, + custom_llm_provider=custom_llm_provider, + litellm_params=litellm_params, + logging_obj=litellm_logging_obj, + extra_headers=extra_headers, + extra_body=extra_body, + timeout=timeout or request_timeout, + _is_async=_is_async, + client=kwargs.get("client"), + ) + + return response + except Exception as e: + raise litellm.exception_type( + model=None, + custom_llm_provider=custom_llm_provider, + original_exception=e, + completion_kwargs=local_vars, + extra_kwargs=kwargs, + ) + + @client async def asearch( vector_store_id: str, diff --git a/litellm/vector_stores/utils.py b/litellm/vector_stores/utils.py index 7817526d00..b7eb7790ad 100644 --- a/litellm/vector_stores/utils.py +++ b/litellm/vector_stores/utils.py @@ -1,6 +1,9 @@ from typing import Any, Dict, cast, get_type_hints -from litellm.types.vector_stores import VectorStoreSearchOptionalRequestParams +from litellm.types.vector_stores import ( + VectorStoreCreateOptionalRequestParams, + VectorStoreSearchOptionalRequestParams, +) class VectorStoreRequestUtils: @@ -26,3 +29,23 @@ class VectorStoreRequestUtils: return cast(VectorStoreSearchOptionalRequestParams, filtered_params) + @staticmethod + def get_requested_vector_store_create_optional_param( + params: Dict[str, Any], + ) -> VectorStoreCreateOptionalRequestParams: + """ + Filter parameters to only include those defined in VectorStoreCreateOptionalRequestParams. + + Args: + params: Dictionary of parameters to filter + + Returns: + VectorStoreCreateOptionalRequestParams instance with only the valid parameters + """ + valid_keys = get_type_hints(VectorStoreCreateOptionalRequestParams).keys() + filtered_params = { + k: v for k, v in params.items() if k in valid_keys and v is not None + } + + return cast(VectorStoreCreateOptionalRequestParams, filtered_params) + diff --git a/tests/test_litellm/llms/azure/response/test_azure_transformation.py b/tests/test_litellm/llms/azure/response/test_azure_transformation.py index c210f91170..8edd87a12f 100644 --- a/tests/test_litellm/llms/azure/response/test_azure_transformation.py +++ b/tests/test_litellm/llms/azure/response/test_azure_transformation.py @@ -54,9 +54,9 @@ def test_validate_environment_azure_key_within_litellm(): def test_validate_environment_azure_openai_api_key_within_secret_str(): azure_openai_responses_apiconfig = AzureOpenAIResponsesAPIConfig() - with patch( - "litellm.llms.azure.responses.transformation.get_secret_str" - ) as mock_get_secret_str: + with patch("litellm.api_key", None), \ + patch("litellm.azure_key", None), \ + patch("litellm.llms.azure.common_utils.get_secret_str") as mock_get_secret_str: # Configure the mock to return "test-api-key" when called with "AZURE_OPENAI_API_KEY" mock_get_secret_str.side_effect = ( lambda key: "test-api-key" if key == "AZURE_OPENAI_API_KEY" else None @@ -74,13 +74,19 @@ def test_validate_environment_azure_openai_api_key_within_secret_str(): def test_validate_environment_azure_api_key_within_secret_str(): azure_openai_responses_apiconfig = AzureOpenAIResponsesAPIConfig() - with patch( - "litellm.llms.azure.responses.transformation.get_secret_str" - ) as mock_get_secret_str: - # Configure the mock to return "test-api-key" when called with "AZURE_API_KEY" - mock_get_secret_str.side_effect = ( - lambda key: "test-api-key" if key == "AZURE_API_KEY" else None - ) + with patch("litellm.api_key", None), \ + patch("litellm.azure_key", None), \ + patch("litellm.llms.azure.common_utils.get_secret_str") as mock_get_secret_str: + # Configure the mock to return None for "AZURE_OPENAI_API_KEY" and "test-api-key" for "AZURE_API_KEY" + def mock_side_effect(key): + if key == "AZURE_OPENAI_API_KEY": + return None + elif key == "AZURE_API_KEY": + return "test-api-key" + else: + return None + + mock_get_secret_str.side_effect = mock_side_effect litellm_params = GenericLiteLLMParams() result = azure_openai_responses_apiconfig.validate_environment( diff --git a/tests/vector_store_tests/base_vector_store_test.py b/tests/vector_store_tests/base_vector_store_test.py index d23012847b..ac1ec704f3 100644 --- a/tests/vector_store_tests/base_vector_store_test.py +++ b/tests/vector_store_tests/base_vector_store_test.py @@ -26,6 +26,11 @@ class BaseVectorStoreTest(ABC): def get_base_request_args(self) -> dict: """Must return the base request args""" pass + + @abstractmethod + def get_base_create_vector_store_args(self) -> dict: + """Must return the base create vector store args""" + pass @pytest.mark.parametrize("sync_mode", [True, False]) @pytest.mark.asyncio @@ -52,6 +57,39 @@ class BaseVectorStoreTest(ABC): # Validate response structure self._validate_vector_store_response(response) + @pytest.mark.parametrize("sync_mode", [True, False]) + @pytest.mark.asyncio + async def test_basic_create_vector_store(self, sync_mode): + litellm._turn_on_debug() + litellm.set_verbose = True + base_request_args = self.get_base_create_vector_store_args() + + # Extract custom_llm_provider from base args if present + create_args = base_request_args + try: + if sync_mode: + response = litellm.vector_stores.create( + name="Test Vector Store", + **create_args + ) + else: + response = await litellm.vector_stores.acreate( + name="Test Vector Store", + **create_args + ) + except litellm.InternalServerError: + pytest.skip("Skipping test due to litellm.InternalServerError") + except Exception as e: + # If this is an authentication or permission error, skip the test + if "authentication" in str(e).lower() or "permission" in str(e).lower() or "unauthorized" in str(e).lower(): + pytest.skip(f"Skipping test due to authentication/permission error: {e}") + raise + + print("litellm create response=", json.dumps(response, indent=4, default=str)) + + # Validate response structure + self._validate_vector_store_create_response(response) + def _validate_vector_store_response(self, response): """Validate the structure and content of a vector store search response""" @@ -84,6 +122,87 @@ class BaseVectorStoreTest(ABC): print(f"✅ Response validation passed: Found {len(response['data'])} search results") + def _validate_vector_store_create_response(self, response): + """Validate the structure and content of a vector store create response""" + + # Check that response is a dictionary + assert isinstance(response, dict), f"Response should be a dict, got {type(response)}" + + # Check required top-level fields for create response + required_fields = ['id', 'object', 'created_at'] + for field in required_fields: + assert field in response, f"Missing required field '{field}' in create response" + + # Validate object field + assert response['object'] == 'vector_store', \ + f"Expected object to be 'vector_store', got '{response['object']}'" + + # Validate id field + assert isinstance(response['id'], str), \ + f"id should be a string, got {type(response['id'])}" + assert len(response['id']) > 0, "id should not be empty" + assert response['id'].startswith('vs_'), \ + f"id should start with 'vs_', got '{response['id']}'" + + # Validate created_at field + assert isinstance(response['created_at'], int), \ + f"created_at should be an integer, got {type(response['created_at'])}" + assert response['created_at'] > 0, "created_at should be a positive timestamp" + + # Validate optional fields if present + if 'name' in response: + assert isinstance(response['name'], str), \ + f"name should be a string, got {type(response['name'])}" + + if 'bytes' in response: + assert isinstance(response['bytes'], int), \ + f"bytes should be an integer, got {type(response['bytes'])}" + assert response['bytes'] >= 0, "bytes should be non-negative" + + if 'file_counts' in response: + self._validate_file_counts(response['file_counts']) + + if 'status' in response: + valid_statuses = ['expired', 'in_progress', 'completed'] + assert response['status'] in valid_statuses, \ + f"status should be one of {valid_statuses}, got '{response['status']}'" + + if 'expires_at' in response and response['expires_at'] is not None: + assert isinstance(response['expires_at'], int), \ + f"expires_at should be an integer, got {type(response['expires_at'])}" + + if 'last_active_at' in response and response['last_active_at'] is not None: + assert isinstance(response['last_active_at'], int), \ + f"last_active_at should be an integer, got {type(response['last_active_at'])}" + + if 'metadata' in response and response['metadata'] is not None: + assert isinstance(response['metadata'], dict), \ + f"metadata should be a dict, got {type(response['metadata'])}" + + print(f"✅ Create response validation passed: Vector store '{response['id']}' created successfully") + + def _validate_file_counts(self, file_counts): + """Validate file_counts structure""" + assert isinstance(file_counts, dict), \ + f"file_counts should be a dict, got {type(file_counts)}" + + required_count_fields = ['in_progress', 'completed', 'failed', 'cancelled', 'total'] + for field in required_count_fields: + assert field in file_counts, f"Missing required field '{field}' in file_counts" + assert isinstance(file_counts[field], int), \ + f"{field} should be an integer, got {type(file_counts[field])}" + assert file_counts[field] >= 0, f"{field} should be non-negative" + + # Validate that total equals sum of other counts + calculated_total = ( + file_counts['in_progress'] + + file_counts['completed'] + + file_counts['failed'] + + file_counts['cancelled'] + ) + assert file_counts['total'] == calculated_total, \ + f"total should equal sum of other counts ({calculated_total}), got {file_counts['total']}" + def _validate_search_result(self, result, index): """Validate an individual search result""" diff --git a/tests/vector_store_tests/test_azure_vector_store.py b/tests/vector_store_tests/test_azure_vector_store.py new file mode 100644 index 0000000000..783417cb74 --- /dev/null +++ b/tests/vector_store_tests/test_azure_vector_store.py @@ -0,0 +1,25 @@ +from base_vector_store_test import BaseVectorStoreTest +import os +import pytest + +class TestAzureOpenAIVectorStore(BaseVectorStoreTest): + def get_base_request_args(self) -> dict: + """Must return the base request args""" + return {} + + @pytest.mark.parametrize("sync_mode", [True, False]) + @pytest.mark.asyncio + async def test_basic_search_vector_store(self, sync_mode): + pass + + + def get_base_create_vector_store_args(self) -> dict: + """ + This is a real vector store on Azure + """ + return { + "custom_llm_provider": "azure", + "api_base": os.getenv("AZURE_RESPONSES_OPENAI_ENDPOINT"), + "api_key": os.getenv("AZURE_RESPONSES_OPENAI_API_KEY"), + "api_version": "2025-04-01-preview", + } \ No newline at end of file diff --git a/tests/vector_store_tests/test_openai_vector_store.py b/tests/vector_store_tests/test_openai_vector_store.py index 20980f2ff5..3e27be2f64 100644 --- a/tests/vector_store_tests/test_openai_vector_store.py +++ b/tests/vector_store_tests/test_openai_vector_store.py @@ -8,4 +8,13 @@ class TestOpenAIVectorStore(BaseVectorStoreTest): return { "vector_store_id": "vs_685b14b1a1b88191bc27e04f1917fddd", "custom_llm_provider": "openai", + } + + + def get_base_create_vector_store_args(self) -> dict: + """ + This is a real vector store on OpenAI + """ + return { + "custom_llm_provider": "openai", } \ No newline at end of file From 15dabf6573f1076cb9d39f265cac0a22b938994b Mon Sep 17 00:00:00 2001 From: Cole McIntosh <82463175+colesmcintosh@users.noreply.github.com> Date: Tue, 24 Jun 2025 21:48:08 -0600 Subject: [PATCH 08/12] docs(CLAUDE.md): add development guidance and architecture overview for Claude Code (#12011) --- CLAUDE.md | 89 +++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 89 insertions(+) create mode 100644 CLAUDE.md diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 0000000000..50bed6e43e --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,89 @@ +# CLAUDE.md + +This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. + +## Development Commands + +### Installation +- `make install-dev` - Install core development dependencies +- `make install-proxy-dev` - Install proxy development dependencies with full feature set +- `make install-test-deps` - Install all test dependencies + +### Testing +- `make test` - Run all tests +- `make test-unit` - Run unit tests (tests/test_litellm) with 4 parallel workers +- `make test-integration` - Run integration tests (excludes unit tests) +- `pytest tests/` - Direct pytest execution + +### Code Quality +- `make lint` - Run all linting (Ruff, MyPy, Black, circular imports, import safety) +- `make format` - Apply Black code formatting +- `make lint-ruff` - Run Ruff linting only +- `make lint-mypy` - Run MyPy type checking only + +### Single Test Files +- `poetry run pytest tests/path/to/test_file.py -v` - Run specific test file +- `poetry run pytest tests/path/to/test_file.py::test_function -v` - Run specific test + +## Architecture Overview + +LiteLLM is a unified interface for 100+ LLM providers with two main components: + +### Core Library (`litellm/`) +- **Main entry point**: `litellm/main.py` - Contains core completion() function +- **Provider implementations**: `litellm/llms/` - Each provider has its own subdirectory +- **Router system**: `litellm/router.py` + `litellm/router_utils/` - Load balancing and fallback logic +- **Type definitions**: `litellm/types/` - Pydantic models and type hints +- **Integrations**: `litellm/integrations/` - Third-party observability, caching, logging +- **Caching**: `litellm/caching/` - Multiple cache backends (Redis, in-memory, S3, etc.) + +### Proxy Server (`litellm/proxy/`) +- **Main server**: `proxy_server.py` - FastAPI application +- **Authentication**: `auth/` - API key management, JWT, OAuth2 +- **Database**: `db/` - Prisma ORM with PostgreSQL/SQLite support +- **Management endpoints**: `management_endpoints/` - Admin APIs for keys, teams, models +- **Pass-through endpoints**: `pass_through_endpoints/` - Provider-specific API forwarding +- **Guardrails**: `guardrails/` - Safety and content filtering hooks +- **UI Dashboard**: Served from `_experimental/out/` (Next.js build) + +## Key Patterns + +### Provider Implementation +- Providers inherit from base classes in `litellm/llms/base.py` +- Each provider has transformation functions for input/output formatting +- Support both sync and async operations +- Handle streaming responses and function calling + +### Error Handling +- Provider-specific exceptions mapped to OpenAI-compatible errors +- Fallback logic handled by Router system +- Comprehensive logging through `litellm/_logging.py` + +### Configuration +- YAML config files for proxy server (see `proxy/example_config_yaml/`) +- Environment variables for API keys and settings +- Database schema managed via Prisma (`proxy/schema.prisma`) + +## Development Notes + +### Code Style +- Uses Black formatter, Ruff linter, MyPy type checker +- Pydantic v2 for data validation +- Async/await patterns throughout +- Type hints required for all public APIs + +### Testing Strategy +- Unit tests in `tests/test_litellm/` +- Integration tests for each provider in `tests/llm_translation/` +- Proxy tests in `tests/proxy_unit_tests/` +- Load tests in `tests/load_tests/` + +### Database Migrations +- Prisma handles schema migrations +- Migration files auto-generated with `prisma migrate dev` +- Always test migrations against both PostgreSQL and SQLite + +### Enterprise Features +- Enterprise-specific code in `enterprise/` directory +- Optional features enabled via environment variables +- Separate licensing and authentication for enterprise features \ No newline at end of file From ac15ca3014027159e352b153d5d5fc5a9a3f55c0 Mon Sep 17 00:00:00 2001 From: Krish Dholakia Date: Tue, 24 Jun 2025 21:58:07 -0700 Subject: [PATCH 09/12] Teams - Support default key expiry + UI - support enforcing access for members of specific SSO Group (#12023) * fix(team_endpoints.py): support setting default key expiry allows admin to set key expiry on all team member keys makes it easier to setup default team for experimentation * feat(key_management_endpoints.py): allows admin to set duration for keys created by team members * feat(team_endpoints.py): support team_member_key_duration on `/team/update` allows setting max time team member keys are valid for * fix(team_info.tsx): ui component to update team member key duration * fix(team_info.tsx): support updating team member key duration, if set * feat(teams.tsx): add team member key duration param ui component allow admin to set this on UI * feat(ui_sso.py): support restricting ui access by sso group allows controlling who can/can't access the UI * feat(ssomodals.tsx): add initial commit adding sso group access to admin ui * feat(proxy_server.py): support reading + writing ui_access_mode from db allows admin to configure allowed sso groups from UI * feat(ui_sso.py): support enforcing all teams on sso jwt handler if ui access mode set via ui, support reading the value and enforcing it * feat(ui/): ui component for controlling sso access group allow admin to only allow users within specific sso group to log into UI * fix(uiaccesscontrolform.tsx): fix field names * feat(ui_sso.py): return received sso response in the clientside error message - enables easier debugging * test: add unit tests * fix: minor fixes --- .../key_management_endpoints.py | 30 ++ litellm/proxy/_new_secret_config.yaml | 2 +- litellm/proxy/_types.py | 3 + litellm/proxy/auth/handle_jwt.py | 2 + litellm/proxy/litellm.log | 357 ++++++++++++++++++ .../key_management_endpoints.py | 74 ++-- .../management_endpoints/sso_helper_utils.py | 9 +- .../management_endpoints/team_endpoints.py | 2 + litellm/proxy/management_endpoints/ui_sso.py | 121 +++++- litellm/proxy/proxy_server.py | 8 +- .../proxy_setting_endpoints.py | 87 +++-- .../proxy/management_endpoints/ui_sso.py | 22 +- .../proxy/management_endpoints/test_ui_sso.py | 6 +- .../src/components/SSOModals.tsx | 2 +- .../src/components/UIAccessControlForm.tsx | 148 ++++++++ .../src/components/admins.tsx | 36 ++ .../src/components/team/team_info.tsx | 29 +- ui/litellm-dashboard/src/components/teams.tsx | 11 + 18 files changed, 864 insertions(+), 85 deletions(-) create mode 100644 enterprise/litellm_enterprise/proxy/management_endpoints/key_management_endpoints.py create mode 100644 litellm/proxy/litellm.log create mode 100644 ui/litellm-dashboard/src/components/UIAccessControlForm.tsx diff --git a/enterprise/litellm_enterprise/proxy/management_endpoints/key_management_endpoints.py b/enterprise/litellm_enterprise/proxy/management_endpoints/key_management_endpoints.py new file mode 100644 index 0000000000..19ce8090db --- /dev/null +++ b/enterprise/litellm_enterprise/proxy/management_endpoints/key_management_endpoints.py @@ -0,0 +1,30 @@ +from typing import Optional + +from litellm.proxy._types import GenerateKeyRequest, LiteLLM_TeamTable + + +def add_team_member_key_duration( + team_table: Optional[LiteLLM_TeamTable], + data: GenerateKeyRequest, +) -> GenerateKeyRequest: + if team_table is None: + return data + + if data.user_id is None: # only apply for team member keys, not service accounts + return data + + if ( + team_table.metadata is not None + and team_table.metadata.get("team_member_key_duration") is not None + ): + data.duration = team_table.metadata["team_member_key_duration"] + + return data + + +def apply_enterprise_key_management_params( + data: GenerateKeyRequest, + team_table: Optional[LiteLLM_TeamTable], +) -> GenerateKeyRequest: + data = add_team_member_key_duration(team_table, data) + return data diff --git a/litellm/proxy/_new_secret_config.yaml b/litellm/proxy/_new_secret_config.yaml index 07e0a96f55..40243ae668 100644 --- a/litellm/proxy/_new_secret_config.yaml +++ b/litellm/proxy/_new_secret_config.yaml @@ -1,4 +1,4 @@ model_list: - model_name: gemini-2.5-pro litellm_params: - model: gemini/gemini-2.5-pro \ No newline at end of file + model: gemini/gemini-2.5-pro diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index 7d358cad03..111ef89f7d 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -1115,6 +1115,7 @@ class NewTeamRequest(TeamBase): team_member_budget: Optional[float] = ( None # allow user to set a budget for all team members ) + team_member_key_duration: Optional[str] = None # e.g. "1d", "1w", "1m" model_config = ConfigDict(protected_namespaces=()) @@ -1157,6 +1158,7 @@ class UpdateTeamRequest(LiteLLMPydanticObjectBase): guardrails: Optional[List[str]] = None object_permission: Optional[LiteLLM_ObjectPermissionBase] = None team_member_budget: Optional[float] = None + team_member_key_duration: Optional[str] = None class ResetTeamBudgetRequest(LiteLLMPydanticObjectBase): @@ -2792,6 +2794,7 @@ LiteLLM_ManagementEndpoint_MetadataFields = [ LiteLLM_ManagementEndpoint_MetadataFields_Premium = [ "guardrails", "tags", + "team_member_key_duration", ] diff --git a/litellm/proxy/auth/handle_jwt.py b/litellm/proxy/auth/handle_jwt.py index fa91785e6b..046f66173d 100644 --- a/litellm/proxy/auth/handle_jwt.py +++ b/litellm/proxy/auth/handle_jwt.py @@ -164,7 +164,9 @@ class JWTHandler: self.litellm_jwtauth.team_ids_jwt_field is not None and token.get(self.litellm_jwtauth.team_ids_jwt_field) is not None ): + return token[self.litellm_jwtauth.team_ids_jwt_field] + return [] def get_end_user_id( diff --git a/litellm/proxy/litellm.log b/litellm/proxy/litellm.log new file mode 100644 index 0000000000..4f592f5cc0 --- /dev/null +++ b/litellm/proxy/litellm.log @@ -0,0 +1,357 @@ +18:10:09 - LiteLLM Router:INFO: router.py:660 - Routing strategy: simple-shuffle +18:10:11 - LiteLLM Proxy:INFO: utils.py:1317 - All necessary views exist! +18:10:11 - LiteLLM Router:WARNING: router.py:4862 - Error upserting deployment: vertex_project, and vertex_location must be set in litellm_params for pass-through endpoints., ignoring and continuing with other deployments. +18:10:11 - LiteLLM Router:WARNING: router.py:4862 - Error upserting deployment: vertex_project, and vertex_location must be set in litellm_params for pass-through endpoints., ignoring and continuing with other deployments. +18:10:23 - LiteLLM Proxy:INFO: ui_sso.py:129 - Redirecting to SSO login for http://localhost:4000/sso/callback +18:10:27 - LiteLLM Proxy:INFO: ui_sso.py:495 - Starting SSO callback +18:10:27 - LiteLLM Proxy:INFO: ui_sso.py:550 - Redirecting to http://localhost:4000/sso/callback +18:10:28 - LiteLLM Proxy:INFO: ui_sso.py:581 - SSO callback result: id='krrishd' email='krrishdholakia@gmail.com' first_name=None last_name=None display_name='a3f1c107-04dc-4c93-ae60-7f32eb4b05ce' picture=None provider=None team_ids=[] +18:10:28 - LiteLLM Proxy:INFO: ui_sso.py:671 - user_defined_values for creating ui key: {'models': [], 'user_id': 'krrishd', 'user_email': 'krrishdholakia@gmail.com', 'max_budget': None, 'user_role': 'proxy_admin', 'budget_duration': None} +18:10:28 - LiteLLM Proxy:INFO: utils.py:1856 - Data Inserted into Keys Table +18:10:28 - LiteLLM Proxy:INFO: ui_sso.py:761 - user_id: krrishd; jwt_token: eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJ1c2VyX2lkIjoia3JyaXNoZCIsImtleSI6InNrLTVvOXVVc0ZaaTVBRFBiWERoanhCZlEiLCJ1c2VyX2VtYWlsIjoia3JyaXNoZGhvbGFraWFAZ21haWwuY29tIiwidXNlcl9yb2xlIjoicHJveHlfYWRtaW4iLCJsb2dpbl9tZXRob2QiOiJzc28iLCJwcmVtaXVtX3VzZXIiOnRydWUsImF1dGhfaGVhZGVyX25hbWUiOiJBdXRob3JpemF0aW9uIiwiZGlzYWJsZWRfbm9uX2FkbWluX3BlcnNvbmFsX2tleV9jcmVhdGlvbiI6ZmFsc2UsInNlcnZlcl9yb290X3BhdGgiOiIvIn0.OiZdFjZ2wiMhFbMCwu2cZYXh7oV5BB8Vta-Ysk5JBQU +18:10:28 - LiteLLM Proxy:INFO: ui_sso.py:764 - Redirecting to http://localhost:4000/ui/?login=success +18:10:30 - LiteLLM Proxy:ERROR: key_management_endpoints.py:2275 - Error in list_keys: Server disconnected without sending a response. +Traceback (most recent call last): + File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpx/_transports/default.py", line 101, in map_httpcore_exceptions + yield + File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpx/_transports/default.py", line 394, in handle_async_request + resp = await self._pool.handle_async_request(req) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpcore/_async/connection_pool.py", line 216, in handle_async_request + raise exc from None + File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpcore/_async/connection_pool.py", line 196, in handle_async_request + response = await connection.handle_async_request( + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpcore/_async/connection.py", line 101, in handle_async_request + return await self._connection.handle_async_request(request) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpcore/_async/http11.py", line 143, in handle_async_request + raise exc + File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpcore/_async/http11.py", line 113, in handle_async_request + ) = await self._receive_response_headers(**kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpcore/_async/http11.py", line 186, in _receive_response_headers + event = await self._receive_event(timeout=timeout) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpcore/_async/http11.py", line 238, in _receive_event + raise RemoteProtocolError(msg) +httpcore.RemoteProtocolError: Server disconnected without sending a response. + +The above exception was the direct cause of the following exception: + +Traceback (most recent call last): + File "/Users/krrishdholakia/Documents/litellm/litellm/proxy/management_endpoints/key_management_endpoints.py", line 2255, in list_keys + response = await _list_key_helper( + ^^^^^^^^^^^^^^^^^^^^^^^ + File "/Users/krrishdholakia/Documents/litellm/litellm/proxy/management_endpoints/key_management_endpoints.py", line 2434, in _list_key_helper + total_count = await prisma_client.db.litellm_verificationtoken.count( + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/prisma/actions.py", line 10157, in count + resp = await self._client._execute( + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/prisma/_base_client.py", line 543, in _execute + return await self._engine.query(builder.build(), tx_id=self._tx_id) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/prisma/engine/_query.py", line 402, in query + return await self.request( + ^^^^^^^^^^^^^^^^^^^ + File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/prisma/engine/_http.py", line 217, in request + response = await self.session.request(method, url, **kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/prisma/_async_http.py", line 26, in request + return Response(await self.session.request(method, url, **kwargs)) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpx/_client.py", line 1540, in request + return await self.send(request, auth=auth, follow_redirects=follow_redirects) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpx/_client.py", line 1629, in send + response = await self._send_handling_auth( + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpx/_client.py", line 1657, in _send_handling_auth + response = await self._send_handling_redirects( + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpx/_client.py", line 1694, in _send_handling_redirects + response = await self._send_single_request(request) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpx/_client.py", line 1730, in _send_single_request + response = await transport.handle_async_request(request) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpx/_transports/default.py", line 393, in handle_async_request + with map_httpcore_exceptions(): + File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/contextlib.py", line 155, in __exit__ + self.gen.throw(typ, value, traceback) + File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpx/_transports/default.py", line 118, in map_httpcore_exceptions + raise mapped_exc(message) from exc +httpx.RemoteProtocolError: Server disconnected without sending a response. +18:10:30 - LiteLLM Proxy:ERROR: proxy_server.py:2730 - litellm.proxy_server.py::add_deployment() - Error getting new models from DB - All connection attempts failed +Traceback (most recent call last): + File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpx/_transports/default.py", line 101, in map_httpcore_exceptions + yield + File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpx/_transports/default.py", line 394, in handle_async_request + resp = await self._pool.handle_async_request(req) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpcore/_async/connection_pool.py", line 216, in handle_async_request + raise exc from None + File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpcore/_async/connection_pool.py", line 196, in handle_async_request + response = await connection.handle_async_request( + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpcore/_async/connection.py", line 99, in handle_async_request + raise exc + File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpcore/_async/connection.py", line 76, in handle_async_request + stream = await self._connect(request) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpcore/_async/connection.py", line 122, in _connect + stream = await self._network_backend.connect_tcp(**kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpcore/_backends/auto.py", line 30, in connect_tcp + return await self._backend.connect_tcp( + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpcore/_backends/anyio.py", line 112, in connect_tcp + with map_exceptions(exc_map): + File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/contextlib.py", line 155, in __exit__ + self.gen.throw(typ, value, traceback) + File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpcore/_exceptions.py", line 14, in map_exceptions + raise to_exc(exc) from exc +httpcore.ConnectError: All connection attempts failed + +The above exception was the direct cause of the following exception: + +Traceback (most recent call last): + File "/Users/krrishdholakia/Documents/litellm/litellm/proxy/proxy_server.py", line 2728, in _get_models_from_db + new_models = await prisma_client.db.litellm_proxymodeltable.find_many() + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/prisma/actions.py", line 2540, in find_many + resp = await self._client._execute( + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/prisma/_base_client.py", line 543, in _execute + return await self._engine.query(builder.build(), tx_id=self._tx_id) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/prisma/engine/_query.py", line 402, in query + return await self.request( + ^^^^^^^^^^^^^^^^^^^ + File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/prisma/engine/_http.py", line 217, in request + response = await self.session.request(method, url, **kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/prisma/_async_http.py", line 26, in request + return Response(await self.session.request(method, url, **kwargs)) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpx/_client.py", line 1540, in request + return await self.send(request, auth=auth, follow_redirects=follow_redirects) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpx/_client.py", line 1629, in send + response = await self._send_handling_auth( + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpx/_client.py", line 1657, in _send_handling_auth + response = await self._send_handling_redirects( + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpx/_client.py", line 1694, in _send_handling_redirects + response = await self._send_single_request(request) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpx/_client.py", line 1730, in _send_single_request + response = await transport.handle_async_request(request) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpx/_transports/default.py", line 393, in handle_async_request + with map_httpcore_exceptions(): + File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/contextlib.py", line 155, in __exit__ + self.gen.throw(typ, value, traceback) + File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpx/_transports/default.py", line 118, in map_httpcore_exceptions + raise mapped_exc(message) from exc +httpx.ConnectError: All connection attempts failed +18:10:30 - LiteLLM Proxy:ERROR: utils.py:1404 - LiteLLM Prisma Client Exception get_generic_data: All connection attempts failed +18:10:30 - LiteLLM Proxy:ERROR: utils.py:1404 - LiteLLM Prisma Client Exception get_generic_data: All connection attempts failed +18:10:30 - LiteLLM Proxy:ERROR: proxy_server.py:2778 - litellm.proxy.proxy_server.py::ProxyConfig:add_deployment - All connection attempts failed +Traceback (most recent call last): + File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpx/_transports/default.py", line 101, in map_httpcore_exceptions + yield + File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpx/_transports/default.py", line 394, in handle_async_request + resp = await self._pool.handle_async_request(req) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpcore/_async/connection_pool.py", line 216, in handle_async_request + raise exc from None + File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpcore/_async/connection_pool.py", line 196, in handle_async_request + response = await connection.handle_async_request( + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpcore/_async/connection.py", line 99, in handle_async_request + raise exc + File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpcore/_async/connection.py", line 76, in handle_async_request + stream = await self._connect(request) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpcore/_async/connection.py", line 122, in _connect + stream = await self._network_backend.connect_tcp(**kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpcore/_backends/auto.py", line 30, in connect_tcp + return await self._backend.connect_tcp( + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpcore/_backends/anyio.py", line 112, in connect_tcp + with map_exceptions(exc_map): + File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/contextlib.py", line 155, in __exit__ + self.gen.throw(typ, value, traceback) + File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpcore/_exceptions.py", line 14, in map_exceptions + raise to_exc(exc) from exc +httpcore.ConnectError: All connection attempts failed + +The above exception was the direct cause of the following exception: + +Traceback (most recent call last): + File "/Users/krrishdholakia/Documents/litellm/litellm/proxy/proxy_server.py", line 2760, in add_deployment + await self._update_llm_router( + File "/Users/krrishdholakia/Documents/litellm/litellm/proxy/proxy_server.py", line 2418, in _update_llm_router + config_data = await proxy_config.get_config() + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/Users/krrishdholakia/Documents/litellm/litellm/proxy/proxy_server.py", line 1584, in get_config + config = await self._update_config_from_db( + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/Users/krrishdholakia/Documents/litellm/litellm/proxy/proxy_server.py", line 2706, in _update_config_from_db + responses = await asyncio.gather(*_tasks) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/Users/krrishdholakia/Documents/litellm/litellm/proxy/db/log_db_metrics.py", line 99, in wrapper + raise e + File "/Users/krrishdholakia/Documents/litellm/litellm/proxy/db/log_db_metrics.py", line 42, in wrapper + result = await func(*args, **kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/backoff/_async.py", line 151, in retry + ret = await target(*args, **kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/Users/krrishdholakia/Documents/litellm/litellm/proxy/utils.py", line 1418, in get_generic_data + raise e + File "/Users/krrishdholakia/Documents/litellm/litellm/proxy/utils.py", line 1392, in get_generic_data + response = await self.db.litellm_config.find_first( # type: ignore + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/prisma/actions.py", line 11822, in find_first + resp = await self._client._execute( + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/prisma/_base_client.py", line 543, in _execute + return await self._engine.query(builder.build(), tx_id=self._tx_id) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/prisma/engine/_query.py", line 402, in query + return await self.request( + ^^^^^^^^^^^^^^^^^^^ + File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/prisma/engine/_http.py", line 217, in request + response = await self.session.request(method, url, **kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/prisma/_async_http.py", line 26, in request + return Response(await self.session.request(method, url, **kwargs)) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpx/_client.py", line 1540, in request + return await self.send(request, auth=auth, follow_redirects=follow_redirects) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpx/_client.py", line 1629, in send + response = await self._send_handling_auth( + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpx/_client.py", line 1657, in _send_handling_auth + response = await self._send_handling_redirects( + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpx/_client.py", line 1694, in _send_handling_redirects + response = await self._send_single_request(request) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpx/_client.py", line 1730, in _send_single_request + response = await transport.handle_async_request(request) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpx/_transports/default.py", line 393, in handle_async_request + with map_httpcore_exceptions(): + File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/contextlib.py", line 155, in __exit__ + self.gen.throw(typ, value, traceback) + File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpx/_transports/default.py", line 118, in map_httpcore_exceptions + raise mapped_exc(message) from exc +httpx.ConnectError: All connection attempts failed +18:10:30 - LiteLLM Proxy:ERROR: utils.py:1404 - LiteLLM Prisma Client Exception get_generic_data: All connection attempts failed +18:10:30 - LiteLLM Proxy:ERROR: utils.py:1404 - LiteLLM Prisma Client Exception get_generic_data: All connection attempts failed +18:10:30 - LiteLLM Proxy:INFO: proxy_server.py:490 - Shutting down LiteLLM Proxy Server +18:11:47 - LiteLLM Router:INFO: router.py:660 - Routing strategy: simple-shuffle +18:11:49 - LiteLLM Proxy:INFO: utils.py:1317 - All necessary views exist! +18:11:50 - LiteLLM Router:WARNING: router.py:4862 - Error upserting deployment: vertex_project, and vertex_location must be set in litellm_params for pass-through endpoints., ignoring and continuing with other deployments. +18:11:50 - LiteLLM Router:WARNING: router.py:4862 - Error upserting deployment: vertex_project, and vertex_location must be set in litellm_params for pass-through endpoints., ignoring and continuing with other deployments. +18:12:00 - LiteLLM Proxy:ERROR: proxy_server.py:2925 - litellm.proxy_server.py::get_credentials() - Error getting credentials from DB - Server disconnected without sending a response. +Traceback (most recent call last): + File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpx/_transports/default.py", line 101, in map_httpcore_exceptions + yield + File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpx/_transports/default.py", line 394, in handle_async_request + resp = await self._pool.handle_async_request(req) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpcore/_async/connection_pool.py", line 216, in handle_async_request + raise exc from None + File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpcore/_async/connection_pool.py", line 196, in handle_async_request + response = await connection.handle_async_request( + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpcore/_async/connection.py", line 101, in handle_async_request + return await self._connection.handle_async_request(request) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpcore/_async/http11.py", line 143, in handle_async_request + raise exc + File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpcore/_async/http11.py", line 113, in handle_async_request + ) = await self._receive_response_headers(**kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpcore/_async/http11.py", line 186, in _receive_response_headers + event = await self._receive_event(timeout=timeout) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpcore/_async/http11.py", line 238, in _receive_event + raise RemoteProtocolError(msg) +httpcore.RemoteProtocolError: Server disconnected without sending a response. + +The above exception was the direct cause of the following exception: + +Traceback (most recent call last): + File "/Users/krrishdholakia/Documents/litellm/litellm/proxy/proxy_server.py", line 2916, in get_credentials + credentials = await prisma_client.db.litellm_credentialstable.find_many() + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/prisma/actions.py", line 1502, in find_many + resp = await self._client._execute( + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/prisma/_base_client.py", line 543, in _execute + return await self._engine.query(builder.build(), tx_id=self._tx_id) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/prisma/engine/_query.py", line 402, in query + return await self.request( + ^^^^^^^^^^^^^^^^^^^ + File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/prisma/engine/_http.py", line 217, in request + response = await self.session.request(method, url, **kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/prisma/_async_http.py", line 26, in request + return Response(await self.session.request(method, url, **kwargs)) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpx/_client.py", line 1540, in request + return await self.send(request, auth=auth, follow_redirects=follow_redirects) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpx/_client.py", line 1629, in send + response = await self._send_handling_auth( + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpx/_client.py", line 1657, in _send_handling_auth + response = await self._send_handling_redirects( + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpx/_client.py", line 1694, in _send_handling_redirects + response = await self._send_single_request(request) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpx/_client.py", line 1730, in _send_single_request + response = await transport.handle_async_request(request) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpx/_transports/default.py", line 393, in handle_async_request + with map_httpcore_exceptions(): + File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/contextlib.py", line 155, in __exit__ + self.gen.throw(typ, value, traceback) + File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpx/_transports/default.py", line 118, in map_httpcore_exceptions + raise mapped_exc(message) from exc +httpx.RemoteProtocolError: Server disconnected without sending a response. +18:12:01 - LiteLLM Proxy:INFO: proxy_server.py:490 - Shutting down LiteLLM Proxy Server +18:12:14 - LiteLLM Router:INFO: router.py:660 - Routing strategy: simple-shuffle +18:12:16 - LiteLLM Proxy:INFO: utils.py:1317 - All necessary views exist! +18:12:16 - LiteLLM Router:WARNING: router.py:4862 - Error upserting deployment: vertex_project, and vertex_location must be set in litellm_params for pass-through endpoints., ignoring and continuing with other deployments. +18:12:16 - LiteLLM Router:WARNING: router.py:4862 - Error upserting deployment: vertex_project, and vertex_location must be set in litellm_params for pass-through endpoints., ignoring and continuing with other deployments. +18:12:21 - LiteLLM Proxy:INFO: ui_sso.py:129 - Redirecting to SSO login for http://localhost:4000/sso/callback +18:12:26 - LiteLLM Proxy:INFO: ui_sso.py:495 - Starting SSO callback +18:12:26 - LiteLLM Proxy:INFO: ui_sso.py:550 - Redirecting to http://localhost:4000/sso/callback +18:12:26 - LiteLLM Proxy:INFO: ui_sso.py:581 - SSO callback result: id='krrishd' email='krrishdholakia@gmail.com' first_name=None last_name=None display_name='a3f1c107-04dc-4c93-ae60-7f32eb4b05ce' picture=None provider=None team_ids=[] +18:12:27 - LiteLLM Proxy:INFO: ui_sso.py:672 - user_defined_values for creating ui key: {'models': [], 'user_id': 'krrishd', 'user_email': 'krrishdholakia@gmail.com', 'max_budget': None, 'user_role': 'proxy_admin', 'budget_duration': None} +18:12:27 - LiteLLM Proxy:INFO: utils.py:1856 - Data Inserted into Keys Table +18:12:27 - LiteLLM Proxy:INFO: ui_sso.py:762 - user_id: krrishd; jwt_token: eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJ1c2VyX2lkIjoia3JyaXNoZCIsImtleSI6InNrLUQzMEFpdW9lckU3YlMyakFXWVFLd1EiLCJ1c2VyX2VtYWlsIjoia3JyaXNoZGhvbGFraWFAZ21haWwuY29tIiwidXNlcl9yb2xlIjoicHJveHlfYWRtaW4iLCJsb2dpbl9tZXRob2QiOiJzc28iLCJwcmVtaXVtX3VzZXIiOnRydWUsImF1dGhfaGVhZGVyX25hbWUiOiJBdXRob3JpemF0aW9uIiwiZGlzYWJsZWRfbm9uX2FkbWluX3BlcnNvbmFsX2tleV9jcmVhdGlvbiI6ZmFsc2UsInNlcnZlcl9yb290X3BhdGgiOiIvIn0.EzYP86hw12J4WHLe6ZZz4YgVNGPnxM_PHqLjINH2_-U +18:12:27 - LiteLLM Proxy:INFO: ui_sso.py:765 - Redirecting to http://localhost:4000/ui/?login=success +18:12:31 - LiteLLM Proxy:INFO: proxy_server.py:490 - Shutting down LiteLLM Proxy Server +18:15:07 - LiteLLM Router:INFO: router.py:660 - Routing strategy: simple-shuffle +18:15:09 - LiteLLM Proxy:INFO: utils.py:1317 - All necessary views exist! +18:15:09 - LiteLLM Router:WARNING: router.py:4862 - Error upserting deployment: vertex_project, and vertex_location must be set in litellm_params for pass-through endpoints., ignoring and continuing with other deployments. +18:15:09 - LiteLLM Router:WARNING: router.py:4862 - Error upserting deployment: vertex_project, and vertex_location must be set in litellm_params for pass-through endpoints., ignoring and continuing with other deployments. +18:15:17 - LiteLLM Proxy:INFO: utils.py:1916 - Data Inserted into Config Table +18:15:28 - LiteLLM Proxy:INFO: ui_sso.py:129 - Redirecting to SSO login for http://localhost:4000/sso/callback +18:15:32 - LiteLLM Proxy:INFO: ui_sso.py:495 - Starting SSO callback +18:15:32 - LiteLLM Proxy:INFO: ui_sso.py:550 - Redirecting to http://localhost:4000/sso/callback +18:15:32 - LiteLLM Proxy:INFO: ui_sso.py:581 - SSO callback result: id='krrishd' email='krrishdholakia@gmail.com' first_name=None last_name=None display_name='a3f1c107-04dc-4c93-ae60-7f32eb4b05ce' picture=None provider=None team_ids=[] +18:15:37 - LiteLLM Proxy:INFO: proxy_server.py:490 - Shutting down LiteLLM Proxy Server diff --git a/litellm/proxy/management_endpoints/key_management_endpoints.py b/litellm/proxy/management_endpoints/key_management_endpoints.py index fdc8e73dad..0b1c7f523c 100644 --- a/litellm/proxy/management_endpoints/key_management_endpoints.py +++ b/litellm/proxy/management_endpoints/key_management_endpoints.py @@ -512,6 +512,20 @@ async def generate_key_fn( # noqa: PLR0915 }, ) + # APPLY ENTERPRISE KEY MANAGEMENT PARAMS + try: + from litellm_enterprise.proxy.management_endpoints.key_management_endpoints import ( + apply_enterprise_key_management_params, + ) + + data = apply_enterprise_key_management_params(data, team_table) + except Exception as e: + verbose_proxy_logger.info( + "litellm.proxy.proxy_server.generate_key_fn(): Enterprise key management params not applied - {}".format( + str(e) + ) + ) + # TODO: @ishaan-jaff: Migrate all budget tracking to use LiteLLM_BudgetTable _budget_id = data.budget_id if prisma_client is not None and data.soft_budget is not None: @@ -536,7 +550,7 @@ async def generate_key_fn( # noqa: PLR0915 # ADD METADATA FIELDS # Set Management Endpoint Metadata Fields for field in LiteLLM_ManagementEndpoint_MetadataFields_Premium: - if getattr(data, field) is not None: + if getattr(data, field, None) is not None: _set_object_metadata_field( object_data=data, field_name=field, @@ -589,9 +603,9 @@ async def generate_key_fn( # noqa: PLR0915 request_type="key", **data_json, table_name="key" ) - response[ - "soft_budget" - ] = data.soft_budget # include the user-input soft budget in the response + response["soft_budget"] = ( + data.soft_budget + ) # include the user-input soft budget in the response response = GenerateKeyResponse(**response) @@ -667,9 +681,9 @@ async def _set_object_permission( data=data_json["object_permission"], ) ) - data_json[ - "object_permission_id" - ] = created_object_permission.object_permission_id + data_json["object_permission_id"] = ( + created_object_permission.object_permission_id + ) # delete the object_permission from the data_json data_json.pop("object_permission") @@ -1652,10 +1666,10 @@ async def delete_verification_tokens( try: if prisma_client: tokens = [_hash_token_if_needed(token=key) for key in tokens] - _keys_being_deleted: List[ - LiteLLM_VerificationToken - ] = await prisma_client.db.litellm_verificationtoken.find_many( - where={"token": {"in": tokens}} + _keys_being_deleted: List[LiteLLM_VerificationToken] = ( + await prisma_client.db.litellm_verificationtoken.find_many( + where={"token": {"in": tokens}} + ) ) if len(_keys_being_deleted) == 0: @@ -1763,9 +1777,9 @@ async def _rotate_master_key( from litellm.proxy.proxy_server import proxy_config try: - models: Optional[ - List - ] = await prisma_client.db.litellm_proxymodeltable.find_many() + models: Optional[List] = ( + await prisma_client.db.litellm_proxymodeltable.find_many() + ) except Exception: models = None # 2. process model table @@ -2057,11 +2071,11 @@ async def validate_key_list_check( param="user_id", code=status.HTTP_403_FORBIDDEN, ) - complete_user_info_db_obj: Optional[ - BaseModel - ] = await prisma_client.db.litellm_usertable.find_unique( - where={"user_id": user_api_key_dict.user_id}, - include={"organization_memberships": True}, + complete_user_info_db_obj: Optional[BaseModel] = ( + await prisma_client.db.litellm_usertable.find_unique( + where={"user_id": user_api_key_dict.user_id}, + include={"organization_memberships": True}, + ) ) if complete_user_info_db_obj is None: @@ -2147,10 +2161,10 @@ async def get_admin_team_ids( if complete_user_info is None: return [] # Get all teams that user is an admin of - teams: Optional[ - List[BaseModel] - ] = await prisma_client.db.litellm_teamtable.find_many( - where={"team_id": {"in": complete_user_info.teams}} + teams: Optional[List[BaseModel]] = ( + await prisma_client.db.litellm_teamtable.find_many( + where={"team_id": {"in": complete_user_info.teams}} + ) ) if teams is None: return [] @@ -2403,12 +2417,14 @@ async def _list_key_helper( where=where, # type: ignore skip=skip, # type: ignore take=size, # type: ignore - order=order_by - if order_by - else [ - {"created_at": "desc"}, - {"token": "desc"}, # fallback sort - ], + order=( + order_by + if order_by + else [ + {"created_at": "desc"}, + {"token": "desc"}, # fallback sort + ] + ), include={"object_permission": True}, ) diff --git a/litellm/proxy/management_endpoints/sso_helper_utils.py b/litellm/proxy/management_endpoints/sso_helper_utils.py index 45906b2fce..7b296a6646 100644 --- a/litellm/proxy/management_endpoints/sso_helper_utils.py +++ b/litellm/proxy/management_endpoints/sso_helper_utils.py @@ -1,9 +1,14 @@ +from typing import Dict, Union + from litellm.proxy._types import LitellmUserRoles -def check_is_admin_only_access(ui_access_mode: str) -> bool: +def check_is_admin_only_access(ui_access_mode: Union[str, Dict]) -> bool: """Checks ui access mode is admin_only""" - return ui_access_mode == "admin_only" + if isinstance(ui_access_mode, str): + return ui_access_mode == "admin_only" + else: + return False def has_admin_ui_access(user_role: str) -> bool: diff --git a/litellm/proxy/management_endpoints/team_endpoints.py b/litellm/proxy/management_endpoints/team_endpoints.py index 31d2877ea6..3d4ae4e945 100644 --- a/litellm/proxy/management_endpoints/team_endpoints.py +++ b/litellm/proxy/management_endpoints/team_endpoints.py @@ -262,6 +262,7 @@ async def new_team( # noqa: PLR0915 - guardrails: Optional[List[str]] - Guardrails for the team. [Docs](https://docs.litellm.ai/docs/proxy/guardrails) - object_permission: Optional[LiteLLM_ObjectPermissionBase] - team-specific object permission. Example - {"vector_stores": ["vector_store_1", "vector_store_2"]}. IF null or {} then no object permission. - team_member_budget: Optional[float] - The maximum budget allocated to an individual team member. + - team_member_key_duration: Optional[str] - The duration for a team member's key. e.g. "1d", "1w", "1mo" Returns: - team_id: (str) Unique team id - used for tracking spend across multiple keys for same team id. @@ -688,6 +689,7 @@ async def update_team( - guardrails: Optional[List[str]] - Guardrails for the team. [Docs](https://docs.litellm.ai/docs/proxy/guardrails) - object_permission: Optional[LiteLLM_ObjectPermissionBase] - team-specific object permission. Example - {"vector_stores": ["vector_store_1", "vector_store_2"]}. IF null or {} then no object permission. - team_member_budget: Optional[float] - The maximum budget allocated to an individual team member. + - team_member_key_duration: Optional[str] - The duration for a team member's key. e.g. "1d", "1w", "1mo" Example - update team TPM Limit ``` diff --git a/litellm/proxy/management_endpoints/ui_sso.py b/litellm/proxy/management_endpoints/ui_sso.py index 29a078ec06..77fdb64b0e 100644 --- a/litellm/proxy/management_endpoints/ui_sso.py +++ b/litellm/proxy/management_endpoints/ui_sso.py @@ -145,7 +145,11 @@ async def google_login(request: Request): # noqa: PLR0915 return HTMLResponse(content=html_form, status_code=200) -def generic_response_convertor(response, jwt_handler: JWTHandler): +def generic_response_convertor( + response, + jwt_handler: JWTHandler, + sso_jwt_handler: Optional[JWTHandler] = None, +): generic_user_id_attribute_name = os.getenv( "GENERIC_USER_ID_ATTRIBUTE", "preferred_username" ) @@ -171,6 +175,13 @@ def generic_response_convertor(response, jwt_handler: JWTHandler): f" generic_user_id_attribute_name: {generic_user_id_attribute_name}\n generic_user_email_attribute_name: {generic_user_email_attribute_name}" ) + all_teams = [] + if sso_jwt_handler is not None: + team_ids = sso_jwt_handler.get_team_ids_from_jwt(cast(dict, response)) + all_teams.extend(team_ids) + + team_ids = jwt_handler.get_team_ids_from_jwt(cast(dict, response)) + all_teams.extend(team_ids) return CustomOpenID( id=response.get(generic_user_id_attribute_name), display_name=response.get(generic_user_display_name_attribute_name), @@ -178,20 +189,24 @@ def generic_response_convertor(response, jwt_handler: JWTHandler): first_name=response.get(generic_user_first_name_attribute_name), last_name=response.get(generic_user_last_name_attribute_name), provider=response.get(generic_provider_attribute_name), - team_ids=jwt_handler.get_team_ids_from_jwt(cast(dict, response)), + team_ids=all_teams, ) async def get_generic_sso_response( request: Request, jwt_handler: JWTHandler, + sso_jwt_handler: Optional[ + JWTHandler + ], # sso specific jwt handler - used for restricted sso group access control generic_client_id: str, redirect_url: str, -) -> Union[OpenID, dict]: +) -> Tuple[Union[OpenID, dict], Optional[dict]]: # return received response # make generic sso provider from fastapi_sso.sso.base import DiscoveryDocument from fastapi_sso.sso.generic import create_provider + received_response: Optional[dict] = None generic_client_secret = os.getenv("GENERIC_CLIENT_SECRET", None) generic_scope = os.getenv("GENERIC_SCOPE", "openid email profile").split(" ") generic_authorization_endpoint = os.getenv("GENERIC_AUTHORIZATION_ENDPOINT", None) @@ -242,9 +257,12 @@ async def get_generic_sso_response( ) def response_convertor(response, client): + nonlocal received_response # return for user debugging + received_response = response return generic_response_convertor( response=response, jwt_handler=jwt_handler, + sso_jwt_handler=sso_jwt_handler, ) SSOProvider = create_provider( @@ -284,7 +302,7 @@ async def get_generic_sso_response( ) raise e verbose_proxy_logger.debug("generic result: %s", result) - return result or {} + return result or {}, received_response async def create_team_member_add_task(team_id, user_info): @@ -480,6 +498,8 @@ async def check_and_update_if_proxy_admin_id( async def auth_callback(request: Request): # noqa: PLR0915 """Verify login""" verbose_proxy_logger.info("Starting SSO callback") + from litellm.proxy._types import LiteLLM_JWTAuth + from litellm.proxy.auth.handle_jwt import JWTHandler from litellm.proxy.management_endpoints.key_management_endpoints import ( generate_key_helper_fn, ) @@ -490,7 +510,6 @@ async def auth_callback(request: Request): # noqa: PLR0915 premium_user, prisma_client, proxy_logging_obj, - ui_access_mode, user_api_key_cache, user_custom_sso, ) @@ -502,9 +521,25 @@ async def auth_callback(request: Request): # noqa: PLR0915 status_code=500, detail=CommonProxyErrors.db_not_connected_error.value ) + sso_jwt_handler: Optional[JWTHandler] = None + ui_access_mode = general_settings.get("ui_access_mode", None) + if ui_access_mode is not None and isinstance(ui_access_mode, dict): + sso_jwt_handler = JWTHandler() + sso_jwt_handler.update_environment( + prisma_client=prisma_client, + user_api_key_cache=user_api_key_cache, + litellm_jwtauth=LiteLLM_JWTAuth( + team_ids_jwt_field=general_settings.get("ui_access_mode", {}).get( + "sso_group_jwt_field", None + ), + ), + leeway=0, + ) + microsoft_client_id = os.getenv("MICROSOFT_CLIENT_ID", None) google_client_id = os.getenv("GOOGLE_CLIENT_ID", None) generic_client_id = os.getenv("GENERIC_CLIENT_ID", None) + received_response: Optional[dict] = None # get url from request if master_key is None: raise ProxyException( @@ -532,11 +567,12 @@ async def auth_callback(request: Request): # noqa: PLR0915 redirect_url=redirect_url, ) elif generic_client_id is not None: - result = await get_generic_sso_response( + result, received_response = await get_generic_sso_response( request=request, jwt_handler=jwt_handler, generic_client_id=generic_client_id, redirect_url=redirect_url, + sso_jwt_handler=sso_jwt_handler, ) if result is None: @@ -547,6 +583,7 @@ async def auth_callback(request: Request): # noqa: PLR0915 # User is Authe'd in - generate key for the UI to access Proxy verbose_proxy_logger.info(f"SSO callback result: {result}") + user_email: Optional[str] = getattr(result, "email", None) user_id: Optional[str] = getattr(result, "id", None) if result is not None else None @@ -612,6 +649,13 @@ async def auth_callback(request: Request): # noqa: PLR0915 budget_duration=internal_user_budget_duration, ) + # (IF SET) Verify user is in restricted SSO group + SSOAuthenticationHandler.verify_user_in_restricted_sso_group( + general_settings=general_settings, + result=result, + received_response=received_response, + ) + user_info = await get_user_info_from_db( result=result, prisma_client=prisma_client, @@ -1055,6 +1099,44 @@ class SSOAuthenticationHandler: sso_teams = getattr(result, "team_ids", []) await add_missing_team_member(user_info=user_info, sso_teams=sso_teams) + @staticmethod + def verify_user_in_restricted_sso_group( + general_settings: Dict, + result: Optional[Union[CustomOpenID, OpenID, dict]], + received_response: Optional[dict], + ) -> Literal[True]: + """ + when ui_access_mode.type == "restricted_sso_group": + + - result.team_ids should contain the restricted_sso_group + - if not, raise a ProxyException + - if so, return True + - if result.team_ids is None, return False + - if result.team_ids is an empty list, return False + - if result.team_ids is a list, return True if the restricted_sso_group is in the list, otherwise return False + """ + + ui_access_mode = cast( + Optional[Union[Dict, str]], general_settings.get("ui_access_mode") + ) + + if ui_access_mode is None: + return True + if isinstance(ui_access_mode, str): + return True + team_ids = getattr(result, "team_ids", []) + + if ui_access_mode.get("type") == "restricted_sso_group": + restricted_sso_group = ui_access_mode.get("restricted_sso_group") + if restricted_sso_group not in team_ids: + raise ProxyException( + message=f"User is not in the restricted SSO group: {restricted_sso_group}. User groups: {team_ids}. Received SSO response: {received_response}", + type=ProxyErrorTypes.auth_error, + param="restricted_sso_group", + code=status.HTTP_403_FORBIDDEN, + ) + return True + @staticmethod async def create_litellm_team_from_sso_group( litellm_team_id: str, @@ -1551,7 +1633,29 @@ async def debug_sso_callback(request: Request): from fastapi.responses import HTMLResponse - from litellm.proxy.proxy_server import jwt_handler + from litellm.proxy._types import LiteLLM_JWTAuth + from litellm.proxy.auth.handle_jwt import JWTHandler + from litellm.proxy.proxy_server import ( + general_settings, + jwt_handler, + prisma_client, + user_api_key_cache, + ) + + sso_jwt_handler: Optional[JWTHandler] = None + ui_access_mode = general_settings.get("ui_access_mode", None) + if ui_access_mode is not None and isinstance(ui_access_mode, dict): + sso_jwt_handler = JWTHandler() + sso_jwt_handler.update_environment( + prisma_client=prisma_client, + user_api_key_cache=user_api_key_cache, + litellm_jwtauth=LiteLLM_JWTAuth( + team_ids_jwt_field=general_settings.get("ui_access_mode", {}).get( + "sso_group_jwt_field", None + ), + ), + leeway=0, + ) microsoft_client_id = os.getenv("MICROSOFT_CLIENT_ID", None) google_client_id = os.getenv("GOOGLE_CLIENT_ID", None) @@ -1580,11 +1684,12 @@ async def debug_sso_callback(request: Request): ) elif generic_client_id is not None: - result = await get_generic_sso_response( + result, _ = await get_generic_sso_response( request=request, jwt_handler=jwt_handler, generic_client_id=generic_client_id, redirect_url=redirect_url, + sso_jwt_handler=sso_jwt_handler, ) # If result is None, return a basic error message diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 95173d1c09..0a8abdd19e 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -905,7 +905,7 @@ health_check_results: Dict[str, Union[int, List[Dict[str, Any]]]] = {} queue: List = [] litellm_proxy_budget_name = "litellm-proxy-budget" litellm_proxy_admin_name = LITELLM_PROXY_ADMIN_NAME -ui_access_mode: Literal["admin", "all"] = "all" +ui_access_mode: Union[Literal["admin", "all"], Dict] = "all" proxy_budget_rescheduler_min_time = PROXY_BUDGET_RESCHEDULER_MIN_TIME proxy_budget_rescheduler_max_time = PROXY_BUDGET_RESCHEDULER_MAX_TIME proxy_batch_write_at = PROXY_BATCH_WRITE_AT @@ -1435,11 +1435,13 @@ class ProxyConfig: - Do not write restricted params like 'api_key' to the database - if api_key is passed, save that to the local environment or connected secret manage (maybe expose `litellm.save_secret()`) """ + if prisma_client is not None and ( general_settings.get("store_model_in_db", False) is True or store_model_in_db ): # if using - db for config - models are in ModelTable + new_config.pop("model_list", None) await prisma_client.insert_data(data=new_config, table_name="config") else: @@ -2625,6 +2627,10 @@ class ProxyConfig: pass_through_endpoints=general_settings["pass_through_endpoints"] ) + ## UI ACCESS MODE ## + if "ui_access_mode" in _general_settings: + general_settings["ui_access_mode"] = _general_settings["ui_access_mode"] + def _update_config_fields( self, current_config: dict, diff --git a/litellm/proxy/ui_crud_endpoints/proxy_setting_endpoints.py b/litellm/proxy/ui_crud_endpoints/proxy_setting_endpoints.py index 25991862fa..884afd02bc 100644 --- a/litellm/proxy/ui_crud_endpoints/proxy_setting_endpoints.py +++ b/litellm/proxy/ui_crud_endpoints/proxy_setting_endpoints.py @@ -7,7 +7,10 @@ import litellm from litellm._logging import verbose_proxy_logger from litellm.proxy._types import * from litellm.proxy.auth.user_api_key_auth import user_api_key_auth -from litellm.types.proxy.management_endpoints.ui_sso import DefaultTeamSSOParams, SSOConfig +from litellm.types.proxy.management_endpoints.ui_sso import ( + DefaultTeamSSOParams, + SSOConfig, +) router = APIRouter() @@ -18,26 +21,29 @@ class IPAddress(BaseModel): class SettingsResponse(BaseModel): """Base response model for settings with values and schema information""" - + values: Dict[str, Any] """The current configuration values""" - + field_schema: Dict[str, Any] """Schema information including descriptions and property types for UI display""" class SSOSettingsResponse(SettingsResponse): """Response model for SSO settings""" + pass class InternalUserSettingsResponse(SettingsResponse): """Response model for internal user settings""" + pass class DefaultTeamSettingsResponse(SettingsResponse): """Response model for default team settings""" + pass @@ -166,7 +172,10 @@ async def _get_settings_with_schema( # Add descriptions to the response result = { "values": settings_dict, - "field_schema": {"description": schema.get("description", ""), "properties": {}}, + "field_schema": { + "description": schema.get("description", ""), + "properties": {}, + }, } # Add property descriptions @@ -322,20 +331,21 @@ async def get_sso_settings(): Returns a structured object with values and descriptions for UI display. """ import os + from litellm.proxy.proxy_server import proxy_config - + # Load existing config to get both environment variables and general settings config = await proxy_config.get_config() general_settings = config.get("general_settings", {}) or {} environment_variables = config.get("environment_variables", {}) or {} - + # Get user_email from general_settings proxy_admin_email = general_settings.get("proxy_admin_email", None) - + # Helper function to get env var value (first from config, then from environment) def get_env_value(env_var_name: str): return environment_variables.get(env_var_name) or os.getenv(env_var_name) - + # Get current environment variables for SSO sso_config = SSOConfig( google_client_id=get_env_value("GOOGLE_CLIENT_ID"), @@ -351,27 +361,31 @@ async def get_sso_settings(): proxy_base_url=get_env_value("PROXY_BASE_URL"), user_email=proxy_admin_email, # Get from config instead of environment ) - + # Get the schema for UI display from pydantic import TypeAdapter + schema = TypeAdapter(SSOConfig).json_schema(by_alias=True) - + # Convert to dict for response sso_dict = sso_config.model_dump() - + # Add descriptions to the response result = { "values": sso_dict, - "field_schema": {"description": schema.get("description", ""), "properties": {}}, + "field_schema": { + "description": schema.get("description", ""), + "properties": {}, + }, } - + # Add property descriptions for field_name, field_info in schema["properties"].items(): result["field_schema"]["properties"][field_name] = { "description": field_info.get("description", ""), "type": field_info.get("type", "string"), } - + return result @@ -384,51 +398,56 @@ async def update_sso_settings(sso_config: SSOConfig): """ Update SSO configuration by saving to both environment variables and config file. """ - from litellm.proxy.proxy_server import proxy_config import os - + + from litellm.proxy.proxy_server import proxy_config + # Update environment variables env_var_mapping = { - 'google_client_id': 'GOOGLE_CLIENT_ID', - 'google_client_secret': 'GOOGLE_CLIENT_SECRET', - 'microsoft_client_id': 'MICROSOFT_CLIENT_ID', - 'microsoft_client_secret': 'MICROSOFT_CLIENT_SECRET', - 'microsoft_tenant': 'MICROSOFT_TENANT', - 'generic_client_id': 'GENERIC_CLIENT_ID', - 'generic_client_secret': 'GENERIC_CLIENT_SECRET', - 'generic_authorization_endpoint': 'GENERIC_AUTHORIZATION_ENDPOINT', - 'generic_token_endpoint': 'GENERIC_TOKEN_ENDPOINT', - 'generic_userinfo_endpoint': 'GENERIC_USERINFO_ENDPOINT', - 'proxy_base_url': 'PROXY_BASE_URL', + "google_client_id": "GOOGLE_CLIENT_ID", + "google_client_secret": "GOOGLE_CLIENT_SECRET", + "microsoft_client_id": "MICROSOFT_CLIENT_ID", + "microsoft_client_secret": "MICROSOFT_CLIENT_SECRET", + "microsoft_tenant": "MICROSOFT_TENANT", + "generic_client_id": "GENERIC_CLIENT_ID", + "generic_client_secret": "GENERIC_CLIENT_SECRET", + "generic_authorization_endpoint": "GENERIC_AUTHORIZATION_ENDPOINT", + "generic_token_endpoint": "GENERIC_TOKEN_ENDPOINT", + "generic_userinfo_endpoint": "GENERIC_USERINFO_ENDPOINT", + "proxy_base_url": "PROXY_BASE_URL", } - + # Load existing config config = await proxy_config.get_config() - + # Update config with new environment variables if "environment_variables" not in config: config["environment_variables"] = {} - + # Update general_settings for user_email (admin email) if "general_settings" not in config: config["general_settings"] = {} - + # Update environment variables in config and in memory sso_data = sso_config.model_dump(exclude_none=True) for field_name, value in sso_data.items(): - if field_name == 'user_email' and value is not None: + + if field_name == "user_email" and value is not None: # Store user_email in general_settings instead of environment variables config["general_settings"]["proxy_admin_email"] = value + elif field_name == "ui_access_mode" and value is not None: + + config["general_settings"]["ui_access_mode"] = value elif field_name in env_var_mapping and value is not None: env_var_name = env_var_mapping[field_name] # Update in config config["environment_variables"][env_var_name] = value # Update in runtime environment os.environ[env_var_name] = value - + # Save the updated config await proxy_config.save_config(new_config=config) - + return { "message": "SSO settings updated successfully", "status": "success", diff --git a/litellm/types/proxy/management_endpoints/ui_sso.py b/litellm/types/proxy/management_endpoints/ui_sso.py index f6838b6170..3cdb5cb639 100644 --- a/litellm/types/proxy/management_endpoints/ui_sso.py +++ b/litellm/types/proxy/management_endpoints/ui_sso.py @@ -1,4 +1,4 @@ -from typing import List, Literal, Optional, TypedDict +from typing import List, Literal, Optional, TypedDict, Union from pydantic import Field @@ -31,6 +31,14 @@ class MicrosoftServicePrincipalTeam(TypedDict, total=False): principalId: Optional[str] +class AccessControl_UI_AccessMode(LiteLLMPydanticObjectBase): + """Model for Controlling UI Access Mode via SSO Groups""" + + type: Literal["restricted_sso_group"] + restricted_sso_group: str + sso_group_jwt_field: str + + class SSOConfig(LiteLLMPydanticObjectBase): """ Configuration for SSO environment variables and settings @@ -45,7 +53,7 @@ class SSOConfig(LiteLLMPydanticObjectBase): default=None, description="Google OAuth Client Secret for SSO authentication", ) - + # Microsoft SSO microsoft_client_id: Optional[str] = Field( default=None, @@ -59,7 +67,7 @@ class SSOConfig(LiteLLMPydanticObjectBase): default=None, description="Microsoft Azure Tenant ID for SSO authentication", ) - + # Generic/Okta SSO generic_client_id: Optional[str] = Field( default=None, @@ -81,7 +89,7 @@ class SSOConfig(LiteLLMPydanticObjectBase): default=None, description="User info endpoint URL for generic OAuth provider", ) - + # Common settings proxy_base_url: Optional[str] = Field( default=None, @@ -92,6 +100,12 @@ class SSOConfig(LiteLLMPydanticObjectBase): description="Email of the proxy admin user", ) + # Access Mode + ui_access_mode: Optional[Union[AccessControl_UI_AccessMode, str]] = Field( + default=None, + description="Access mode for the UI", + ) + class DefaultTeamSSOParams(LiteLLMPydanticObjectBase): """ diff --git a/tests/test_litellm/proxy/management_endpoints/test_ui_sso.py b/tests/test_litellm/proxy/management_endpoints/test_ui_sso.py index 2ce4cf2938..60199b335a 100644 --- a/tests/test_litellm/proxy/management_endpoints/test_ui_sso.py +++ b/tests/test_litellm/proxy/management_endpoints/test_ui_sso.py @@ -696,11 +696,12 @@ async def test_get_generic_sso_response_with_additional_headers(): "fastapi_sso.sso.generic.create_provider", return_value=mock_sso_class ) as mock_create_provider: # Act - result = await get_generic_sso_response( + result, received_response = await get_generic_sso_response( request=mock_request, jwt_handler=mock_jwt_handler, generic_client_id=generic_client_id, redirect_url=redirect_url, + sso_jwt_handler=None, ) # Assert @@ -756,11 +757,12 @@ async def test_get_generic_sso_response_with_empty_headers(): "fastapi_sso.sso.generic.create_provider", return_value=mock_sso_class ) as mock_create_provider: # Act - result = await get_generic_sso_response( + result, received_response = await get_generic_sso_response( request=mock_request, jwt_handler=mock_jwt_handler, generic_client_id=generic_client_id, redirect_url=redirect_url, + sso_jwt_handler=None, ) # Assert diff --git a/ui/litellm-dashboard/src/components/SSOModals.tsx b/ui/litellm-dashboard/src/components/SSOModals.tsx index 0779edb89d..763242aeb6 100644 --- a/ui/litellm-dashboard/src/components/SSOModals.tsx +++ b/ui/litellm-dashboard/src/components/SSOModals.tsx @@ -132,7 +132,7 @@ const SSOModals: React.FC = ({ } } - // Set form values with existing data + // Set form values with existing data (excluding UI access control fields) const formValues = { sso_provider: selectedProvider, proxy_base_url: ssoData.values.proxy_base_url, diff --git a/ui/litellm-dashboard/src/components/UIAccessControlForm.tsx b/ui/litellm-dashboard/src/components/UIAccessControlForm.tsx new file mode 100644 index 0000000000..def9bf6bb9 --- /dev/null +++ b/ui/litellm-dashboard/src/components/UIAccessControlForm.tsx @@ -0,0 +1,148 @@ +import React, { useEffect, useState } from "react"; +import { Form, Button as Button2, Select, message } from "antd"; +import { Text, TextInput } from "@tremor/react"; +import { getSSOSettings, updateSSOSettings } from "./networking"; + +interface UIAccessControlFormProps { + accessToken: string | null; + onSuccess: () => void; +} + +// Separate UI Access Control Form Component +const UIAccessControlForm: React.FC = ({ accessToken, onSuccess }) => { + const [form] = Form.useForm(); + const [loading, setLoading] = useState(false); + + // Load existing UI access control settings + useEffect(() => { + const loadUIAccessSettings = async () => { + if (accessToken) { + try { + const ssoData = await getSSOSettings(accessToken); + if (ssoData && ssoData.values) { + // Handle nested ui_access_mode structure + const uiAccessMode = ssoData.values.ui_access_mode; + let formValues = {}; + + if (uiAccessMode && typeof uiAccessMode === 'object') { + formValues = { + ui_access_mode_type: uiAccessMode.type, + restricted_sso_group: uiAccessMode.restricted_sso_group, + sso_group_jwt_field: uiAccessMode.sso_group_jwt_field, + }; + } else if (typeof uiAccessMode === 'string') { + // Handle legacy flat structure + formValues = { + ui_access_mode_type: uiAccessMode, + restricted_sso_group: ssoData.values.restricted_sso_group, + sso_group_jwt_field: ssoData.values.team_ids_jwt_field || ssoData.values.sso_group_jwt_field, + }; + } + + form.setFieldsValue(formValues); + } + } catch (error) { + console.error("Failed to load UI access settings:", error); + } + } + }; + + loadUIAccessSettings(); + }, [accessToken, form]); + + const handleUIAccessSubmit = async (formValues: Record) => { + if (!accessToken) { + message.error("No access token available"); + return; + } + + setLoading(true); + try { + // Transform form data to match API expected structure + const apiPayload = { + ui_access_mode: { + type: formValues.ui_access_mode_type, + restricted_sso_group: formValues.restricted_sso_group, + sso_group_jwt_field: formValues.sso_group_jwt_field, + } + }; + + await updateSSOSettings(accessToken, apiPayload); + onSuccess(); + } catch (error) { + console.error("Failed to save UI access settings:", error); + message.error("Failed to save UI access settings"); + } finally { + setLoading(false); + } + }; + + return ( +
+
+ + Configure who can access the UI interface and how group information is extracted from JWT tokens. + +
+ +
+ + + + + prevValues.ui_access_mode_type !== currentValues.ui_access_mode_type} + > + {({ getFieldValue }) => { + const uiAccessModeType = getFieldValue('ui_access_mode_type'); + return uiAccessModeType === 'restricted_sso_group' ? ( + + + + ) : null; + }} + + + + + + +
+ + Update UI Access Control + +
+
+
+ ); +}; + +export default UIAccessControlForm; \ No newline at end of file diff --git a/ui/litellm-dashboard/src/components/admins.tsx b/ui/litellm-dashboard/src/components/admins.tsx index 6b3ddc094d..a876aa9d51 100644 --- a/ui/litellm-dashboard/src/components/admins.tsx +++ b/ui/litellm-dashboard/src/components/admins.tsx @@ -44,6 +44,7 @@ import { InvitationLink } from "./onboarding_link"; import SSOModals from "./SSOModals"; import { ssoProviderConfigs } from './SSOModals'; import SCIMConfig from "./SCIM"; +import UIAccessControlForm from "./UIAccessControlForm"; interface AdminPanelProps { searchParams: any; @@ -97,6 +98,7 @@ const AdminPanel: React.FC = ({ const [isAllowedIPModalVisible, setIsAllowedIPModalVisible] = useState(false); const [isAddIPModalVisible, setIsAddIPModalVisible] = useState(false); const [isDeleteIPModalVisible, setIsDeleteIPModalVisible] = useState(false); + const [isUIAccessControlModalVisible, setIsUIAccessControlModalVisible] = useState(false); const [allowedIPs, setAllowedIPs] = useState([]); const [ipToDelete, setIPToDelete] = useState(null); const [ssoConfigured, setSsoConfigured] = useState(false); @@ -532,6 +534,14 @@ const AdminPanel: React.FC = ({ } }; + const handleUIAccessControlOk = () => { + setIsUIAccessControlModalVisible(false); + }; + + const handleUIAccessControlCancel = () => { + setIsUIAccessControlModalVisible(false); + }; + console.log(`admins: ${admins?.length}`); return (
@@ -563,6 +573,14 @@ const AdminPanel: React.FC = ({ Allowed IPs
+
+ +
@@ -654,6 +672,24 @@ const AdminPanel: React.FC = ({ >

Are you sure you want to delete the IP address: {ipToDelete}?

+ + {/* UI Access Control Modal */} + + { + handleUIAccessControlOk(); + message.success("UI Access Control settings updated successfully"); + }} + /> + If you need to login without sso, you can access{" "} diff --git a/ui/litellm-dashboard/src/components/team/team_info.tsx b/ui/litellm-dashboard/src/components/team/team_info.tsx index 0d79f56262..439e7f4a6d 100644 --- a/ui/litellm-dashboard/src/components/team/team_info.tsx +++ b/ui/litellm-dashboard/src/components/team/team_info.tsx @@ -260,6 +260,10 @@ const TeamInfoView: React.FC = ({ updateData.team_member_budget = Number(values.team_member_budget); } + if (values.team_member_key_duration !== undefined) { + updateData.team_member_key_duration = values.team_member_key_duration; + } + // Handle object_permission updates if (values.vector_stores !== undefined || values.mcp_servers !== undefined) { updateData.object_permission = { @@ -453,6 +457,15 @@ const TeamInfoView: React.FC = ({ + + + + +