diff --git a/README.md b/README.md index 75a23faa5c..58ffa12c5e 100644 --- a/README.md +++ b/README.md @@ -374,7 +374,9 @@ Support for more providers. Missing a provider or LLM Platform, raise a [feature 1. (In root) create virtual environment `python -m venv .venv` 2. Activate virtual environment `source .venv/bin/activate` 3. Install dependencies `pip install -e ".[all]"` -4. Start proxy backend `python litellm/proxy_cli.py` +4. `pip install prisma` +5. `prisma generate` +6. Start proxy backend `python litellm/proxy/proxy_cli.py` ### Frontend 1. Navigate to `ui/litellm-dashboard` diff --git a/cookbook/ai_coding_tool_guides/index.json b/cookbook/ai_coding_tool_guides/index.json index f879292aef..3e71670d62 100644 --- a/cookbook/ai_coding_tool_guides/index.json +++ b/cookbook/ai_coding_tool_guides/index.json @@ -110,7 +110,7 @@ ] }, { - "title": "Use Web Search with Claude Code (across OpenAI/Anthropic/Gemini/etc.)", + "title": "Use Web Search with Claude Code (across Bedrock/OpenAI/Gemini/etc.)", "description": "This is a guide for using Web Search with Claude Code via LiteLLM.", "url": "https://docs.litellm.ai/docs/tutorials/claude_code_websearch", "date": "2026-01-17", diff --git a/deploy/charts/litellm-helm/Chart.yaml b/deploy/charts/litellm-helm/Chart.yaml index b37597c7c8..8a08f0b4e2 100644 --- a/deploy/charts/litellm-helm/Chart.yaml +++ b/deploy/charts/litellm-helm/Chart.yaml @@ -18,7 +18,7 @@ type: application # This is the chart version. This version number should be incremented each time you make changes # to the chart and its templates, including the app version. # Versions are expected to follow Semantic Versioning (https://semver.org/) -version: 1.0.0 +version: 1.1.0 # This is the version number of the application being deployed. This version number should be # incremented each time you make changes to the application. Versions are not expected to diff --git a/deploy/charts/litellm-helm/templates/deployment.yaml b/deploy/charts/litellm-helm/templates/deployment.yaml index 682d97ae3b..c3e0055e38 100644 --- a/deploy/charts/litellm-helm/templates/deployment.yaml +++ b/deploy/charts/litellm-helm/templates/deployment.yaml @@ -10,7 +10,7 @@ metadata: {{- toYaml .Values.deploymentLabels | nindent 4 }} {{- end }} spec: - {{- if not .Values.autoscaling.enabled }} + {{- if and (not .Values.keda.enabled) (not .Values.autoscaling.enabled) }} replicas: {{ .Values.replicaCount }} {{- end }} selector: diff --git a/deploy/charts/litellm-helm/templates/keda.yaml b/deploy/charts/litellm-helm/templates/keda.yaml new file mode 100644 index 0000000000..fe5190fffc --- /dev/null +++ b/deploy/charts/litellm-helm/templates/keda.yaml @@ -0,0 +1,37 @@ +{{- if and .Values.keda.enabled (not .Values.autoscaling.enabled) }} +apiVersion: keda.sh/v1alpha1 +kind: ScaledObject +metadata: + name: {{ include "litellm.fullname" . }} + labels: + {{- include "litellm.labels" . | nindent 4 }} + {{- if .Values.keda.scaledObject.annotations }} + annotations: {{ toYaml .Values.keda.scaledObject.annotations | nindent 4 }} + {{- end }} +spec: + scaleTargetRef: + name: {{ include "litellm.fullname" . }} + pollingInterval: {{ .Values.keda.pollingInterval }} + cooldownPeriod: {{ .Values.keda.cooldownPeriod }} + minReplicaCount: {{ .Values.keda.minReplicas }} + maxReplicaCount: {{ .Values.keda.maxReplicas }} +{{- with .Values.keda.fallback }} + fallback: + failureThreshold: {{ .failureThreshold | default 3 }} + replicas: {{ .replicas | default $.Values.keda.maxReplicas }} +{{- end }} + triggers: +{{- with .Values.keda.triggers }} + {{- toYaml . | nindent 2 }} +{{- end }} + advanced: + restoreToOriginalReplicaCount: {{ .Values.keda.restoreToOriginalReplicaCount }} +{{- if .Values.keda.behavior }} + horizontalPodAutoscalerConfig: + behavior: +{{- with .Values.keda.behavior }} +{{- toYaml . | nindent 8 }} +{{- end }} + +{{- end }} +{{- end }} diff --git a/deploy/charts/litellm-helm/values.yaml b/deploy/charts/litellm-helm/values.yaml index e9e8e75a1f..5427175699 100644 --- a/deploy/charts/litellm-helm/values.yaml +++ b/deploy/charts/litellm-helm/values.yaml @@ -156,6 +156,40 @@ autoscaling: targetCPUUtilizationPercentage: 80 # targetMemoryUtilizationPercentage: 80 +# Autoscaling with keda is mutually exclusive with hpa +keda: + enabled: false + minReplicas: 1 + maxReplicas: 100 + pollingInterval: 30 + cooldownPeriod: 300 + # fallback: + # failureThreshold: 3 + # replicas: 11 + restoreToOriginalReplicaCount: false + scaledObject: + annotations: {} + triggers: [] + # - type: prometheus + # metadata: + # serverAddress: http://:9090 + # metricName: http_requests_total + # threshold: '100' + # query: sum(rate(http_requests_total{deployment="my-deployment"}[2m])) + behavior: {} + # scaleDown: + # stabilizationWindowSeconds: 300 + # policies: + # - type: Pods + # value: 1 + # periodSeconds: 180 + # scaleUp: + # stabilizationWindowSeconds: 300 + # policies: + # - type: Pods + # value: 2 + # periodSeconds: 60 + # Additional volumes on the output Deployment definition. volumes: [] # - name: foo diff --git a/docker/Dockerfile.health_check b/docker/Dockerfile.health_check new file mode 100644 index 0000000000..de62e4bd72 --- /dev/null +++ b/docker/Dockerfile.health_check @@ -0,0 +1,16 @@ +FROM python:3.11-slim + +WORKDIR /app + +# Copy health check script and requirements +COPY scripts/health_check/health_check_client.py /app/health_check_client.py +COPY scripts/health_check/health_check_requirements.txt /app/requirements.txt + +# Install dependencies +RUN pip install --no-cache-dir -r requirements.txt + +# Make script executable +RUN chmod +x /app/health_check_client.py + +# Set entrypoint +ENTRYPOINT ["python", "/app/health_check_client.py"] diff --git a/docker/supervisord.conf b/docker/supervisord.conf index c6855fe652..877335804f 100644 --- a/docker/supervisord.conf +++ b/docker/supervisord.conf @@ -1,6 +1,8 @@ [supervisord] nodaemon=true loglevel=info +logfile=/tmp/supervisord.log +pidfile=/tmp/supervisord.pid [group:litellm] programs=main,health diff --git a/docs/my-website/docs/completion/token_usage.md b/docs/my-website/docs/completion/token_usage.md index 0bec6b3f90..d99564765a 100644 --- a/docs/my-website/docs/completion/token_usage.md +++ b/docs/my-website/docs/completion/token_usage.md @@ -100,7 +100,7 @@ from litellm import cost_per_token prompt_tokens = 5 completion_tokens = 10 -prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar = cost_per_token(model="gpt-3.5-turbo", prompt_tokens=prompt_tokens, completion_tokens=completion_tokens)) +prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar = cost_per_token(model="gpt-3.5-turbo", prompt_tokens=prompt_tokens, completion_tokens=completion_tokens) print(prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar) ``` @@ -162,7 +162,7 @@ print(model_cost) # {'gpt-3.5-turbo': {'max_tokens': 4000, 'input_cost_per_token **Dictionary** ```python -from litellm import register_model +import litellm litellm.register_model({ "gpt-4": { diff --git a/docs/my-website/docs/contributing.md b/docs/my-website/docs/contributing.md index a88013ff1b..be7222f6cb 100644 --- a/docs/my-website/docs/contributing.md +++ b/docs/my-website/docs/contributing.md @@ -1,45 +1,100 @@ # Contributing - UI -Here's how to run the LiteLLM UI locally for making changes: +Thanks for contributing to the LiteLLM UI! This guide will help you set up your local development environment. + + +## 1. Clone the repo -## 1. Clone the repo ```bash git clone https://github.com/BerriAI/litellm.git +cd litellm ``` -## 2. Start the UI + Proxy +## 2. Start the Proxy -**2.1 Start the proxy on port 4000** +Create a config file (e.g., `config.yaml`): -Tell the proxy where the UI is located -```bash -DATABASE_URL = "postgresql://:@:/" -LITELLM_MASTER_KEY = "sk-1234" -STORE_MODEL_IN_DB = "True" +```yaml +model_list: + - model_name: gpt-4o + litellm_params: + model: openai/gpt-4o + +general_settings: + master_key: sk-1234 + database_url: postgresql://:@:/ + store_model_in_db: true ``` +Start the proxy on port 4000: + ```bash -cd litellm/litellm/proxy -python3 proxy_cli.py --config /path/to/config.yaml --port 4000 +poetry run litellm --config config.yaml --port 4000 ``` -**2.2 Start the UI** +The UI comes pre-built in the repo. Access it at `http://localhost:4000/ui` -Set the mode as development (this will assume the proxy is running on localhost:4000) -```bash -npm install # install dependencies -``` +## 3. UI Development + +There are two options for UI development: + +### Option A: Development Mode (Hot Reload) + +This runs the UI on port 3000 with hot reload. The proxy runs on port 4000. ```bash -cd litellm/ui/litellm-dashboard - +cd ui/litellm-dashboard +npm install npm run dev - -# starts on http://0.0.0.0:3000 ``` -## 3. Go to local UI +**Login flow:** +1. Go to `http://localhost:3000` +2. You'll be redirected to `http://localhost:4000/ui` for login +3. After logging in, manually navigate back to `http://localhost:3000/` +4. You're now authenticated and can develop with hot reload + +:::note +If you experience redirect loops or authentication issues, clear your browser cookies for localhost or use Build Mode instead. +::: + +### Option B: Build Mode + +This builds the UI and copies it to the proxy. Changes require rebuilding. + +1. Make your code changes in `ui/litellm-dashboard/src/` + +2. Build the UI +```bash +cd ui/litellm-dashboard +npm install +npm run build +``` + +After building, copy the output to the proxy: ```bash -http://0.0.0.0:3000 -``` \ No newline at end of file +cp -r out/* ../../litellm/proxy/_experimental/out/ +``` + +Then restart the proxy and access the UI at `http://localhost:4000/ui` + +## 4. Submitting a PR + +1. Create a new branch for your changes: +```bash +git checkout -b feat/your-feature-name +``` + +2. Stage and commit your changes: +```bash +git add . +git commit -m "feat: description of your changes" +``` + +3. Push to your fork: +```bash +git push origin feat/your-feature-name +``` + +4. Create a Pull Request on GitHub following the [PR template](https://github.com/BerriAI/litellm/blob/main/.github/pull_request_template.md) diff --git a/docs/my-website/docs/providers/openai/text_to_speech.md b/docs/my-website/docs/providers/openai/text_to_speech.md index a4aeb9e525..f4507faa06 100644 --- a/docs/my-website/docs/providers/openai/text_to_speech.md +++ b/docs/my-website/docs/providers/openai/text_to_speech.md @@ -46,7 +46,7 @@ os.environ["OPENAI_API_KEY"] = "sk-.." async def test_async_speech(): speech_file_path = Path(__file__).parent / "speech.mp3" - response = await litellm.aspeech( + response = await aspeech( model="openai/tts-1", voice="alloy", input="the quick brown fox jumped over the lazy dogs", diff --git a/docs/my-website/docs/providers/stability.md b/docs/my-website/docs/providers/stability.md index 62a8ab43cd..c4bc5376d1 100644 --- a/docs/my-website/docs/providers/stability.md +++ b/docs/my-website/docs/providers/stability.md @@ -173,6 +173,14 @@ Stability AI returns images in base64 format. The response is OpenAI-compatible: Stability AI supports various image editing operations including inpainting, upscaling, outpainting, background removal, and more. +:::info Optional Parameters +**Important:** Different Stability models have different parameter requirements: +- Some models don't require a `prompt` (e.g., upscaling, background removal) +- The `style-transfer` model uses `init_image` and `style_image` instead of `image` +- The `outpaint` model requires numeric parameters (`left`, `right`, `up`, `down`) +LiteLLM automatically handles these differences for you. +::: + ### Usage - LiteLLM Python SDK #### Inpainting (Edit with Mask) @@ -217,11 +225,11 @@ response = image_edit( creativity=0.3, # 0-0.35, higher = more creative ) -# Fast upscaling - quick upscaling +# Fast upscaling - quick upscaling (no prompt needed) response = image_edit( model="stability/stable-fast-upscale-v1:0", image=open("low_res_image.png", "rb"), - prompt="Quickly upscale this image", + # No prompt required for fast upscale ) print(response) ``` @@ -259,7 +267,7 @@ os.environ['STABILITY_API_KEY'] = "your-api-key" response = image_edit( model="stability/stable-image-remove-background-v1:0", image=open("portrait.png", "rb"), - prompt="Remove the background", + # No prompt required for fast upscale ) print(response) ``` @@ -329,10 +337,29 @@ response = image_edit( model="stability/stable-image-erase-object-v1:0", image=open("scene.png", "rb"), mask=open("object_mask.png", "rb"), # Mask the object to erase - prompt="Remove the object", + # No prompt needed ) print(response) ``` +#### Style Transfer + +```python showLineNumbers +from litellm import image_edit +import os + +os.environ['STABILITY_API_KEY'] = "your-api-key" + +# Transfer style from one image to another +# Note: Uses init_image (via image param) and style_image +response = image_edit( + model="stability/stable-style-transfer-v1:0", + image=open("content_image.png", "rb"), # Maps to init_image + style_image=open("style_reference.png", "rb"), # Style to apply + fidelity=0.5, # 0-1, balance between content and style + # No prompt needed +) + +print(response) ### Supported Image Edit Models @@ -419,6 +446,23 @@ response = image_edit( ) print(response) ``` +# Fast upscale without prompt +response = image_edit( + model="bedrock/stability.stable-fast-upscale-v1:0", + image=open("low_res_image.png", "rb"), +) + +# Outpaint with numeric parameters +response = image_edit( + model="bedrock/stability.stable-outpaint-v1:0", + image=open("original_image.png", "rb"), + left=100, # Automatically converted to int + right=100, + up=50, + down=50, +) + +print(response) ### Supported Bedrock Stability Models diff --git a/docs/my-website/docs/proxy/config_settings.md b/docs/my-website/docs/proxy/config_settings.md index b941f21b33..8b4514f568 100644 --- a/docs/my-website/docs/proxy/config_settings.md +++ b/docs/my-website/docs/proxy/config_settings.md @@ -603,6 +603,7 @@ router_settings: | GCS_PATH_SERVICE_ACCOUNT | Path to the Google Cloud service account JSON file | GCS_FLUSH_INTERVAL | Flush interval for GCS logging (in seconds). Specify how often you want a log to be sent to GCS. **Default is 20 seconds** | GCS_BATCH_SIZE | Batch size for GCS logging. Specify after how many logs you want to flush to GCS. If `BATCH_SIZE` is set to 10, logs are flushed every 10 logs. **Default is 2048** +| GCS_USE_BATCHED_LOGGING | Enable batched logging for GCS. When enabled (default), multiple log payloads are combined into single GCS object uploads (NDJSON format), dramatically reducing API calls. When disabled, sends each log individually as separate GCS objects (legacy behavior). **Default is true** | GCS_PUBSUB_TOPIC_ID | PubSub Topic ID to send LiteLLM SpendLogs to. | GCS_PUBSUB_PROJECT_ID | PubSub Project ID to send LiteLLM SpendLogs to. | GENERIC_AUTHORIZATION_ENDPOINT | Authorization endpoint for generic OAuth providers diff --git a/docs/my-website/docs/proxy/deploy.md b/docs/my-website/docs/proxy/deploy.md index 5686e9fd83..7393e73ba8 100644 --- a/docs/my-website/docs/proxy/deploy.md +++ b/docs/my-website/docs/proxy/deploy.md @@ -4,6 +4,10 @@ import Image from '@theme/IdealImage'; # Docker, Helm, Terraform +:::info No Limits on LiteLLM OSS +There are **no limits** on the number of users, keys, or teams you can create on LiteLLM OSS. +::: + You can find the Dockerfile to build litellm proxy [here](https://github.com/BerriAI/litellm/blob/main/Dockerfile) > Note: Production requires at least 4 CPU cores and 8 GB RAM. diff --git a/docs/my-website/docs/text_to_speech.md b/docs/my-website/docs/text_to_speech.md index 77d15ccb3a..667ffc925c 100644 --- a/docs/my-website/docs/text_to_speech.md +++ b/docs/my-website/docs/text_to_speech.md @@ -46,7 +46,7 @@ os.environ["OPENAI_API_KEY"] = "sk-.." async def test_async_speech(): speech_file_path = Path(__file__).parent / "speech.mp3" - response = await litellm.aspeech( + response = await aspeech( model="openai/tts-1", voice="alloy", input="the quick brown fox jumped over the lazy dogs", diff --git a/docs/my-website/src/pages/token_usage.md b/docs/my-website/src/pages/token_usage.md index 028e010a96..61deb61c94 100644 --- a/docs/my-website/src/pages/token_usage.md +++ b/docs/my-website/src/pages/token_usage.md @@ -27,7 +27,7 @@ from litellm import cost_per_token prompt_tokens = 5 completion_tokens = 10 -prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar = cost_per_token(model="gpt-3.5-turbo", prompt_tokens=prompt_tokens, completion_tokens=completion_tokens)) +prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar = cost_per_token(model="gpt-3.5-turbo", prompt_tokens=prompt_tokens, completion_tokens=completion_tokens) print(prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar) ``` diff --git a/litellm/a2a_protocol/main.py b/litellm/a2a_protocol/main.py index 167aad7959..2d36dbeacd 100644 --- a/litellm/a2a_protocol/main.py +++ b/litellm/a2a_protocol/main.py @@ -113,7 +113,9 @@ def _get_a2a_model_info(a2a_client: Any, kwargs: Dict[str, Any]) -> str: litellm_logging_obj.model = model litellm_logging_obj.custom_llm_provider = custom_llm_provider litellm_logging_obj.model_call_details["model"] = model - litellm_logging_obj.model_call_details["custom_llm_provider"] = custom_llm_provider + litellm_logging_obj.model_call_details[ + "custom_llm_provider" + ] = custom_llm_provider return agent_name @@ -197,7 +199,11 @@ async def asend_message( ) # Extract params from request - params = request.params.model_dump(mode="json") if hasattr(request.params, "model_dump") else dict(request.params) + params = ( + request.params.model_dump(mode="json") + if hasattr(request.params, "model_dump") + else dict(request.params) + ) response_dict = await A2ACompletionBridgeHandler.handle_non_streaming( request_id=str(request.id), @@ -216,7 +222,9 @@ async def asend_message( # Create A2A client if not provided but api_base is available if a2a_client is None: if api_base is None: - raise ValueError("Either a2a_client or api_base is required for standard A2A flow") + raise ValueError( + "Either a2a_client or api_base is required for standard A2A flow" + ) a2a_client = await create_a2a_client(base_url=api_base) # Type assertion: a2a_client is guaranteed to be non-None here @@ -235,7 +243,11 @@ async def asend_message( # Calculate token usage from request and response response_dict = a2a_response.model_dump(mode="json", exclude_none=True) - prompt_tokens, completion_tokens, _ = A2ARequestUtils.calculate_usage_from_request_response( + ( + prompt_tokens, + completion_tokens, + _, + ) = A2ARequestUtils.calculate_usage_from_request_response( request=request, response_dict=response_dict, ) @@ -280,7 +292,9 @@ def send_message( if loop is not None: return asend_message(a2a_client=a2a_client, request=request, **kwargs) else: - return asyncio.run(asend_message(a2a_client=a2a_client, request=request, **kwargs)) + return asyncio.run( + asend_message(a2a_client=a2a_client, request=request, **kwargs) + ) async def asend_message_streaming( @@ -347,7 +361,11 @@ async def asend_message_streaming( ) # Extract params from request - params = request.params.model_dump(mode="json") if hasattr(request.params, "model_dump") else dict(request.params) + params = ( + request.params.model_dump(mode="json") + if hasattr(request.params, "model_dump") + else dict(request.params) + ) async for chunk in A2ACompletionBridgeHandler.handle_streaming( request_id=str(request.id), @@ -365,7 +383,9 @@ async def asend_message_streaming( # Create A2A client if not provided but api_base is available if a2a_client is None: if api_base is None: - raise ValueError("Either a2a_client or api_base is required for standard A2A flow") + raise ValueError( + "Either a2a_client or api_base is required for standard A2A flow" + ) a2a_client = await create_a2a_client(base_url=api_base) # Type assertion: a2a_client is guaranteed to be non-None here @@ -378,7 +398,9 @@ async def asend_message_streaming( stream = a2a_client.send_message_streaming(request) # Build logging object for streaming completion callbacks - agent_card = getattr(a2a_client, "_litellm_agent_card", None) or getattr(a2a_client, "agent_card", None) + agent_card = getattr(a2a_client, "_litellm_agent_card", None) or getattr( + a2a_client, "agent_card", None + ) agent_name = getattr(agent_card, "name", "unknown") if agent_card else "unknown" model = f"a2a_agent/{agent_name}" @@ -456,7 +478,7 @@ async def create_a2a_client( if not A2A_SDK_AVAILABLE: raise ImportError( "The 'a2a' package is required for A2A agent invocation. " - "Install it with: pip install a2a" + "Install it with: pip install a2a-sdk" ) verbose_logger.info(f"Creating A2A client for {base_url}") @@ -512,7 +534,7 @@ async def aget_agent_card( if not A2A_SDK_AVAILABLE: raise ImportError( "The 'a2a' package is required for A2A agent invocation. " - "Install it with: pip install a2a" + "Install it with: pip install a2a-sdk" ) verbose_logger.info(f"Fetching agent card from {base_url}") @@ -534,5 +556,3 @@ async def aget_agent_card( f"Fetched agent card: {agent_card.name if hasattr(agent_card, 'name') else 'unknown'}" ) return agent_card - - diff --git a/litellm/images/main.py b/litellm/images/main.py index 1b09c20d35..6c4c502a7b 100644 --- a/litellm/images/main.py +++ b/litellm/images/main.py @@ -714,8 +714,8 @@ def image_variation( @client def image_edit( # noqa: PLR0915 - image: Union[FileTypes, List[FileTypes]], - prompt: str, + image: Optional[Union[FileTypes, List[FileTypes]]] = None, + prompt: Optional[str]= None, model: Optional[str] = None, mask: Optional[str] = None, n: Optional[int] = None, @@ -766,7 +766,7 @@ def image_edit( # noqa: PLR0915 _is_async = kwargs.pop("async_call", False) is True # add images / or return a single image - images = image if isinstance(image, list) else [image] + images = image if isinstance(image, list) else ([image] if image is not None else []) headers_from_kwargs = kwargs.get("headers") merged_extra_headers: Dict[str, Any] = {} diff --git a/litellm/integrations/gcs_bucket/gcs_bucket.py b/litellm/integrations/gcs_bucket/gcs_bucket.py index 9190f921d5..3cb6290553 100644 --- a/litellm/integrations/gcs_bucket/gcs_bucket.py +++ b/litellm/integrations/gcs_bucket/gcs_bucket.py @@ -1,9 +1,11 @@ import asyncio +import hashlib import json import os +import time from litellm._uuid import uuid from datetime import datetime, timedelta, timezone -from typing import TYPE_CHECKING, Any, Dict, List, Optional +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple from urllib.parse import quote from litellm._logging import verbose_logger @@ -26,19 +28,21 @@ class GCSBucketLogger(GCSBucketBase, AdditionalLoggingUtils): super().__init__(bucket_name=bucket_name) - # Init Batch logging settings - self.log_queue: List[GCSLogQueueItem] = [] self.batch_size = int(os.getenv("GCS_BATCH_SIZE", GCS_DEFAULT_BATCH_SIZE)) self.flush_interval = int( os.getenv("GCS_FLUSH_INTERVAL", GCS_DEFAULT_FLUSH_INTERVAL_SECONDS) ) - asyncio.create_task(self.periodic_flush()) + self.use_batched_logging = ( + os.getenv("GCS_USE_BATCHED_LOGGING", str(GCS_DEFAULT_USE_BATCHED_LOGGING).lower()).lower() == "true" + ) self.flush_lock = asyncio.Lock() super().__init__( flush_lock=self.flush_lock, batch_size=self.batch_size, flush_interval=self.flush_interval, ) + self.log_queue: asyncio.Queue[GCSLogQueueItem] = asyncio.Queue() # type: ignore[assignment] + asyncio.create_task(self.periodic_flush()) AdditionalLoggingUtils.__init__(self) if premium_user is not True: @@ -65,8 +69,7 @@ class GCSBucketLogger(GCSBucketBase, AdditionalLoggingUtils): ) if logging_payload is None: raise ValueError("standard_logging_object not found in kwargs") - # Add to logging queue - this will be flushed periodically - self.log_queue.append( + await self.log_queue.put( GCSLogQueueItem( payload=logging_payload, kwargs=kwargs, response_obj=response_obj ) @@ -89,7 +92,9 @@ class GCSBucketLogger(GCSBucketBase, AdditionalLoggingUtils): if logging_payload is None: raise ValueError("standard_logging_object not found in kwargs") # Add to logging queue - this will be flushed periodically - self.log_queue.append( + # Use asyncio.Queue.put() for thread-safe concurrent access + # If queue is full, this will block until space is available (backpressure) + await self.log_queue.put( GCSLogQueueItem( payload=logging_payload, kwargs=kwargs, response_obj=response_obj ) @@ -98,28 +103,98 @@ class GCSBucketLogger(GCSBucketBase, AdditionalLoggingUtils): except Exception as e: verbose_logger.exception(f"GCS Bucket logging error: {str(e)}") - async def async_send_batch(self): + def _drain_queue_batch(self) -> List[GCSLogQueueItem]: """ - Process queued logs in batch - sends logs to GCS Bucket - - - GCS Bucket does not have a Batch endpoint to batch upload logs - - Instead, we - - collect the logs to flush every `GCS_FLUSH_INTERVAL` seconds - - during async_send_batch, we make 1 POST request per log to GCS Bucket - + Drain items from the queue (non-blocking), respecting batch_size limit. + + This prevents unbounded queue growth when processing is slower than log accumulation. + + Returns: + List of items to process, up to batch_size items """ - if not self.log_queue: - return + items_to_process: List[GCSLogQueueItem] = [] + while len(items_to_process) < self.batch_size: + try: + items_to_process.append(self.log_queue.get_nowait()) + except asyncio.QueueEmpty: + break + return items_to_process - for log_item in self.log_queue: - logging_payload = log_item["payload"] - kwargs = log_item["kwargs"] - response_obj = log_item.get("response_obj", None) or {} + def _generate_batch_object_name(self, date_str: str, batch_id: str) -> str: + """ + Generate object name for a batched log file. + Format: {date}/batch-{batch_id}.ndjson + """ + return f"{date_str}/batch-{batch_id}.ndjson" + def _get_config_key(self, kwargs: Dict[str, Any]) -> str: + """ + Extract a synchronous grouping key from kwargs to group items by GCS config. + This allows us to batch items with the same bucket/credentials together. + + Returns a string key that uniquely identifies the GCS config combination. + This key may contain sensitive information (bucket names, paths) - use _sanitize_config_key() + for logging purposes. + """ + standard_callback_dynamic_params = kwargs.get("standard_callback_dynamic_params", None) or {} + + bucket_name = standard_callback_dynamic_params.get("gcs_bucket_name", None) or self.BUCKET_NAME or "default" + path_service_account = standard_callback_dynamic_params.get("gcs_path_service_account", None) or self.path_service_account_json or "default" + + return f"{bucket_name}|{path_service_account}" + + def _sanitize_config_key(self, config_key: str) -> str: + """ + Create a sanitized version of the config key for logging. + Uses a hash to avoid exposing sensitive bucket names or service account paths. + + Returns a short hash prefix for safe logging. + """ + hash_obj = hashlib.sha256(config_key.encode('utf-8')) + return f"config-{hash_obj.hexdigest()[:8]}" + + def _group_items_by_config(self, items: List[GCSLogQueueItem]) -> Dict[str, List[GCSLogQueueItem]]: + """ + Group items by their GCS config (bucket + credentials). + This ensures items with different configs are processed separately. + + Returns a dict mapping config_key -> list of items with that config. + """ + grouped: Dict[str, List[GCSLogQueueItem]] = {} + for item in items: + config_key = self._get_config_key(item["kwargs"]) + if config_key not in grouped: + grouped[config_key] = [] + grouped[config_key].append(item) + return grouped + + def _combine_payloads_to_ndjson(self, items: List[GCSLogQueueItem]) -> str: + """ + Combine multiple log payloads into newline-delimited JSON (NDJSON) format. + Each line is a valid JSON object representing one log entry. + """ + lines = [] + for item in items: + logging_payload = item["payload"] + json_line = json.dumps(logging_payload, default=str, ensure_ascii=False) + lines.append(json_line) + return "\n".join(lines) + + async def _send_grouped_batch(self, items: List[GCSLogQueueItem], config_key: str) -> Tuple[int, int]: + """ + Send a batch of items that share the same GCS config. + + Returns: + (success_count, error_count) + """ + if not items: + return (0, 0) + + first_kwargs = items[0]["kwargs"] + + try: gcs_logging_config: GCSLoggingConfig = await self.get_gcs_logging_config( - kwargs + first_kwargs ) headers = await self.construct_request_headers( @@ -127,24 +202,92 @@ class GCSBucketLogger(GCSBucketBase, AdditionalLoggingUtils): service_account_json=gcs_logging_config["path_service_account"], ) bucket_name = gcs_logging_config["bucket_name"] - object_name = self._get_object_name(kwargs, logging_payload, response_obj) + + current_date = self._get_object_date_from_datetime(datetime.now(timezone.utc)) + batch_id = f"{int(time.time() * 1000)}-{uuid.uuid4().hex[:8]}" + object_name = self._generate_batch_object_name(current_date, batch_id) + combined_payload = self._combine_payloads_to_ndjson(items) + + await self._log_json_data_on_gcs( + headers=headers, + bucket_name=bucket_name, + object_name=object_name, + logging_payload=combined_payload, + ) + + success_count = len(items) + error_count = 0 + return (success_count, error_count) + + except Exception as e: + success_count = 0 + error_count = len(items) + verbose_logger.exception( + f"GCS Bucket error logging batch payload to GCS bucket: {str(e)}" + ) + return (success_count, error_count) - try: - await self._log_json_data_on_gcs( - headers=headers, - bucket_name=bucket_name, - object_name=object_name, - logging_payload=logging_payload, - ) - except Exception as e: - # don't let one log item fail the entire batch - verbose_logger.exception( - f"GCS Bucket error logging payload to GCS bucket: {str(e)}" - ) - pass + async def _send_individual_logs(self, items: List[GCSLogQueueItem]) -> None: + """ + Send each log individually as separate GCS objects (legacy behavior). + This is used when GCS_USE_BATCHED_LOGGING is disabled. + """ + for item in items: + await self._send_single_log_item(item) - # Clear the queue after processing - self.log_queue.clear() + async def _send_single_log_item(self, item: GCSLogQueueItem) -> None: + """ + Send a single log item to GCS as an individual object. + """ + try: + gcs_logging_config: GCSLoggingConfig = await self.get_gcs_logging_config( + item["kwargs"] + ) + + headers = await self.construct_request_headers( + vertex_instance=gcs_logging_config["vertex_instance"], + service_account_json=gcs_logging_config["path_service_account"], + ) + bucket_name = gcs_logging_config["bucket_name"] + + object_name = self._get_object_name( + kwargs=item["kwargs"], + logging_payload=item["payload"], + response_obj=item["response_obj"], + ) + + await self._log_json_data_on_gcs( + headers=headers, + bucket_name=bucket_name, + object_name=object_name, + logging_payload=item["payload"], + ) + except Exception as e: + verbose_logger.exception( + f"GCS Bucket error logging individual payload to GCS bucket: {str(e)}" + ) + + async def async_send_batch(self): + """ + Process queued logs - sends logs to GCS Bucket. + + If `GCS_USE_BATCHED_LOGGING` is enabled (default), batches multiple log payloads + into single GCS object uploads (NDJSON format), dramatically reducing API calls. + + If disabled, sends each log individually as separate GCS objects (legacy behavior). + """ + items_to_process = self._drain_queue_batch() + + if not items_to_process: + return + + if self.use_batched_logging: + grouped_items = self._group_items_by_config(items_to_process) + + for config_key, group_items in grouped_items.items(): + await self._send_grouped_batch(group_items, config_key) + else: + await self._send_individual_logs(items_to_process) def _get_object_name( self, kwargs: Dict, logging_payload: StandardLoggingPayload, response_obj: Any @@ -186,7 +329,6 @@ class GCSBucketLogger(GCSBucketBase, AdditionalLoggingUtils): "start_time_utc is required for getting a payload from GCS Bucket" ) - # Try current day, next day, and previous day dates_to_try = [ start_time_utc, start_time_utc + timedelta(days=1), @@ -230,5 +372,23 @@ class GCSBucketLogger(GCSBucketBase, AdditionalLoggingUtils): def _get_object_date_from_datetime(self, datetime_obj: datetime) -> str: return datetime_obj.strftime("%Y-%m-%d") + async def flush_queue(self): + """ + Override flush_queue to work with asyncio.Queue. + """ + await self.async_send_batch() + self.last_flush_time = time.time() + + async def periodic_flush(self): + """ + Override periodic_flush to work with asyncio.Queue. + """ + while True: + await asyncio.sleep(self.flush_interval) + verbose_logger.debug( + f"GCS Bucket periodic flush after {self.flush_interval} seconds" + ) + await self.flush_queue() + async def async_health_check(self) -> IntegrationHealthCheckStatus: raise NotImplementedError("GCS Bucket does not support health check") diff --git a/litellm/integrations/opentelemetry.py b/litellm/integrations/opentelemetry.py index a223925d59..9fbccac68d 100644 --- a/litellm/integrations/opentelemetry.py +++ b/litellm/integrations/opentelemetry.py @@ -987,7 +987,10 @@ class OpenTelemetry(CustomLogger): # TODO: Refactor to use the proper OTEL Logs API instead of directly creating SDK LogRecords from opentelemetry._logs import SeverityNumber, get_logger, get_logger_provider - from opentelemetry.sdk._logs import LogRecord as SdkLogRecord + try: + from opentelemetry.sdk._logs import LogRecord as SdkLogRecord # OTEL < 1.39.0 + except ImportError: + from opentelemetry.sdk._logs._internal import LogRecord as SdkLogRecord # OTEL >= 1.39.0 otel_logger = get_logger(LITELLM_LOGGER_NAME) diff --git a/litellm/litellm_core_utils/prompt_templates/factory.py b/litellm/litellm_core_utils/prompt_templates/factory.py index 43ed23587d..21a79af2bd 100644 --- a/litellm/litellm_core_utils/prompt_templates/factory.py +++ b/litellm/litellm_core_utils/prompt_templates/factory.py @@ -4410,9 +4410,10 @@ def _bedrock_tools_pt(tools: List) -> List[BedrockToolBlock]: defs = parameters.pop("$defs", {}) defs_copy = copy.deepcopy(defs) - # flatten the defs - for _, value in defs_copy.items(): - unpack_defs(value, defs_copy) + # Expand $ref references in parameters using the definitions + # Note: We don't pre-flatten defs as that causes exponential memory growth + # with circular references (see issue #19098). unpack_defs handles nested + # refs recursively and correctly detects/skips circular references. unpack_defs(parameters, defs_copy) tool_input_schema = BedrockToolInputSchemaBlock( json=BedrockToolJsonSchemaBlock( diff --git a/litellm/llms/anthropic/chat/transformation.py b/litellm/llms/anthropic/chat/transformation.py index 5b1b663e85..86378b97d2 100644 --- a/litellm/llms/anthropic/chat/transformation.py +++ b/litellm/llms/anthropic/chat/transformation.py @@ -934,8 +934,15 @@ class AnthropicConfig(AnthropicModelInfo, BaseConfig): ) return tools - def _ensure_context_management_beta_header(self, headers: dict) -> None: - beta_value = ANTHROPIC_BETA_HEADER_VALUES.CONTEXT_MANAGEMENT_2025_06_27.value + def _ensure_beta_header(self, headers: dict, beta_value: str) -> None: + """ + Ensure a beta header value is present in the anthropic-beta header. + Merges with existing values instead of overriding them. + + Args: + headers: Dictionary of headers to update + beta_value: The beta header value to add + """ existing_beta = headers.get("anthropic-beta") if existing_beta is None: headers["anthropic-beta"] = beta_value @@ -944,6 +951,10 @@ class AnthropicConfig(AnthropicModelInfo, BaseConfig): if beta_value not in existing_values: headers["anthropic-beta"] = f"{existing_beta}, {beta_value}" + def _ensure_context_management_beta_header(self, headers: dict) -> None: + beta_value = ANTHROPIC_BETA_HEADER_VALUES.CONTEXT_MANAGEMENT_2025_06_27.value + self._ensure_beta_header(headers, beta_value) + def update_headers_with_optional_anthropic_beta( self, headers: dict, optional_params: dict ) -> dict: @@ -960,20 +971,20 @@ class AnthropicConfig(AnthropicModelInfo, BaseConfig): if tool.get("type", None) and tool.get("type").startswith( ANTHROPIC_HOSTED_TOOLS.WEB_FETCH.value ): - headers["anthropic-beta"] = ( - ANTHROPIC_BETA_HEADER_VALUES.WEB_FETCH_2025_09_10.value + self._ensure_beta_header( + headers, ANTHROPIC_BETA_HEADER_VALUES.WEB_FETCH_2025_09_10.value ) elif tool.get("type", None) and tool.get("type").startswith( ANTHROPIC_HOSTED_TOOLS.MEMORY.value ): - headers["anthropic-beta"] = ( - ANTHROPIC_BETA_HEADER_VALUES.CONTEXT_MANAGEMENT_2025_06_27.value + self._ensure_beta_header( + headers, ANTHROPIC_BETA_HEADER_VALUES.CONTEXT_MANAGEMENT_2025_06_27.value ) if optional_params.get("context_management") is not None: self._ensure_context_management_beta_header(headers) if optional_params.get("output_format") is not None: - headers["anthropic-beta"] = ( - ANTHROPIC_BETA_HEADER_VALUES.STRUCTURED_OUTPUT_2025_09_25.value + self._ensure_beta_header( + headers, ANTHROPIC_BETA_HEADER_VALUES.STRUCTURED_OUTPUT_2025_09_25.value ) return headers diff --git a/litellm/llms/azure_ai/image_edit/flux2_transformation.py b/litellm/llms/azure_ai/image_edit/flux2_transformation.py index 87bae59ba0..77d46ff917 100644 --- a/litellm/llms/azure_ai/image_edit/flux2_transformation.py +++ b/litellm/llms/azure_ai/image_edit/flux2_transformation.py @@ -88,7 +88,7 @@ class AzureFoundryFlux2ImageEditConfig(OpenAIImageEditConfig): self, model: str, prompt: Optional[str], - image: FileTypes, + image: Optional[FileTypes], image_edit_optional_request_params: Dict, litellm_params: GenericLiteLLMParams, headers: dict, @@ -102,6 +102,9 @@ class AzureFoundryFlux2ImageEditConfig(OpenAIImageEditConfig): if prompt is None: raise ValueError("FLUX 2 image edit requires a prompt.") + if image is None: + raise ValueError("FLUX 2 image edit requires an image.") + image_b64 = self._convert_image_to_base64(image) # Build request body with required params diff --git a/litellm/llms/base_llm/image_edit/transformation.py b/litellm/llms/base_llm/image_edit/transformation.py index cc72348037..b088cdf37f 100644 --- a/litellm/llms/base_llm/image_edit/transformation.py +++ b/litellm/llms/base_llm/image_edit/transformation.py @@ -93,7 +93,7 @@ class BaseImageEditConfig(ABC): self, model: str, prompt: Optional[str], - image: FileTypes, + image: Optional[FileTypes], image_edit_optional_request_params: Dict, litellm_params: GenericLiteLLMParams, headers: dict, diff --git a/litellm/llms/bedrock/image_edit/handler.py b/litellm/llms/bedrock/image_edit/handler.py index 0f1dcff629..ef441fa503 100644 --- a/litellm/llms/bedrock/image_edit/handler.py +++ b/litellm/llms/bedrock/image_edit/handler.py @@ -62,7 +62,7 @@ class BedrockImageEdit(BaseAWSLLM): self, model: str, image: list, - prompt: str, + prompt: Optional[str], model_response: ImageResponse, optional_params: dict, logging_obj: LitellmLogging, @@ -127,7 +127,7 @@ class BedrockImageEdit(BaseAWSLLM): timeout: Optional[Union[float, httpx.Timeout]], model: str, logging_obj: LitellmLogging, - prompt: str, + prompt: Optional[str], model_response: ImageResponse, client: Optional[AsyncHTTPHandler] = None, ) -> ImageResponse: @@ -163,7 +163,7 @@ class BedrockImageEdit(BaseAWSLLM): self, model: str, image: list, - prompt: str, + prompt: Optional[str], optional_params: dict, api_base: Optional[str], extra_headers: Optional[dict], @@ -176,7 +176,7 @@ class BedrockImageEdit(BaseAWSLLM): Args: model (str): The model to use for the image edit image (list): The images to edit - prompt (str): The prompt for the edit + prompt (Optional[str]): The prompt for the edit optional_params (dict): The optional parameters for the image edit api_base (Optional[str]): The base URL for the Bedrock API extra_headers (Optional[dict]): The extra headers to include in the request @@ -248,7 +248,7 @@ class BedrockImageEdit(BaseAWSLLM): self, model: str, image: list, - prompt: str, + prompt: Optional[str], optional_params: dict, ) -> dict: """ @@ -276,7 +276,7 @@ class BedrockImageEdit(BaseAWSLLM): model_response: ImageResponse, model: str, logging_obj: LitellmLogging, - prompt: str, + prompt: Optional[str], response: httpx.Response, data: dict, ) -> ImageResponse: diff --git a/litellm/llms/bedrock/image_edit/stability_transformation.py b/litellm/llms/bedrock/image_edit/stability_transformation.py index e8b7781298..fc14b571a8 100644 --- a/litellm/llms/bedrock/image_edit/stability_transformation.py +++ b/litellm/llms/bedrock/image_edit/stability_transformation.py @@ -150,11 +150,11 @@ class BedrockStabilityImageEditConfig(BaseImageEditConfig): return mapped_params - def transform_image_edit_request( + def transform_image_edit_request( #noqa: PLR0915 self, model: str, prompt: Optional[str], - image: FileTypes, + image: Optional[FileTypes], image_edit_optional_request_params: Dict, litellm_params: GenericLiteLLMParams, headers: dict, @@ -164,32 +164,38 @@ class BedrockStabilityImageEditConfig(BaseImageEditConfig): Returns the request body dict that will be JSON-encoded by the handler. """ - if prompt is None: - raise ValueError("Bedrock Stability image edit requires a prompt.") - # Build Bedrock Stability request data: Dict[str, Any] = { - "prompt": prompt, "output_format": "png", # Default to PNG } - # Convert image to base64 - image_b64: str - if hasattr(image, 'read') and callable(getattr(image, 'read', None)): - # File-like object (e.g., BufferedReader from open()) - image_bytes = image.read() # type: ignore - image_b64 = base64.b64encode(image_bytes).decode('utf-8') # type: ignore - elif isinstance(image, bytes): - # Raw bytes - image_b64 = base64.b64encode(image).decode('utf-8') - elif isinstance(image, str): - # Already a base64 string - image_b64 = image - else: - # Try to handle as bytes - image_b64 = base64.b64encode(bytes(image)).decode('utf-8') # type: ignore + # Add prompt only if provided (some models don't require it) + if prompt is not None and prompt != "": + data["prompt"] = prompt + + # Convert image to base64 if provided + if image is not None: + image_b64: str + if hasattr(image, 'read') and callable(getattr(image, 'read', None)): + # File-like object (e.g., BufferedReader from open()) + image_bytes = image.read() # type: ignore + image_b64 = base64.b64encode(image_bytes).decode('utf-8') # type: ignore + elif isinstance(image, bytes): + # Raw bytes + image_b64 = base64.b64encode(image).decode('utf-8') + elif isinstance(image, str): + # Already a base64 string + image_b64 = image + else: + # Try to handle as bytes + image_b64 = base64.b64encode(bytes(image)).decode('utf-8') # type: ignore - data["image"] = image_b64 + # For style-transfer models, map image to init_image + model_lower = model.lower() + if "style-transfer" in model_lower: + data["init_image"] = image_b64 + else: + data["image"] = image_b64 # Add optional params (already mapped in map_openai_params) for key, value in image_edit_optional_request_params.items(): # type: ignore @@ -221,30 +227,43 @@ class BedrockStabilityImageEditConfig(BaseImageEditConfig): file_b64 = str(file_bytes) data[key] = file_b64 continue - - # Supported text fields - if key in [ - "negative_prompt", - "aspect_ratio", - "seed", - "output_format", - "model", - "mode", + + # Numeric fields that need to be converted to int/float + numeric_int_fields = ["left", "right", "up", "down", "seed"] + numeric_float_fields = [ "strength", - "style_preset", "creativity", "control_strength", "grow_mask", - "left", - "right", - "up", - "down", - "select_prompt", - "search_prompt", "fidelity", "composition_fidelity", "style_strength", "change_strength", + ] + + if key in numeric_int_fields: + # Convert to int (these are pixel values for outpaint) + try: + data[key] = int(value) # type: ignore + except (ValueError, TypeError): + data[key] = value # type: ignore + elif key in numeric_float_fields: + # Convert to float + try: + data[key] = float(value) # type: ignore + except (ValueError, TypeError): + data[key] = value # type: ignore + + # Supported text fields + elif key in [ + "negative_prompt", + "aspect_ratio", + "output_format", + "model", + "mode", + "style_preset", + "select_prompt", + "search_prompt", ]: data[key] = value # type: ignore diff --git a/litellm/llms/bedrock/messages/invoke_transformations/anthropic_claude3_transformation.py b/litellm/llms/bedrock/messages/invoke_transformations/anthropic_claude3_transformation.py index fa5002fcad..293ee1caaf 100644 --- a/litellm/llms/bedrock/messages/invoke_transformations/anthropic_claude3_transformation.py +++ b/litellm/llms/bedrock/messages/invoke_transformations/anthropic_claude3_transformation.py @@ -50,6 +50,12 @@ class AmazonAnthropicClaudeMessagesConfig( DEFAULT_BEDROCK_ANTHROPIC_API_VERSION = "bedrock-2023-05-31" + # Beta header patterns that are not supported by Bedrock Invoke API + # These will be filtered out to prevent 400 "invalid beta flag" errors + UNSUPPORTED_BEDROCK_INVOKE_BETA_PATTERNS = [ + "advanced-tool-use", # Bedrock Invoke doesn't support advanced-tool-use beta headers + ] + def __init__(self, **kwargs): BaseAnthropicMessagesConfig.__init__(self, **kwargs) AmazonInvokeConfig.__init__(self, **kwargs) @@ -114,7 +120,7 @@ class AmazonAnthropicClaudeMessagesConfig( """ Remove `ttl` field from cache_control in messages. Bedrock doesn't support the ttl field in cache_control. - + Args: anthropic_messages_request: The request dictionary to modify in-place """ @@ -129,6 +135,75 @@ class AmazonAnthropicClaudeMessagesConfig( if isinstance(cache_control, dict) and "ttl" in cache_control: cache_control.pop("ttl", None) + def _supports_extended_thinking_on_bedrock(self, model: str) -> bool: + """ + Check if the model supports extended thinking beta headers on Bedrock. + + On 3rd-party platforms (e.g., Amazon Bedrock), extended thinking is only + supported on: Claude Opus 4.5, Claude Opus 4.1, Opus 4, or Sonnet 4. + + Ref: https://docs.anthropic.com/en/docs/build-with-claude/extended-thinking + + Args: + model: The model name + + Returns: + True if the model supports extended thinking on Bedrock + """ + model_lower = model.lower() + + # Supported models on Bedrock for extended thinking + supported_patterns = [ + "opus-4.5", "opus_4.5", "opus-4-5", "opus_4_5", # Opus 4.5 + "opus-4.1", "opus_4.1", "opus-4-1", "opus_4_1", # Opus 4.1 + "opus-4", "opus_4", # Opus 4 + "sonnet-4", "sonnet_4", # Sonnet 4 + ] + + return any(pattern in model_lower for pattern in supported_patterns) + + def _filter_unsupported_beta_headers_for_bedrock( + self, model: str, beta_set: set + ) -> None: + """ + Remove beta headers that are not supported on Bedrock for the given model. + + Extended thinking beta headers are only supported on specific Claude 4+ models. + Advanced tool use headers are not supported on Bedrock Invoke API. + This prevents 400 "invalid beta flag" errors on Bedrock. + + Note: Bedrock Invoke API fails with a 400 error when unsupported beta headers + are sent, returning: {"message":"invalid beta flag"} + + Args: + model: The model name + beta_set: The set of beta headers to filter in-place + """ + beta_headers_to_remove = set() + + # 1. Filter out beta headers that are universally unsupported on Bedrock Invoke + for beta in beta_set: + for unsupported_pattern in self.UNSUPPORTED_BEDROCK_INVOKE_BETA_PATTERNS: + if unsupported_pattern in beta.lower(): + beta_headers_to_remove.add(beta) + break + + # 2. Filter out extended thinking headers for models that don't support them + extended_thinking_patterns = [ + "extended-thinking", + "interleaved-thinking", + ] + if not self._supports_extended_thinking_on_bedrock(model): + for beta in beta_set: + for pattern in extended_thinking_patterns: + if pattern in beta.lower(): + beta_headers_to_remove.add(beta) + break + + # Remove all filtered headers + for beta in beta_headers_to_remove: + beta_set.discard(beta) + def _get_tool_search_beta_header_for_bedrock( self, model: str, @@ -139,15 +214,15 @@ class AmazonAnthropicClaudeMessagesConfig( ) -> None: """ Adjust tool search beta header for Bedrock. - + Bedrock requires a different beta header for tool search on Opus 4 models when tool search is used without programmatic tool calling or input examples. - + Note: On Amazon Bedrock, server-side tool search is only supported on Claude Opus 4 with the `tool-search-tool-2025-10-19` beta header. - + Ref: https://platform.claude.com/docs/en/agents-and-tools/tool-use/tool-search-tool - + Args: model: The model name tool_search_used: Whether tool search is used @@ -228,6 +303,12 @@ class AmazonAnthropicClaudeMessagesConfig( beta_set=beta_set, ) + # Filter out unsupported beta headers for Bedrock (e.g., advanced-tool-use, extended-thinking on non-Opus/Sonnet 4 models) + self._filter_unsupported_beta_headers_for_bedrock( + model=model, + beta_set=beta_set, + ) + if beta_set: anthropic_messages_request["anthropic_beta"] = list(beta_set) diff --git a/litellm/llms/custom_httpx/llm_http_handler.py b/litellm/llms/custom_httpx/llm_http_handler.py index ab1e735fca..6a87967c3a 100644 --- a/litellm/llms/custom_httpx/llm_http_handler.py +++ b/litellm/llms/custom_httpx/llm_http_handler.py @@ -3080,10 +3080,8 @@ class BaseLLMHTTPHandler: transformed_request, bytes ): # Handle traditional file uploads - # Ensure transformed_request is a string for httpx compatibility - if isinstance(transformed_request, bytes): - transformed_request = transformed_request.decode("utf-8") - + # Note: transformed_request can be bytes (for binary files like PDFs) + # or str (for text files like JSONL). httpx handles both correctly. # Use the HTTP method specified by the provider config http_method = provider_config.file_upload_http_method.upper() if http_method == "PUT": diff --git a/litellm/llms/gemini/image_edit/transformation.py b/litellm/llms/gemini/image_edit/transformation.py index 0015155b47..1654113821 100644 --- a/litellm/llms/gemini/image_edit/transformation.py +++ b/litellm/llms/gemini/image_edit/transformation.py @@ -81,21 +81,23 @@ class GeminiImageEditConfig(BaseImageEditConfig): self, model: str, prompt: Optional[str], - image: FileTypes, + image: Optional[FileTypes], image_edit_optional_request_params: Dict[str, Any], litellm_params: GenericLiteLLMParams, headers: dict, ) -> Tuple[Dict[str, Any], Optional[RequestFiles]]: - inline_parts = self._prepare_inline_image_parts(image) + inline_parts = self._prepare_inline_image_parts(image) if image else [] if not inline_parts: raise ValueError("Gemini image edit requires at least one image.") - if prompt is None: - raise ValueError("Gemini image edit requires a prompt.") + # Build parts list with image and prompt (if provided) + parts = inline_parts.copy() + if prompt is not None and prompt != "": + parts.append({"text": prompt}) contents = [ { - "parts": inline_parts + [{"text": prompt}], + "parts": parts, } ] diff --git a/litellm/llms/openai/image_edit/dalle2_transformation.py b/litellm/llms/openai/image_edit/dalle2_transformation.py index 13531546d2..fd697b210e 100644 --- a/litellm/llms/openai/image_edit/dalle2_transformation.py +++ b/litellm/llms/openai/image_edit/dalle2_transformation.py @@ -31,7 +31,7 @@ class DallE2ImageEditConfig(OpenAIImageEditConfig): self, model: str, prompt: Optional[str], - image: FileTypes, + image: Optional[FileTypes], image_edit_optional_request_params: Dict, litellm_params: GenericLiteLLMParams, headers: dict, @@ -40,18 +40,20 @@ class DallE2ImageEditConfig(OpenAIImageEditConfig): Transform image edit request for DALL-E-2. DALL-E-2 only accepts a single image with field name "image" (not "image[]"). - """ - if prompt is None: - raise ValueError("DALL-E-2 image edit requires a prompt.") - - request = ImageEditRequestParams( - model=model, - image=image, - prompt=prompt, + """ + request_params = { + "model": model, **image_edit_optional_request_params, - ) + } + if image is not None: + request_params["image"] = image + if prompt is not None: + request_params["prompt"] = prompt + + request = ImageEditRequestParams(**request_params) request_dict = cast(Dict, request) + ######################################################### # Separate images and masks as `files` and send other parameters as `data` ######################################################### diff --git a/litellm/llms/openai/image_edit/transformation.py b/litellm/llms/openai/image_edit/transformation.py index 9edad9ee2c..a1e5375d09 100644 --- a/litellm/llms/openai/image_edit/transformation.py +++ b/litellm/llms/openai/image_edit/transformation.py @@ -80,7 +80,7 @@ class OpenAIImageEditConfig(BaseImageEditConfig): self, model: str, prompt: Optional[str], - image: FileTypes, + image: Optional[FileTypes], image_edit_optional_request_params: Dict, litellm_params: GenericLiteLLMParams, headers: dict, @@ -91,15 +91,17 @@ class OpenAIImageEditConfig(BaseImageEditConfig): Handles multipart/form-data for images. Uses "image[]" field name to support multiple images (e.g., for gpt-image-1). """ - if prompt is None: - raise ValueError("OpenAI image edit requires a prompt.") - - request = ImageEditRequestParams( - model=model, - image=image, - prompt=prompt, + # Build request params, only including non-None values + request_params = { + "model": model, **image_edit_optional_request_params, - ) + } + if image is not None: + request_params["image"] = image + if prompt is not None: + request_params["prompt"] = prompt + + request = ImageEditRequestParams(**request_params) request_dict = cast(Dict, request) ######################################################### diff --git a/litellm/llms/recraft/image_edit/transformation.py b/litellm/llms/recraft/image_edit/transformation.py index 9bf46704ed..d2a5623681 100644 --- a/litellm/llms/recraft/image_edit/transformation.py +++ b/litellm/llms/recraft/image_edit/transformation.py @@ -102,7 +102,7 @@ class RecraftImageEditConfig(BaseImageEditConfig): self, model: str, prompt: Optional[str], - image: FileTypes, + image: Optional[FileTypes], image_edit_optional_request_params: Dict, litellm_params: GenericLiteLLMParams, headers: dict, @@ -114,15 +114,15 @@ class RecraftImageEditConfig(BaseImageEditConfig): https://www.recraft.ai/docs#image-to-image """ - if prompt is None: - raise ValueError("Recraft image edit requires a prompt.") - - request_body: RecraftImageEditRequestParams = RecraftImageEditRequestParams( - model=model, - prompt=prompt, - strength=image_edit_optional_request_params.pop("strength", self.DEFAULT_STRENGTH), + request_params = { + "model": model, + "strength": image_edit_optional_request_params.pop("strength", self.DEFAULT_STRENGTH), **image_edit_optional_request_params, - ) + } + if prompt is not None: + request_params["prompt"] = prompt + + request_body = RecraftImageEditRequestParams(**request_params) request_dict = cast(Dict, request_body) ######################################################### # Reuse OpenAI logic: Separate images as `files` and send other parameters as `data` diff --git a/litellm/llms/replicate/chat/handler.py b/litellm/llms/replicate/chat/handler.py index 4c75db5abc..c37473b318 100644 --- a/litellm/llms/replicate/chat/handler.py +++ b/litellm/llms/replicate/chat/handler.py @@ -83,19 +83,27 @@ async def async_handle_prediction_response_streaming( await asyncio.sleep( REPLICATE_POLLING_DELAY_SECONDS ) # prevent being rate limited by replicate - print_verbose(f"replicate: polling endpoint: {prediction_url}") response = await http_client.get(prediction_url, headers=headers) if response.status_code == 200: response_data = response.json() - status = response_data["status"] - if "output" in response_data: + status = response_data.get("status", "") + # Check that "output" exists and is not None or empty + output_present = "output" in response_data and response_data["output"] is not None + if output_present: try: - output_string = "".join(response_data["output"]) + # If output is None or not a list, treat as empty string + if isinstance(response_data["output"], list): + output_string = "".join(response_data["output"]) + elif response_data["output"] is None: + output_string = "" + else: + # fallback for other types; convert to string safely + output_string = str(response_data["output"]) except Exception: raise ReplicateError( status_code=422, message="Unable to parse response. Got={}".format( - response_data["output"] + response_data.get("output", None) ), headers=response.headers, ) @@ -103,7 +111,7 @@ async def async_handle_prediction_response_streaming( print_verbose(f"New chunk: {new_output}") yield {"output": new_output, "status": status} previous_output = output_string - status = response_data["status"] + status = response_data.get("status", "") if status == "failed": replicate_error = response_data.get("error", "") raise ReplicateError( diff --git a/litellm/llms/stability/image_edit/transformations.py b/litellm/llms/stability/image_edit/transformations.py index 013e3f27a0..53bdc825dd 100644 --- a/litellm/llms/stability/image_edit/transformations.py +++ b/litellm/llms/stability/image_edit/transformations.py @@ -171,7 +171,7 @@ class StabilityImageEditConfig(BaseImageEditConfig): self, model: str, prompt: Optional[str], - image: FileTypes, + image: Optional[FileTypes], image_edit_optional_request_params: Dict, litellm_params: GenericLiteLLMParams, headers: dict, @@ -190,11 +190,14 @@ class StabilityImageEditConfig(BaseImageEditConfig): } # Add prompt only if provided (some Stability endpoints don't require it) - if prompt is not None: + if prompt is not None and prompt != "": data["prompt"] = prompt # Handle image parameter - could be a single file or list image_file = image[0] if isinstance(image, list) else image # type: ignore - files: Dict[str, Any] = {"image": image_file} + files: Dict[str, Any] = {} + if image is not None: + image_file = image[0] if isinstance(image, list) else image # type: ignore + files["image"] = image_file # Add optional params (already mapped in map_openai_params) for key, value in image_edit_optional_request_params.items(): # type: ignore diff --git a/litellm/llms/vertex_ai/common_utils.py b/litellm/llms/vertex_ai/common_utils.py index 5aa7662f17..5b09bcc028 100644 --- a/litellm/llms/vertex_ai/common_utils.py +++ b/litellm/llms/vertex_ai/common_utils.py @@ -453,9 +453,10 @@ def _build_vertex_schema(parameters: dict, add_property_ordering: bool = False): valid_schema_fields = set(get_type_hints(Schema).keys()) defs = parameters.pop("$defs", {}) - # flatten the defs - for name, value in defs.items(): - unpack_defs(value, defs) + # Expand $ref references in parameters using the definitions + # Note: We don't pre-flatten defs as that causes exponential memory growth + # with circular references (see issue #19098). unpack_defs handles nested + # refs recursively and correctly detects/skips circular references. unpack_defs(parameters, defs) # 5. Nullable fields: diff --git a/litellm/llms/vertex_ai/image_edit/vertex_gemini_transformation.py b/litellm/llms/vertex_ai/image_edit/vertex_gemini_transformation.py index 154d5669eb..8fcd285824 100644 --- a/litellm/llms/vertex_ai/image_edit/vertex_gemini_transformation.py +++ b/litellm/llms/vertex_ai/image_edit/vertex_gemini_transformation.py @@ -152,22 +152,24 @@ class VertexAIGeminiImageEditConfig(BaseImageEditConfig, VertexLLM): self, model: str, prompt: Optional[str], - image: FileTypes, + image: Optional[FileTypes], image_edit_optional_request_params: Dict[str, Any], litellm_params: GenericLiteLLMParams, headers: dict, ) -> Tuple[Dict[str, Any], Optional[RequestFiles]]: - inline_parts = self._prepare_inline_image_parts(image) + inline_parts = self._prepare_inline_image_parts(image) if image else [] if not inline_parts: raise ValueError("Vertex AI Gemini image edit requires at least one image.") - if prompt is None: - raise ValueError("Vertex AI Gemini image edit requires a prompt.") + # Build parts list with image and prompt (if provided) + parts = inline_parts.copy() + if prompt is not None and prompt != "": + parts.append({"text": prompt}) # Correct format for Vertex AI Gemini image editing contents = { "role": "USER", - "parts": inline_parts + [{"text": prompt}] + "parts": parts } request_body: Dict[str, Any] = {"contents": contents} diff --git a/litellm/llms/vertex_ai/image_edit/vertex_imagen_transformation.py b/litellm/llms/vertex_ai/image_edit/vertex_imagen_transformation.py index 337a4bd4dd..b58825e1fa 100644 --- a/litellm/llms/vertex_ai/image_edit/vertex_imagen_transformation.py +++ b/litellm/llms/vertex_ai/image_edit/vertex_imagen_transformation.py @@ -144,7 +144,7 @@ class VertexAIImagenImageEditConfig(BaseImageEditConfig, VertexLLM): self, model: str, prompt: Optional[str], - image: FileTypes, + image: Optional[FileTypes], image_edit_optional_request_params: Dict[str, Any], litellm_params: GenericLiteLLMParams, headers: dict, diff --git a/litellm/model_prices_and_context_window_backup.json b/litellm/model_prices_and_context_window_backup.json index 470d598a25..87f3566ae8 100644 --- a/litellm/model_prices_and_context_window_backup.json +++ b/litellm/model_prices_and_context_window_backup.json @@ -7857,6 +7857,24 @@ "supports_tool_choice": true, "supports_vision": true }, + "dall-e-2": { + "input_cost_per_image": 0.02, + "litellm_provider": "openai", + "mode": "image_generation", + "supported_endpoints": [ + "/v1/images/generations", + "/v1/images/edits", + "/v1/images/variations" + ] + }, + "dall-e-3": { + "input_cost_per_image": 0.04, + "litellm_provider": "openai", + "mode": "image_generation", + "supported_endpoints": [ + "/v1/images/generations" + ] + }, "deepseek-chat": { "cache_read_input_token_cost": 2.8e-08, "input_cost_per_token": 2.8e-07, @@ -18808,13 +18826,14 @@ "supports_tool_choice": true }, "groq/openai/gpt-oss-120b": { + "cache_read_input_token_cost": 7.5e-08, "input_cost_per_token": 1.5e-07, "litellm_provider": "groq", "max_input_tokens": 131072, "max_output_tokens": 32766, "max_tokens": 32766, "mode": "chat", - "output_cost_per_token": 7.5e-07, + "output_cost_per_token": 6e-07, "supports_function_calling": true, "supports_parallel_function_calling": true, "supports_reasoning": true, @@ -18823,13 +18842,14 @@ "supports_web_search": true }, "groq/openai/gpt-oss-20b": { - "input_cost_per_token": 1e-07, + "cache_read_input_token_cost": 3.75e-08, + "input_cost_per_token": 7.5e-08, "litellm_provider": "groq", "max_input_tokens": 131072, "max_output_tokens": 32768, "max_tokens": 32768, "mode": "chat", - "output_cost_per_token": 5e-07, + "output_cost_per_token": 3e-07, "supports_function_calling": true, "supports_parallel_function_calling": true, "supports_reasoning": true, diff --git a/litellm/proxy/agent_endpoints/a2a_endpoints.py b/litellm/proxy/agent_endpoints/a2a_endpoints.py index c2d53b40b7..a21f3291d3 100644 --- a/litellm/proxy/agent_endpoints/a2a_endpoints.py +++ b/litellm/proxy/agent_endpoints/a2a_endpoints.py @@ -55,9 +55,30 @@ async def _handle_stream_message( proxy_server_request: Optional[dict] = None, ) -> StreamingResponse: """Handle message/stream method via SDK functions.""" - from a2a.types import MessageSendParams, SendStreamingMessageRequest - from litellm.a2a_protocol import asend_message_streaming + from litellm.a2a_protocol.main import A2A_SDK_AVAILABLE + + # Check is handled in invoke_agent_a2a, but if called directly: + if not A2A_SDK_AVAILABLE: + # Return a streaming response that yields an error + async def _error_stream(): + yield json.dumps( + { + "jsonrpc": "2.0", + "id": request_id, + "error": { + "code": -32603, + "message": "Server error: 'a2a' package not installed", + }, + } + ) + "\n" + + return StreamingResponse(_error_stream(), media_type="application/x-ndjson") + + from a2a.types import ( + MessageSendParams, + SendStreamingMessageRequest, + ) async def stream_response(): try: @@ -75,16 +96,20 @@ async def _handle_stream_message( ): # Chunk may be dict or object depending on bridge vs standard path if hasattr(chunk, "model_dump"): - yield json.dumps(chunk.model_dump(mode="json", exclude_none=True)) + "\n" + yield json.dumps( + chunk.model_dump(mode="json", exclude_none=True) + ) + "\n" else: yield json.dumps(chunk) + "\n" except Exception as e: verbose_proxy_logger.exception(f"Error streaming A2A response: {e}") - yield json.dumps({ - "jsonrpc": "2.0", - "id": request_id, - "error": {"code": -32603, "message": f"Streaming error: {str(e)}"}, - }) + "\n" + yield json.dumps( + { + "jsonrpc": "2.0", + "id": request_id, + "error": {"code": -32603, "message": f"Streaming error: {str(e)}"}, + } + ) + "\n" return StreamingResponse(stream_response(), media_type="application/x-ndjson") @@ -169,9 +194,8 @@ async def invoke_agent_a2a( - message/send: Send a message and get a response - message/stream: Send a message and stream the response """ - from a2a.types import MessageSendParams, SendMessageRequest - from litellm.a2a_protocol import asend_message + from litellm.a2a_protocol.main import A2A_SDK_AVAILABLE from litellm.proxy.agent_endpoints.auth.agent_permission_handler import ( AgentRequestHandler, ) @@ -189,16 +213,28 @@ async def invoke_agent_a2a( # Validate JSON-RPC format if body.get("jsonrpc") != "2.0": - return _jsonrpc_error(body.get("id"), -32600, "Invalid Request: jsonrpc must be '2.0'") + return _jsonrpc_error( + body.get("id"), -32600, "Invalid Request: jsonrpc must be '2.0'" + ) request_id = body.get("id") method = body.get("method") params = body.get("params", {}) + if not A2A_SDK_AVAILABLE: + return _jsonrpc_error( + request_id, + -32603, + "Server error: 'a2a' package not installed. Please install 'a2a-sdk'.", + 500, + ) + # Find the agent agent = _get_agent(agent_id) if agent is None: - return _jsonrpc_error(request_id, -32000, f"Agent '{agent_id}' not found", 404) + return _jsonrpc_error( + request_id, -32000, f"Agent '{agent_id}' not found", 404 + ) is_allowed = await AgentRequestHandler.is_agent_allowed( agent_id=agent.agent_id, @@ -213,23 +249,29 @@ async def invoke_agent_a2a( # Get backend URL and agent name agent_url = agent.agent_card_params.get("url") agent_name = agent.agent_card_params.get("name", agent_id) - + # Get litellm_params (may include custom_llm_provider for completion bridge) litellm_params = agent.litellm_params or {} custom_llm_provider = litellm_params.get("custom_llm_provider") - + # URL is required unless using completion bridge with a provider that derives endpoint from model # (e.g., bedrock/agentcore derives endpoint from ARN in model string) if not agent_url and not custom_llm_provider: - return _jsonrpc_error(request_id, -32000, f"Agent '{agent_id}' has no URL configured", 500) + return _jsonrpc_error( + request_id, -32000, f"Agent '{agent_id}' has no URL configured", 500 + ) - verbose_proxy_logger.info(f"Proxying A2A request to agent '{agent_id}' at {agent_url or 'completion-bridge'}") + verbose_proxy_logger.info( + f"Proxying A2A request to agent '{agent_id}' at {agent_url or 'completion-bridge'}" + ) # Set up data dict for litellm processing - body.update({ - "model": f"a2a_agent/{agent_name}", - "custom_llm_provider": "a2a_agent", - }) + body.update( + { + "model": f"a2a_agent/{agent_name}", + "custom_llm_provider": "a2a_agent", + } + ) # Add litellm data (user_api_key, user_id, team_id, etc.) data = await add_litellm_data_to_request( @@ -243,6 +285,8 @@ async def invoke_agent_a2a( # Route through SDK functions if method == "message/send": + from a2a.types import MessageSendParams, SendMessageRequest + a2a_request = SendMessageRequest( id=request_id, params=MessageSendParams(**params), @@ -255,7 +299,9 @@ async def invoke_agent_a2a( metadata=data.get("metadata", {}), proxy_server_request=data.get("proxy_server_request"), ) - return JSONResponse(content=response.model_dump(mode="json", exclude_none=True)) + return JSONResponse( + content=response.model_dump(mode="json", exclude_none=True) + ) elif method == "message/stream": return await _handle_stream_message( diff --git a/litellm/proxy/anthropic_endpoints/claude_code_endpoints/__init__.py b/litellm/proxy/anthropic_endpoints/claude_code_endpoints/__init__.py new file mode 100644 index 0000000000..0d1a5a2083 --- /dev/null +++ b/litellm/proxy/anthropic_endpoints/claude_code_endpoints/__init__.py @@ -0,0 +1,11 @@ +""" +Claude Code Endpoints + +Provides endpoints for Claude Code plugin marketplace integration. +""" + +from litellm.proxy.anthropic_endpoints.claude_code_endpoints.claude_code_marketplace import ( + router as claude_code_marketplace_router, +) + +__all__ = ["claude_code_marketplace_router"] diff --git a/litellm/proxy/anthropic_endpoints/claude_code_endpoints/claude_code_marketplace.py b/litellm/proxy/anthropic_endpoints/claude_code_endpoints/claude_code_marketplace.py new file mode 100644 index 0000000000..7c212020a3 --- /dev/null +++ b/litellm/proxy/anthropic_endpoints/claude_code_endpoints/claude_code_marketplace.py @@ -0,0 +1,533 @@ +""" +CLAUDE CODE MARKETPLACE + +Provides a registry/discovery layer for Claude Code plugins. +Plugins are stored as metadata + git source references in LiteLLM database. +Actual plugin files are hosted on GitHub/GitLab/Bitbucket. + +Endpoints: +/claude-code/marketplace.json - GET - List plugins for Claude Code discovery +/claude-code/plugins - POST - Register a plugin +/claude-code/plugins - GET - List plugins (admin) +/claude-code/plugins/{name} - GET - Get plugin details +/claude-code/plugins/{name}/enable - POST - Enable a plugin +/claude-code/plugins/{name}/disable - POST - Disable a plugin +/claude-code/plugins/{name} - DELETE - Delete a plugin +""" + +import json +import re +from datetime import datetime, timezone +from typing import Any, Dict + +from fastapi import APIRouter, Depends, HTTPException +from fastapi.responses import JSONResponse + +from litellm._logging import verbose_proxy_logger +from litellm.proxy._types import CommonProxyErrors, UserAPIKeyAuth +from litellm.proxy.auth.user_api_key_auth import user_api_key_auth +from litellm.types.proxy.claude_code_endpoints import ( + ListPluginsResponse, + PluginListItem, + RegisterPluginRequest, +) + +router = APIRouter() + + +async def _get_prisma_client(): + """Get the prisma client from proxy_server.""" + from litellm.proxy.proxy_server import prisma_client + + if prisma_client is None: + raise HTTPException( + status_code=500, + detail={"error": CommonProxyErrors.db_not_connected_error.value}, + ) + return prisma_client + + +@router.get( + "/claude-code/marketplace.json", + tags=["Claude Code Marketplace"], +) +async def get_marketplace(): + """ + Serve marketplace.json for Claude Code plugin discovery. + + This endpoint is accessed by Claude Code CLI when users run: + - claude plugin marketplace add + - claude plugin install @ + + Returns: + Marketplace catalog with list of available plugins and their git sources. + + Example: + ```bash + claude plugin marketplace add http://localhost:4000/claude-code/marketplace.json + claude plugin install my-plugin@litellm + ``` + """ + try: + prisma_client = await _get_prisma_client() + + plugins = await prisma_client.db.litellm_claudecodeplugintable.find_many( + where={"enabled": True} + ) + + plugin_list = [] + for plugin in plugins: + try: + manifest = json.loads(plugin.manifest_json) + except json.JSONDecodeError: + verbose_proxy_logger.warning( + f"Plugin {plugin.name} has invalid manifest JSON, skipping" + ) + continue + + # Source must be specified for URL-based marketplaces + if "source" not in manifest: + verbose_proxy_logger.warning( + f"Plugin {plugin.name} has no source field, skipping" + ) + continue + + entry: Dict[str, Any] = { + "name": plugin.name, + "source": manifest["source"], + } + + if plugin.version: + entry["version"] = plugin.version + if plugin.description: + entry["description"] = plugin.description + if "author" in manifest: + entry["author"] = manifest["author"] + if "homepage" in manifest: + entry["homepage"] = manifest["homepage"] + if "keywords" in manifest: + entry["keywords"] = manifest["keywords"] + if "category" in manifest: + entry["category"] = manifest["category"] + + plugin_list.append(entry) + + marketplace = { + "name": "litellm", + "owner": {"name": "LiteLLM", "email": "support@litellm.ai"}, + "plugins": plugin_list, + } + + return JSONResponse(content=marketplace) + + except HTTPException: + raise + except Exception as e: + verbose_proxy_logger.exception(f"Error generating marketplace: {e}") + raise HTTPException( + status_code=500, + detail={"error": f"Failed to generate marketplace: {str(e)}"}, + ) + + +@router.post( + "/claude-code/plugins", + tags=["Claude Code Marketplace"], + dependencies=[Depends(user_api_key_auth)], +) +async def register_plugin( + request: RegisterPluginRequest, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): + """ + Register a plugin in the LiteLLM marketplace. + + LiteLLM acts as a registry/discovery layer. Plugins are hosted on + GitHub/GitLab/Bitbucket. Claude Code will clone from the git source + when users install. + + Parameters: + - name: Plugin name (kebab-case) + - source: Git source reference (github or url format) + - version: Semantic version (optional) + - description: Plugin description (optional) + - author: Author information (optional) + - homepage: Plugin homepage URL (optional) + - keywords: Search keywords (optional) + - category: Plugin category (optional) + + Returns: + Registration status and plugin information. + + Example: + ```bash + curl -X POST http://localhost:4000/claude-code/plugins \\ + -H "Authorization: Bearer sk-..." \\ + -H "Content-Type: application/json" \\ + -d '{ + "name": "my-plugin", + "source": {"source": "github", "repo": "org/my-plugin"}, + "version": "1.0.0", + "description": "My awesome plugin" + }' + ``` + """ + try: + prisma_client = await _get_prisma_client() + + # Validate name format + if not re.match(r"^[a-z0-9-]+$", request.name): + raise HTTPException( + status_code=400, + detail={ + "error": "Plugin name must be kebab-case (lowercase letters, numbers, hyphens)" + }, + ) + + # Validate source format + source = request.source + source_type = source.get("source") + + if source_type == "github": + if "repo" not in source: + raise HTTPException( + status_code=400, + detail={ + "error": "GitHub source must include 'repo' field (e.g., 'org/repo')" + }, + ) + elif source_type == "url": + if "url" not in source: + raise HTTPException( + status_code=400, + detail={ + "error": "URL source must include 'url' field (e.g., 'https://github.com/org/repo.git')" + }, + ) + else: + raise HTTPException( + status_code=400, + detail={"error": "source.source must be 'github' or 'url'"}, + ) + + # Build manifest for storage + manifest: Dict[str, Any] = { + "name": request.name, + "source": request.source, + } + if request.version: + manifest["version"] = request.version + if request.description: + manifest["description"] = request.description + if request.author: + manifest["author"] = request.author.model_dump(exclude_none=True) + if request.homepage: + manifest["homepage"] = request.homepage + if request.keywords: + manifest["keywords"] = request.keywords + if request.category: + manifest["category"] = request.category + + # Check if plugin exists + existing = await prisma_client.db.litellm_claudecodeplugintable.find_unique( + where={"name": request.name} + ) + + if existing: + plugin = await prisma_client.db.litellm_claudecodeplugintable.update( + where={"name": request.name}, + data={ + "version": request.version, + "description": request.description, + "manifest_json": json.dumps(manifest), + "files_json": "{}", + "updated_at": datetime.now(timezone.utc), + }, + ) + action = "updated" + else: + plugin = await prisma_client.db.litellm_claudecodeplugintable.create( + data={ + "name": request.name, + "version": request.version, + "description": request.description, + "manifest_json": json.dumps(manifest), + "files_json": "{}", + "enabled": True, + "created_at": datetime.now(timezone.utc), + "updated_at": datetime.now(timezone.utc), + "created_by": user_api_key_dict.user_id, + } + ) + action = "created" + + verbose_proxy_logger.info(f"Plugin {request.name} {action} successfully") + + return { + "status": "success", + "action": action, + "plugin": { + "id": plugin.id, + "name": plugin.name, + "version": plugin.version, + "description": plugin.description, + "source": request.source, + "enabled": plugin.enabled, + }, + } + + except HTTPException: + raise + except Exception as e: + verbose_proxy_logger.exception(f"Error registering plugin: {e}") + raise HTTPException( + status_code=500, + detail={"error": f"Registration failed: {str(e)}"}, + ) + + +@router.get( + "/claude-code/plugins", + tags=["Claude Code Marketplace"], + dependencies=[Depends(user_api_key_auth)], + response_model=ListPluginsResponse, +) +async def list_plugins( + enabled_only: bool = False, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): + """ + List all plugins in the marketplace. + + Parameters: + - enabled_only: If true, only return enabled plugins + + Returns: + List of plugins with their metadata. + """ + try: + prisma_client = await _get_prisma_client() + + where = {"enabled": True} if enabled_only else {} + plugins = await prisma_client.db.litellm_claudecodeplugintable.find_many( + where=where, + order_by={"created_at": "desc"}, + ) + + return ListPluginsResponse( + plugins=[ + PluginListItem( + id=p.id, + name=p.name, + version=p.version, + description=p.description, + enabled=p.enabled, + created_at=p.created_at.isoformat() if p.created_at else None, + updated_at=p.updated_at.isoformat() if p.updated_at else None, + ) + for p in plugins + ], + count=len(plugins), + ) + + except HTTPException: + raise + except Exception as e: + verbose_proxy_logger.exception(f"Error listing plugins: {e}") + raise HTTPException( + status_code=500, + detail={"error": str(e)}, + ) + + +@router.get( + "/claude-code/plugins/{plugin_name}", + tags=["Claude Code Marketplace"], + dependencies=[Depends(user_api_key_auth)], +) +async def get_plugin( + plugin_name: str, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): + """ + Get details of a specific plugin. + + Parameters: + - plugin_name: The name of the plugin + + Returns: + Plugin details including source and metadata. + """ + try: + prisma_client = await _get_prisma_client() + + plugin = await prisma_client.db.litellm_claudecodeplugintable.find_unique( + where={"name": plugin_name} + ) + + if not plugin: + raise HTTPException( + status_code=404, + detail={"error": f"Plugin '{plugin_name}' not found"}, + ) + + manifest = json.loads(plugin.manifest_json) if plugin.manifest_json else {} + + return { + "id": plugin.id, + "name": plugin.name, + "version": plugin.version, + "description": plugin.description, + "source": manifest.get("source"), + "author": manifest.get("author"), + "homepage": manifest.get("homepage"), + "keywords": manifest.get("keywords"), + "category": manifest.get("category"), + "enabled": plugin.enabled, + "created_at": plugin.created_at.isoformat() if plugin.created_at else None, + "updated_at": plugin.updated_at.isoformat() if plugin.updated_at else None, + "created_by": plugin.created_by, + } + + except HTTPException: + raise + except Exception as e: + verbose_proxy_logger.exception(f"Error getting plugin: {e}") + raise HTTPException( + status_code=500, + detail={"error": str(e)}, + ) + + +@router.post( + "/claude-code/plugins/{plugin_name}/enable", + tags=["Claude Code Marketplace"], + dependencies=[Depends(user_api_key_auth)], +) +async def enable_plugin( + plugin_name: str, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): + """ + Enable a disabled plugin. + + Parameters: + - plugin_name: The name of the plugin to enable + """ + try: + prisma_client = await _get_prisma_client() + + plugin = await prisma_client.db.litellm_claudecodeplugintable.find_unique( + where={"name": plugin_name} + ) + if not plugin: + raise HTTPException( + status_code=404, + detail={"error": f"Plugin '{plugin_name}' not found"}, + ) + + await prisma_client.db.litellm_claudecodeplugintable.update( + where={"name": plugin_name}, + data={"enabled": True, "updated_at": datetime.now(timezone.utc)}, + ) + + verbose_proxy_logger.info(f"Plugin {plugin_name} enabled") + return {"status": "success", "message": f"Plugin '{plugin_name}' enabled"} + + except HTTPException: + raise + except Exception as e: + verbose_proxy_logger.exception(f"Error enabling plugin: {e}") + raise HTTPException( + status_code=500, + detail={"error": str(e)}, + ) + + +@router.post( + "/claude-code/plugins/{plugin_name}/disable", + tags=["Claude Code Marketplace"], + dependencies=[Depends(user_api_key_auth)], +) +async def disable_plugin( + plugin_name: str, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): + """ + Disable a plugin without deleting it. + + Parameters: + - plugin_name: The name of the plugin to disable + """ + try: + prisma_client = await _get_prisma_client() + + plugin = await prisma_client.db.litellm_claudecodeplugintable.find_unique( + where={"name": plugin_name} + ) + if not plugin: + raise HTTPException( + status_code=404, + detail={"error": f"Plugin '{plugin_name}' not found"}, + ) + + await prisma_client.db.litellm_claudecodeplugintable.update( + where={"name": plugin_name}, + data={"enabled": False, "updated_at": datetime.now(timezone.utc)}, + ) + + verbose_proxy_logger.info(f"Plugin {plugin_name} disabled") + return {"status": "success", "message": f"Plugin '{plugin_name}' disabled"} + + except HTTPException: + raise + except Exception as e: + verbose_proxy_logger.exception(f"Error disabling plugin: {e}") + raise HTTPException( + status_code=500, + detail={"error": str(e)}, + ) + + +@router.delete( + "/claude-code/plugins/{plugin_name}", + tags=["Claude Code Marketplace"], + dependencies=[Depends(user_api_key_auth)], +) +async def delete_plugin( + plugin_name: str, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): + """ + Delete a plugin from the marketplace. + + Parameters: + - plugin_name: The name of the plugin to delete + """ + try: + prisma_client = await _get_prisma_client() + + plugin = await prisma_client.db.litellm_claudecodeplugintable.find_unique( + where={"name": plugin_name} + ) + if not plugin: + raise HTTPException( + status_code=404, + detail={"error": f"Plugin '{plugin_name}' not found"}, + ) + + await prisma_client.db.litellm_claudecodeplugintable.delete( + where={"name": plugin_name} + ) + + verbose_proxy_logger.info(f"Plugin {plugin_name} deleted") + return {"status": "success", "message": f"Plugin '{plugin_name}' deleted"} + + except HTTPException: + raise + except Exception as e: + verbose_proxy_logger.exception(f"Error deleting plugin: {e}") + raise HTTPException( + status_code=500, + detail={"error": str(e)}, + ) diff --git a/litellm/proxy/image_endpoints/endpoints.py b/litellm/proxy/image_endpoints/endpoints.py index a1453e10db..4a2c05f859 100644 --- a/litellm/proxy/image_endpoints/endpoints.py +++ b/litellm/proxy/image_endpoints/endpoints.py @@ -244,8 +244,10 @@ async def image_edit_api( if mask is None and mask_array is not None: mask = mask_array - if image is None: - raise HTTPException(status_code=422, detail="Field required: image") + # if image is None: + # raise HTTPException(status_code=422, detail="Field required: image") + # Note: Image is optional for some models (e.g., Bedrock Stability style-transfer) + # The validation will be done at the model level if image is truly required from litellm.proxy.proxy_server import ( _read_request_body, @@ -272,6 +274,10 @@ async def image_edit_api( data["image"] = image_files if mask_files: data["mask"] = mask_files + + # Ensure prompt exists in data (default to None for models that don't require it) + if "prompt" not in data: + data["prompt"] = None data["model"] = ( model diff --git a/litellm/proxy/openai_files_endpoints/files_endpoints.py b/litellm/proxy/openai_files_endpoints/files_endpoints.py index 7e3f582081..2c6b378ae3 100644 --- a/litellm/proxy/openai_files_endpoints/files_endpoints.py +++ b/litellm/proxy/openai_files_endpoints/files_endpoints.py @@ -146,8 +146,8 @@ async def route_create_file( Priority: 1. If target_storage is specified and not "default" -> use storage backend 2. If model parameter provided -> use model credentials and encode ID - 3. If enable_loadbalancing_on_batch_endpoints -> deprecated loadbalancing - 4. If target_model_names_list -> managed files (requires DB) + 3. If target_model_names_list -> managed files (requires DB, supports loadbalancing) + 4. If enable_loadbalancing_on_batch_endpoints -> deprecated loadbalancing 5. Else -> use custom_llm_provider with files_settings """ @@ -202,18 +202,9 @@ async def route_create_file( return response - # EXISTING: Deprecated loadbalancing approach - if ( - litellm.enable_loadbalancing_on_batch_endpoints is True - and is_router_model - and router_model is not None - ): - response = await _deprecated_loadbalanced_create_file( - llm_router=llm_router, - router_model=router_model, - _create_file_request=_create_file_request, - ) - elif target_model_names_list: + # Handle managed files (supports loadbalancing via llm_router.acreate_file) + # Priority: Check for managed files BEFORE deprecated loadbalancing + if target_model_names_list: managed_files_obj = proxy_logging_obj.get_proxy_hook("managed_files") if managed_files_obj is None: raise ProxyException( @@ -236,6 +227,7 @@ async def route_create_file( param="None", code=500, ) + # Managed files internally calls llm_router.acreate_file() which includes loadbalancing response = await managed_files_obj.acreate_file( llm_router=llm_router, create_file_request=_create_file_request, @@ -243,6 +235,17 @@ async def route_create_file( litellm_parent_otel_span=user_api_key_dict.parent_otel_span, user_api_key_dict=user_api_key_dict, ) + # EXISTING: Deprecated loadbalancing approach (for backwards compatibility when not using managed files) + elif ( + litellm.enable_loadbalancing_on_batch_endpoints is True + and is_router_model + and router_model is not None + ): + response = await _deprecated_loadbalanced_create_file( + llm_router=llm_router, + router_model=router_model, + _create_file_request=_create_file_request, + ) else: # get configs for custom_llm_provider llm_provider_config = get_files_provider_config( diff --git a/litellm/proxy/proxy_config.yaml b/litellm/proxy/proxy_config.yaml index cf852805f8..646a062b72 100644 --- a/litellm/proxy/proxy_config.yaml +++ b/litellm/proxy/proxy_config.yaml @@ -2,7 +2,7 @@ model_list: - model_name: gemini/* litellm_params: model: gemini/* - - model_name: claude-sonnet-4-5-20250929 + - model_name: -claude-sonnet-4-5-20250929 litellm_params: model: bedrock/invoke/us.anthropic.claude-sonnet-4-5-20250929-v1:0 model_info: @@ -40,7 +40,7 @@ model_list: model_info: litellm_provider: bedrock_converse mode: chat - - model_name: azure-claude-opus-4-5 + - model_name: claude-sonnet-4-5-20250929 litellm_params: model: azure_ai/claude-opus-4-5 api_base: https://krish-mh44t553-eastus2.services.ai.azure.com diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index f4e68c481a..c3a4de314e 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -207,6 +207,9 @@ from litellm.proxy.anthropic_endpoints.endpoints import router as anthropic_rout from litellm.proxy.anthropic_endpoints.skills_endpoints import ( router as anthropic_skills_router, ) +from litellm.proxy.anthropic_endpoints.claude_code_endpoints import ( + claude_code_marketplace_router, +) from litellm.proxy.auth.auth_checks import ( ExperimentalUIJWTToken, get_team_object, @@ -10499,6 +10502,7 @@ app.include_router(llm_passthrough_router) app.include_router(mcp_management_router) app.include_router(anthropic_router) app.include_router(anthropic_skills_router) +app.include_router(claude_code_marketplace_router) app.include_router(google_router) app.include_router(langfuse_router) app.include_router(pass_through_router) diff --git a/litellm/proxy/schema.prisma b/litellm/proxy/schema.prisma index 71b398c59a..22888f6d3a 100644 --- a/litellm/proxy/schema.prisma +++ b/litellm/proxy/schema.prisma @@ -863,3 +863,20 @@ model LiteLLM_SkillsTable { updated_at DateTime @default(now()) @updatedAt updated_by String? } + +// Claude Code Marketplace - stores plugins for Claude Code integration +model LiteLLM_ClaudeCodePluginTable { + id String @id @default(uuid()) + name String @unique // Plugin name (kebab-case) + version String? // Semantic version + description String? // Plugin description + manifest_json String // Full plugin.json as JSON string + files_json String // All files as JSON: {"path": "content"} + enabled Boolean @default(true) + created_at DateTime @default(now()) + updated_at DateTime @default(now()) @updatedAt + created_by String? + + @@index([name]) + @@map("litellm_claudecodeplugin") +} diff --git a/litellm/proxy/vector_store_endpoints/management_endpoints.py b/litellm/proxy/vector_store_endpoints/management_endpoints.py index 661f94e5f0..bc61a60fe5 100644 --- a/litellm/proxy/vector_store_endpoints/management_endpoints.py +++ b/litellm/proxy/vector_store_endpoints/management_endpoints.py @@ -245,6 +245,7 @@ async def list_vector_stores( """ List all available vector stores with optional filtering and pagination. Combines both in-memory vector stores and those stored in the database. + Database is the source of truth - deleted stores are removed from memory, updated stores sync to memory. Parameters: - page: int - Page number for pagination (default: 1) @@ -252,29 +253,65 @@ async def list_vector_stores( """ from litellm.proxy.proxy_server import prisma_client - seen_vector_store_ids = set() + vector_store_map: Dict[str, LiteLLM_ManagedVectorStore] = {} + db_vector_store_ids: set = set() try: - # Get in-memory vector stores - in_memory_vector_stores: List[LiteLLM_ManagedVectorStore] = [] + # Get vector stores from database first (source of truth) + vector_stores_from_db = await VectorStoreRegistry._get_vector_stores_from_db( + prisma_client=prisma_client + ) + + # Build map from database vector stores + for vector_store in vector_stores_from_db: + vector_store_id = vector_store.get("vector_store_id", None) + if vector_store_id: + vector_store_map[vector_store_id] = vector_store + db_vector_store_ids.add(vector_store_id) + + # Process in-memory vector stores if litellm.vector_store_registry is not None: in_memory_vector_stores = copy.deepcopy( litellm.vector_store_registry.vector_stores ) + + vector_stores_to_delete_from_memory: List[str] = [] + + for vector_store in in_memory_vector_stores: + vector_store_id = vector_store.get("vector_store_id", None) + if not vector_store_id: + continue + + # If vector store is in memory but NOT in database, it was deleted + if vector_store_id not in db_vector_store_ids: + verbose_proxy_logger.info( + f"Vector store {vector_store_id} exists in memory but not in database - marking for deletion from cache" + ) + vector_stores_to_delete_from_memory.append(vector_store_id) + # If not in our map yet, add it (only in-memory, not in DB) + elif vector_store_id not in vector_store_map: + vector_store_map[vector_store_id] = vector_store + + # Synchronize in-memory registry with database + # 1. Remove deleted vector stores from memory + for vs_id in vector_stores_to_delete_from_memory: + litellm.vector_store_registry.delete_vector_store_from_registry( + vector_store_id=vs_id + ) + verbose_proxy_logger.debug( + f"Removed deleted vector store {vs_id} from in-memory registry" + ) + + # 2. Update in-memory registry with database versions (for updates) + for vector_store in vector_stores_from_db: + vector_store_id = vector_store.get("vector_store_id", None) + if vector_store_id: + litellm.vector_store_registry.update_vector_store_in_registry( + vector_store_id=vector_store_id, + updated_data=vector_store + ) - # Get vector stores from database - vector_stores_from_db = await VectorStoreRegistry._get_vector_stores_from_db( - prisma_client=prisma_client - ) - - # Combine in-memory and database vector stores - combined_vector_stores: List[LiteLLM_ManagedVectorStore] = [] - for vector_store in in_memory_vector_stores + vector_stores_from_db: - vector_store_id = vector_store.get("vector_store_id", None) - if vector_store_id not in seen_vector_store_ids: - combined_vector_stores.append(vector_store) - seen_vector_store_ids.add(vector_store_id) - + combined_vector_stores = list(vector_store_map.values()) total_count = len(combined_vector_stores) total_pages = (total_count + page_size - 1) // page_size @@ -303,7 +340,7 @@ async def delete_vector_store( user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), ): """ - Delete a vector store. + Delete a vector store from both database and in-memory registry. Parameters: - vector_store_id: str - ID of the vector store to delete @@ -314,31 +351,53 @@ async def delete_vector_store( raise HTTPException(status_code=500, detail="Database not connected") try: - # Check if vector store exists + # Check if vector store exists in database or in-memory registry + db_vector_store_exists = False + memory_vector_store_exists = False + existing_vector_store = ( await prisma_client.db.litellm_managedvectorstorestable.find_unique( where={"vector_store_id": data.vector_store_id} ) ) - if existing_vector_store is None: + if existing_vector_store is not None: + db_vector_store_exists = True + + # Check in-memory registry + if litellm.vector_store_registry is not None: + memory_vector_store = litellm.vector_store_registry.get_litellm_managed_vector_store_from_registry( + vector_store_id=data.vector_store_id + ) + if memory_vector_store is not None: + memory_vector_store_exists = True + + # If not found in either location, raise 404 + if not db_vector_store_exists and not memory_vector_store_exists: raise HTTPException( status_code=404, detail=f"Vector store with ID {data.vector_store_id} not found", ) - # Delete vector store - await prisma_client.db.litellm_managedvectorstorestable.delete( - where={"vector_store_id": data.vector_store_id} - ) + # Delete from database if exists + if db_vector_store_exists: + await prisma_client.db.litellm_managedvectorstorestable.delete( + where={"vector_store_id": data.vector_store_id} + ) - # Delete vector store from registry - if litellm.vector_store_registry is not None: + # Delete from in-memory registry if exists + if memory_vector_store_exists and litellm.vector_store_registry is not None: litellm.vector_store_registry.delete_vector_store_from_registry( vector_store_id=data.vector_store_id ) - return {"message": f"Vector store {data.vector_store_id} deleted successfully"} + return { + "status": "success", + "message": f"Vector store {data.vector_store_id} deleted successfully" + } + except HTTPException: + raise except Exception as e: + verbose_proxy_logger.exception(f"Error deleting vector store: {str(e)}") raise HTTPException(status_code=500, detail=str(e)) @@ -415,8 +474,12 @@ async def update_vector_store( data: VectorStoreUpdateRequest, user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), ): - """Update vector store details""" + """ + Update vector store details in both database and in-memory registry. + The updated data is immediately synchronized to the in-memory registry. + """ from litellm.proxy.proxy_server import prisma_client + from litellm.types.router import GenericLiteLLMParams if prisma_client is None: raise HTTPException(status_code=500, detail="Database not connected") @@ -424,11 +487,36 @@ async def update_vector_store( try: update_data = data.model_dump(exclude_unset=True) vector_store_id = update_data.pop("vector_store_id") + + # Handle metadata serialization if update_data.get("vector_store_metadata") is not None: update_data["vector_store_metadata"] = safe_dumps( update_data["vector_store_metadata"] ) + + # Handle litellm_params if provided + if "litellm_params" in update_data: + _input_litellm_params: dict = update_data.get("litellm_params", {}) or {} + + # Auto-resolve embedding config if embedding model is provided but config is not + embedding_model = _input_litellm_params.get("litellm_embedding_model") + if embedding_model and not _input_litellm_params.get("litellm_embedding_config"): + resolved_config = await _resolve_embedding_config_from_db( + embedding_model=embedding_model, + prisma_client=prisma_client + ) + if resolved_config: + _input_litellm_params["litellm_embedding_config"] = resolved_config + verbose_proxy_logger.info( + f"Auto-resolved embedding config for model {embedding_model}" + ) + + litellm_params_dict = GenericLiteLLMParams( + **_input_litellm_params + ).model_dump(exclude_none=True) + update_data["litellm_params"] = safe_dumps(litellm_params_dict) + # Update in database updated = await prisma_client.db.litellm_managedvectorstorestable.update( where={"vector_store_id": vector_store_id}, data=update_data, @@ -436,13 +524,21 @@ async def update_vector_store( updated_vs = LiteLLM_ManagedVectorStore(**updated.model_dump()) + # Immediately update in-memory registry to keep it in sync if litellm.vector_store_registry is not None: litellm.vector_store_registry.update_vector_store_in_registry( vector_store_id=vector_store_id, updated_data=updated_vs, ) + verbose_proxy_logger.debug( + f"Updated vector store {vector_store_id} in both database and in-memory registry" + ) - return {"vector_store": updated_vs} + return { + "status": "success", + "message": f"Vector store {vector_store_id} updated successfully", + "vector_store": updated_vs + } except Exception as e: verbose_proxy_logger.exception(f"Error updating vector store: {str(e)}") raise HTTPException(status_code=500, detail=str(e)) diff --git a/litellm/responses/streaming_iterator.py b/litellm/responses/streaming_iterator.py index 0b838f916e..7abd4f90f2 100644 --- a/litellm/responses/streaming_iterator.py +++ b/litellm/responses/streaming_iterator.py @@ -55,6 +55,7 @@ class BaseResponsesAPIStreamingIterator: self.responses_api_provider_config = responses_api_provider_config self.completed_response: Optional[ResponsesAPIStreamingResponse] = None self.start_time = getattr(logging_obj, "start_time", datetime.now()) + self._failure_handled = False # Track if failure handler has been called # track request context for hooks self.litellm_metadata = litellm_metadata @@ -169,7 +170,8 @@ class BaseResponsesAPIStreamingIterator: # If we can't parse the chunk, continue return None except Exception as e: - # Ensure failures trigger failure hooks + # Trigger failure hooks before re-raising + # This ensures failures are logged even when _process_chunk is called directly self._handle_failure(e) raise @@ -287,7 +289,13 @@ class BaseResponsesAPIStreamingIterator: def _handle_failure(self, exception: Exception): """ Trigger failure handlers before bubbling the exception. + Only calls handlers once even if called multiple times. """ + # Prevent double-calling failure handlers + if self._failure_handled: + return + self._failure_handled = True + traceback_exception = traceback.format_exc() try: run_async_function( @@ -383,11 +391,20 @@ class ResponsesAPIStreamingIterator(BaseResponsesAPIStreamingIterator): def _handle_logging_completed_response(self): """Handle logging for completed responses in async context""" - # Create a deep copy for logging to avoid modifying the response object that will be returned to the user + # Create a copy for logging to avoid modifying the response object that will be returned to the user # The logging handlers may transform usage from Responses API format (input_tokens/output_tokens) # to chat completion format (prompt_tokens/completion_tokens) for internal logging - import copy - logging_response = copy.deepcopy(self.completed_response) + # Use model_dump + model_validate instead of deepcopy to avoid pickle errors with + # Pydantic ValidatorIterator when response contains tool_choice with allowed_tools (fixes #17192) + logging_response = self.completed_response + if self.completed_response is not None and hasattr(self.completed_response, 'model_dump'): + try: + logging_response = type(self.completed_response).model_validate( + self.completed_response.model_dump() + ) + except Exception: + # Fallback to original if serialization fails + pass asyncio.create_task( self.logging_obj.async_success_handler( @@ -469,11 +486,20 @@ class SyncResponsesAPIStreamingIterator(BaseResponsesAPIStreamingIterator): def _handle_logging_completed_response(self): """Handle logging for completed responses in sync context""" - # Create a deep copy for logging to avoid modifying the response object that will be returned to the user + # Create a copy for logging to avoid modifying the response object that will be returned to the user # The logging handlers may transform usage from Responses API format (input_tokens/output_tokens) # to chat completion format (prompt_tokens/completion_tokens) for internal logging - import copy - logging_response = copy.deepcopy(self.completed_response) + # Use model_dump + model_validate instead of deepcopy to avoid pickle errors with + # Pydantic ValidatorIterator when response contains tool_choice with allowed_tools (fixes #17192) + logging_response = self.completed_response + if self.completed_response is not None and hasattr(self.completed_response, 'model_dump'): + try: + logging_response = type(self.completed_response).model_validate( + self.completed_response.model_dump() + ) + except Exception: + # Fallback to original if serialization fails + pass run_async_function( async_function=self.logging_obj.async_success_handler, diff --git a/litellm/types/integrations/gcs_bucket.py b/litellm/types/integrations/gcs_bucket.py index 2be2acab2f..b297246f4f 100644 --- a/litellm/types/integrations/gcs_bucket.py +++ b/litellm/types/integrations/gcs_bucket.py @@ -12,6 +12,7 @@ else: GCS_DEFAULT_BATCH_SIZE = 2048 GCS_DEFAULT_FLUSH_INTERVAL_SECONDS = 20 +GCS_DEFAULT_USE_BATCHED_LOGGING = True class GCSLoggingConfig(TypedDict): diff --git a/litellm/types/proxy/claude_code_endpoints.py b/litellm/types/proxy/claude_code_endpoints.py new file mode 100644 index 0000000000..663b182b80 --- /dev/null +++ b/litellm/types/proxy/claude_code_endpoints.py @@ -0,0 +1,116 @@ +""" +Claude Code Marketplace endpoint types for LiteLLM Proxy +""" + +from typing import Dict, List, Optional + +from pydantic import BaseModel, Field + + +class PluginAuthor(BaseModel): + """Plugin author information.""" + + name: str = Field(..., description="Author name") + email: Optional[str] = Field(None, description="Author email") + + +class PluginOwner(BaseModel): + """Marketplace owner information.""" + + name: str = Field(..., description="Owner name") + email: Optional[str] = Field(None, description="Owner email") + + +class RegisterPluginRequest(BaseModel): + """ + Request body for registering a plugin in the marketplace. + + LiteLLM acts as a registry/discovery layer. Plugins are hosted on + GitHub/GitLab/Bitbucket and referenced by their git source. + """ + + name: str = Field( + ..., + description="Plugin name (kebab-case, e.g., 'my-plugin')", + pattern=r"^[a-z0-9-]+$", + ) + source: Dict[str, str] = Field( + ..., + description=( + "Git source reference. Supported formats:\n" + "- GitHub: {'source': 'github', 'repo': 'org/repo'}\n" + "- Git URL: {'source': 'url', 'url': 'https://github.com/org/repo.git'}" + ), + ) + version: Optional[str] = Field("1.0.0", description="Semantic version") + description: Optional[str] = Field(None, description="Plugin description") + author: Optional[PluginAuthor] = Field(None, description="Plugin author") + homepage: Optional[str] = Field(None, description="Plugin homepage URL") + keywords: Optional[List[str]] = Field(None, description="Search keywords") + category: Optional[str] = Field(None, description="Plugin category") + + +class PluginResponse(BaseModel): + """Plugin information in API responses.""" + + id: str = Field(..., description="Plugin unique ID") + name: str = Field(..., description="Plugin name") + version: Optional[str] = Field(None, description="Plugin version") + description: Optional[str] = Field(None, description="Plugin description") + source: Dict[str, str] = Field(..., description="Git source reference") + enabled: bool = Field(..., description="Whether plugin is enabled") + + +class RegisterPluginResponse(BaseModel): + """Response from plugin registration.""" + + status: str = Field(..., description="Operation status") + action: str = Field(..., description="Action taken (created/updated)") + plugin: PluginResponse = Field(..., description="Plugin information") + + +class PluginListItem(BaseModel): + """Plugin item in list responses.""" + + id: str + name: str + version: Optional[str] + description: Optional[str] + enabled: bool + created_at: Optional[str] + updated_at: Optional[str] + + +class ListPluginsResponse(BaseModel): + """Response from listing plugins.""" + + plugins: List[PluginListItem] + count: int + + +class MarketplacePluginEntry(BaseModel): + """Plugin entry in marketplace.json.""" + + name: str + source: Dict[str, str] + version: Optional[str] = None + description: Optional[str] = None + author: Optional[PluginAuthor] = None + homepage: Optional[str] = None + keywords: Optional[List[str]] = None + category: Optional[str] = None + + +class MarketplaceResponse(BaseModel): + """ + Marketplace catalog response. + + This format is consumed by Claude Code CLI. + See: https://docs.anthropic.com/en/docs/claude-code/plugins + """ + + name: str = Field(..., description="Marketplace identifier") + owner: PluginOwner = Field(..., description="Marketplace owner") + plugins: List[MarketplacePluginEntry] = Field( + default_factory=list, description="Available plugins" + ) diff --git a/model_prices_and_context_window.json b/model_prices_and_context_window.json index 470d598a25..87f3566ae8 100644 --- a/model_prices_and_context_window.json +++ b/model_prices_and_context_window.json @@ -7857,6 +7857,24 @@ "supports_tool_choice": true, "supports_vision": true }, + "dall-e-2": { + "input_cost_per_image": 0.02, + "litellm_provider": "openai", + "mode": "image_generation", + "supported_endpoints": [ + "/v1/images/generations", + "/v1/images/edits", + "/v1/images/variations" + ] + }, + "dall-e-3": { + "input_cost_per_image": 0.04, + "litellm_provider": "openai", + "mode": "image_generation", + "supported_endpoints": [ + "/v1/images/generations" + ] + }, "deepseek-chat": { "cache_read_input_token_cost": 2.8e-08, "input_cost_per_token": 2.8e-07, @@ -18808,13 +18826,14 @@ "supports_tool_choice": true }, "groq/openai/gpt-oss-120b": { + "cache_read_input_token_cost": 7.5e-08, "input_cost_per_token": 1.5e-07, "litellm_provider": "groq", "max_input_tokens": 131072, "max_output_tokens": 32766, "max_tokens": 32766, "mode": "chat", - "output_cost_per_token": 7.5e-07, + "output_cost_per_token": 6e-07, "supports_function_calling": true, "supports_parallel_function_calling": true, "supports_reasoning": true, @@ -18823,13 +18842,14 @@ "supports_web_search": true }, "groq/openai/gpt-oss-20b": { - "input_cost_per_token": 1e-07, + "cache_read_input_token_cost": 3.75e-08, + "input_cost_per_token": 7.5e-08, "litellm_provider": "groq", "max_input_tokens": 131072, "max_output_tokens": 32768, "max_tokens": 32768, "mode": "chat", - "output_cost_per_token": 5e-07, + "output_cost_per_token": 3e-07, "supports_function_calling": true, "supports_parallel_function_calling": true, "supports_reasoning": true, diff --git a/poetry.lock b/poetry.lock index 5a3b6b8487..ac7076ea01 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,5 +1,37 @@ # This file is automatically @generated by Poetry 2.1.4 and should not be changed by hand. +[[package]] +name = "a2a-sdk" +version = "0.3.22" +description = "A2A Python SDK" +optional = true +python-versions = ">=3.10" +groups = ["main"] +markers = "python_version >= \"3.10\" and extra == \"extra-proxy\"" +files = [ + {file = "a2a_sdk-0.3.22-py3-none-any.whl", hash = "sha256:b98701135bb90b0ff85d35f31533b6b7a299bf810658c1c65f3814a6c15ea385"}, + {file = "a2a_sdk-0.3.22.tar.gz", hash = "sha256:77a5694bfc4f26679c11b70c7f1062522206d430b34bc1215cfbb1eba67b7e7d"}, +] + +[package.dependencies] +google-api-core = ">=1.26.0" +httpx = ">=0.28.1" +httpx-sse = ">=0.4.0" +protobuf = ">=5.29.5" +pydantic = ">=2.11.3" + +[package.extras] +all = ["cryptography (>=43.0.0)", "fastapi (>=0.115.2)", "grpcio (>=1.60)", "grpcio-reflection (>=1.7.0)", "grpcio-tools (>=1.60)", "opentelemetry-api (>=1.33.0)", "opentelemetry-sdk (>=1.33.0)", "pyjwt (>=2.0.0)", "sqlalchemy[aiomysql,asyncio] (>=2.0.0)", "sqlalchemy[aiosqlite,asyncio] (>=2.0.0)", "sqlalchemy[asyncio,postgresql-asyncpg] (>=2.0.0)", "sse-starlette", "starlette"] +encryption = ["cryptography (>=43.0.0)"] +grpc = ["grpcio (>=1.60)", "grpcio-reflection (>=1.7.0)", "grpcio-tools (>=1.60)"] +http-server = ["fastapi (>=0.115.2)", "sse-starlette", "starlette"] +mysql = ["sqlalchemy[aiomysql,asyncio] (>=2.0.0)"] +postgresql = ["sqlalchemy[asyncio,postgresql-asyncpg] (>=2.0.0)"] +signing = ["pyjwt (>=2.0.0)"] +sql = ["sqlalchemy[aiomysql,asyncio] (>=2.0.0)", "sqlalchemy[aiosqlite,asyncio] (>=2.0.0)", "sqlalchemy[asyncio,postgresql-asyncpg] (>=2.0.0)"] +sqlite = ["sqlalchemy[aiosqlite,asyncio] (>=2.0.0)"] +telemetry = ["opentelemetry-api (>=1.33.0)", "opentelemetry-sdk (>=1.33.0)"] + [[package]] name = "aiofiles" version = "24.1.0" @@ -1268,25 +1300,6 @@ dev = ["autoflake", "black", "build", "databricks-connect", "httpx", "ipython", notebook = ["ipython (>=8,<10)", "ipywidgets (>=8,<9)"] openai = ["httpx", "langchain-openai ; python_version > \"3.7\"", "openai"] -[[package]] -name = "deprecated" -version = "1.3.1" -description = "Python @deprecated decorator to deprecate old python classes, functions or methods." -optional = false -python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,>=2.7" -groups = ["main", "dev", "proxy-dev"] -files = [ - {file = "deprecated-1.3.1-py2.py3-none-any.whl", hash = "sha256:597bfef186b6f60181535a29fbe44865ce137a5079f295b479886c82729d5f3f"}, - {file = "deprecated-1.3.1.tar.gz", hash = "sha256:b1b50e0ff0c1fddaa5708a2c6b0a6588bb09b892825ab2b214ac9ea9d92a5223"}, -] -markers = {main = "python_version >= \"3.10\""} - -[package.dependencies] -wrapt = ">=1.10,<3" - -[package.extras] -dev = ["PyTest", "PyTest-Cov", "bump2version (<1)", "setuptools ; python_version >= \"3.12\"", "tox"] - [[package]] name = "diskcache" version = "5.6.3" @@ -2509,7 +2522,7 @@ description = "Consume Server-Sent Event (SSE) messages with HTTPX." optional = true python-versions = ">=3.9" groups = ["main"] -markers = "python_version >= \"3.10\" and extra == \"proxy\"" +markers = "python_version >= \"3.10\" and (extra == \"proxy\" or extra == \"extra-proxy\")" files = [ {file = "httpx_sse-0.4.3-py3-none-any.whl", hash = "sha256:0ac1c9fe3c0afad2e0ebb25a934a59f4c7823b60792691f779fad2c5568830fc"}, {file = "httpx_sse-0.4.3.tar.gz", hash = "sha256:9b1ed0127459a66014aec3c56bebd93da3c1bc8bb6618c8082039a44889a755d"}, @@ -4024,143 +4037,153 @@ voice-helpers = ["numpy (>=2.0.2)", "sounddevice (>=0.5.1)"] [[package]] name = "opentelemetry-api" -version = "1.25.0" +version = "1.39.1" description = "OpenTelemetry Python API" optional = false -python-versions = ">=3.8" +python-versions = ">=3.9" groups = ["main", "dev", "proxy-dev"] files = [ - {file = "opentelemetry_api-1.25.0-py3-none-any.whl", hash = "sha256:757fa1aa020a0f8fa139f8959e53dec2051cc26b832e76fa839a6d76ecefd737"}, - {file = "opentelemetry_api-1.25.0.tar.gz", hash = "sha256:77c4985f62f2614e42ce77ee4c9da5fa5f0bc1e1821085e9a47533a9323ae869"}, + {file = "opentelemetry_api-1.39.1-py3-none-any.whl", hash = "sha256:2edd8463432a7f8443edce90972169b195e7d6a05500cd29e6d13898187c9950"}, + {file = "opentelemetry_api-1.39.1.tar.gz", hash = "sha256:fbde8c80e1b937a2c61f20347e91c0c18a1940cecf012d62e65a7caf08967c9c"}, ] markers = {main = "python_version >= \"3.10\""} [package.dependencies] -deprecated = ">=1.2.6" -importlib-metadata = ">=6.0,<=7.1" +importlib-metadata = ">=6.0,<8.8.0" +typing-extensions = ">=4.5.0" [[package]] name = "opentelemetry-exporter-otlp" -version = "1.25.0" +version = "1.39.1" description = "OpenTelemetry Collector Exporters" optional = false -python-versions = ">=3.8" +python-versions = ">=3.9" groups = ["dev", "proxy-dev"] files = [ - {file = "opentelemetry_exporter_otlp-1.25.0-py3-none-any.whl", hash = "sha256:d67a831757014a3bc3174e4cd629ae1493b7ba8d189e8a007003cacb9f1a6b60"}, - {file = "opentelemetry_exporter_otlp-1.25.0.tar.gz", hash = "sha256:ce03199c1680a845f82e12c0a6a8f61036048c07ec7a0bd943142aca8fa6ced0"}, + {file = "opentelemetry_exporter_otlp-1.39.1-py3-none-any.whl", hash = "sha256:68ae69775291f04f000eb4b698ff16ff685fdebe5cb52871bc4e87938a7b00fe"}, + {file = "opentelemetry_exporter_otlp-1.39.1.tar.gz", hash = "sha256:7cf7470e9fd0060c8a38a23e4f695ac686c06a48ad97f8d4867bc9b420180b9c"}, ] [package.dependencies] -opentelemetry-exporter-otlp-proto-grpc = "1.25.0" -opentelemetry-exporter-otlp-proto-http = "1.25.0" +opentelemetry-exporter-otlp-proto-grpc = "1.39.1" +opentelemetry-exporter-otlp-proto-http = "1.39.1" [[package]] name = "opentelemetry-exporter-otlp-proto-common" -version = "1.25.0" +version = "1.39.1" description = "OpenTelemetry Protobuf encoding" optional = false -python-versions = ">=3.8" +python-versions = ">=3.9" groups = ["dev", "proxy-dev"] files = [ - {file = "opentelemetry_exporter_otlp_proto_common-1.25.0-py3-none-any.whl", hash = "sha256:15637b7d580c2675f70246563363775b4e6de947871e01d0f4e3881d1848d693"}, - {file = "opentelemetry_exporter_otlp_proto_common-1.25.0.tar.gz", hash = "sha256:c93f4e30da4eee02bacd1e004eb82ce4da143a2f8e15b987a9f603e0a85407d3"}, + {file = "opentelemetry_exporter_otlp_proto_common-1.39.1-py3-none-any.whl", hash = "sha256:08f8a5862d64cc3435105686d0216c1365dc5701f86844a8cd56597d0c764fde"}, + {file = "opentelemetry_exporter_otlp_proto_common-1.39.1.tar.gz", hash = "sha256:763370d4737a59741c89a67b50f9e39271639ee4afc999dadfe768541c027464"}, ] [package.dependencies] -opentelemetry-proto = "1.25.0" +opentelemetry-proto = "1.39.1" [[package]] name = "opentelemetry-exporter-otlp-proto-grpc" -version = "1.25.0" +version = "1.39.1" description = "OpenTelemetry Collector Protobuf over gRPC Exporter" optional = false -python-versions = ">=3.8" +python-versions = ">=3.9" groups = ["dev", "proxy-dev"] files = [ - {file = "opentelemetry_exporter_otlp_proto_grpc-1.25.0-py3-none-any.whl", hash = "sha256:3131028f0c0a155a64c430ca600fd658e8e37043cb13209f0109db5c1a3e4eb4"}, - {file = "opentelemetry_exporter_otlp_proto_grpc-1.25.0.tar.gz", hash = "sha256:c0b1661415acec5af87625587efa1ccab68b873745ca0ee96b69bb1042087eac"}, + {file = "opentelemetry_exporter_otlp_proto_grpc-1.39.1-py3-none-any.whl", hash = "sha256:fa1c136a05c7e9b4c09f739469cbdb927ea20b34088ab1d959a849b5cc589c18"}, + {file = "opentelemetry_exporter_otlp_proto_grpc-1.39.1.tar.gz", hash = "sha256:772eb1c9287485d625e4dbe9c879898e5253fea111d9181140f51291b5fec3ad"}, ] [package.dependencies] -deprecated = ">=1.2.6" -googleapis-common-protos = ">=1.52,<2.0" -grpcio = ">=1.0.0,<2.0.0" +googleapis-common-protos = ">=1.57,<2.0" +grpcio = [ + {version = ">=1.63.2,<2.0.0", markers = "python_version < \"3.13\""}, + {version = ">=1.66.2,<2.0.0", markers = "python_version >= \"3.13\""}, +] opentelemetry-api = ">=1.15,<2.0" -opentelemetry-exporter-otlp-proto-common = "1.25.0" -opentelemetry-proto = "1.25.0" -opentelemetry-sdk = ">=1.25.0,<1.26.0" +opentelemetry-exporter-otlp-proto-common = "1.39.1" +opentelemetry-proto = "1.39.1" +opentelemetry-sdk = ">=1.39.1,<1.40.0" +typing-extensions = ">=4.6.0" + +[package.extras] +gcp-auth = ["opentelemetry-exporter-credential-provider-gcp (>=0.59b0)"] [[package]] name = "opentelemetry-exporter-otlp-proto-http" -version = "1.25.0" +version = "1.39.1" description = "OpenTelemetry Collector Protobuf over HTTP Exporter" optional = false -python-versions = ">=3.8" +python-versions = ">=3.9" groups = ["dev", "proxy-dev"] files = [ - {file = "opentelemetry_exporter_otlp_proto_http-1.25.0-py3-none-any.whl", hash = "sha256:2eca686ee11b27acd28198b3ea5e5863a53d1266b91cda47c839d95d5e0541a6"}, - {file = "opentelemetry_exporter_otlp_proto_http-1.25.0.tar.gz", hash = "sha256:9f8723859e37c75183ea7afa73a3542f01d0fd274a5b97487ea24cb683d7d684"}, + {file = "opentelemetry_exporter_otlp_proto_http-1.39.1-py3-none-any.whl", hash = "sha256:d9f5207183dd752a412c4cd564ca8875ececba13be6e9c6c370ffb752fd59985"}, + {file = "opentelemetry_exporter_otlp_proto_http-1.39.1.tar.gz", hash = "sha256:31bdab9745c709ce90a49a0624c2bd445d31a28ba34275951a6a362d16a0b9cb"}, ] [package.dependencies] -deprecated = ">=1.2.6" googleapis-common-protos = ">=1.52,<2.0" opentelemetry-api = ">=1.15,<2.0" -opentelemetry-exporter-otlp-proto-common = "1.25.0" -opentelemetry-proto = "1.25.0" -opentelemetry-sdk = ">=1.25.0,<1.26.0" +opentelemetry-exporter-otlp-proto-common = "1.39.1" +opentelemetry-proto = "1.39.1" +opentelemetry-sdk = ">=1.39.1,<1.40.0" requests = ">=2.7,<3.0" +typing-extensions = ">=4.5.0" + +[package.extras] +gcp-auth = ["opentelemetry-exporter-credential-provider-gcp (>=0.59b0)"] [[package]] name = "opentelemetry-proto" -version = "1.25.0" +version = "1.39.1" description = "OpenTelemetry Python Proto" optional = false -python-versions = ">=3.8" +python-versions = ">=3.9" groups = ["main", "dev", "proxy-dev"] files = [ - {file = "opentelemetry_proto-1.25.0-py3-none-any.whl", hash = "sha256:f07e3341c78d835d9b86665903b199893befa5e98866f63d22b00d0b7ca4972f"}, - {file = "opentelemetry_proto-1.25.0.tar.gz", hash = "sha256:35b6ef9dc4a9f7853ecc5006738ad40443701e52c26099e197895cbda8b815a3"}, + {file = "opentelemetry_proto-1.39.1-py3-none-any.whl", hash = "sha256:22cdc78efd3b3765d09e68bfbd010d4fc254c9818afd0b6b423387d9dee46007"}, + {file = "opentelemetry_proto-1.39.1.tar.gz", hash = "sha256:6c8e05144fc0d3ed4d22c2289c6b126e03bcd0e6a7da0f16cedd2e1c2772e2c8"}, ] markers = {main = "python_version >= \"3.10\" and extra == \"mlflow\""} [package.dependencies] -protobuf = ">=3.19,<5.0" +protobuf = ">=5.0,<7.0" [[package]] name = "opentelemetry-sdk" -version = "1.25.0" +version = "1.39.1" description = "OpenTelemetry Python SDK" optional = false -python-versions = ">=3.8" +python-versions = ">=3.9" groups = ["main", "dev", "proxy-dev"] files = [ - {file = "opentelemetry_sdk-1.25.0-py3-none-any.whl", hash = "sha256:d97ff7ec4b351692e9d5a15af570c693b8715ad78b8aafbec5c7100fe966b4c9"}, - {file = "opentelemetry_sdk-1.25.0.tar.gz", hash = "sha256:ce7fc319c57707ef5bf8b74fb9f8ebdb8bfafbe11898410e0d2a761d08a98ec7"}, + {file = "opentelemetry_sdk-1.39.1-py3-none-any.whl", hash = "sha256:4d5482c478513ecb0a5d938dcc61394e647066e0cc2676bee9f3af3f3f45f01c"}, + {file = "opentelemetry_sdk-1.39.1.tar.gz", hash = "sha256:cf4d4563caf7bff906c9f7967e2be22d0d6b349b908be0d90fb21c8e9c995cc6"}, ] markers = {main = "python_version >= \"3.10\""} [package.dependencies] -opentelemetry-api = "1.25.0" -opentelemetry-semantic-conventions = "0.46b0" -typing-extensions = ">=3.7.4" +opentelemetry-api = "1.39.1" +opentelemetry-semantic-conventions = "0.60b1" +typing-extensions = ">=4.5.0" [[package]] name = "opentelemetry-semantic-conventions" -version = "0.46b0" +version = "0.60b1" description = "OpenTelemetry Semantic Conventions" optional = false -python-versions = ">=3.8" +python-versions = ">=3.9" groups = ["main", "dev", "proxy-dev"] files = [ - {file = "opentelemetry_semantic_conventions-0.46b0-py3-none-any.whl", hash = "sha256:6daef4ef9fa51d51855d9f8e0ccd3a1bd59e0e545abe99ac6203804e36ab3e07"}, - {file = "opentelemetry_semantic_conventions-0.46b0.tar.gz", hash = "sha256:fbc982ecbb6a6e90869b15c1673be90bd18c8a56ff1cffc0864e38e2edffaefa"}, + {file = "opentelemetry_semantic_conventions-0.60b1-py3-none-any.whl", hash = "sha256:9fa8c8b0c110da289809292b0591220d3a7b53c1526a23021e977d68597893fb"}, + {file = "opentelemetry_semantic_conventions-0.60b1.tar.gz", hash = "sha256:87c228b5a0669b748c76d76df6c364c369c28f1c465e50f661e39737e84bc953"}, ] markers = {main = "python_version >= \"3.10\""} [package.dependencies] -opentelemetry-api = "1.25.0" +opentelemetry-api = "1.39.1" +typing-extensions = ">=4.5.0" [[package]] name = "orjson" @@ -4816,23 +4839,23 @@ testing = ["google-api-core (>=1.31.5)"] [[package]] name = "protobuf" -version = "4.25.8" +version = "5.29.5" description = "" optional = false python-versions = ">=3.8" groups = ["main", "dev", "proxy-dev"] files = [ - {file = "protobuf-4.25.8-cp310-abi3-win32.whl", hash = "sha256:504435d831565f7cfac9f0714440028907f1975e4bed228e58e72ecfff58a1e0"}, - {file = "protobuf-4.25.8-cp310-abi3-win_amd64.whl", hash = "sha256:bd551eb1fe1d7e92c1af1d75bdfa572eff1ab0e5bf1736716814cdccdb2360f9"}, - {file = "protobuf-4.25.8-cp37-abi3-macosx_10_9_universal2.whl", hash = "sha256:ca809b42f4444f144f2115c4c1a747b9a404d590f18f37e9402422033e464e0f"}, - {file = "protobuf-4.25.8-cp37-abi3-manylinux2014_aarch64.whl", hash = "sha256:9ad7ef62d92baf5a8654fbb88dac7fa5594cfa70fd3440488a5ca3bfc6d795a7"}, - {file = "protobuf-4.25.8-cp37-abi3-manylinux2014_x86_64.whl", hash = "sha256:83e6e54e93d2b696a92cad6e6efc924f3850f82b52e1563778dfab8b355101b0"}, - {file = "protobuf-4.25.8-cp38-cp38-win32.whl", hash = "sha256:27d498ffd1f21fb81d987a041c32d07857d1d107909f5134ba3350e1ce80a4af"}, - {file = "protobuf-4.25.8-cp38-cp38-win_amd64.whl", hash = "sha256:d552c53d0415449c8d17ced5c341caba0d89dbf433698e1436c8fa0aae7808a3"}, - {file = "protobuf-4.25.8-cp39-cp39-win32.whl", hash = "sha256:077ff8badf2acf8bc474406706ad890466274191a48d0abd3bd6987107c9cde5"}, - {file = "protobuf-4.25.8-cp39-cp39-win_amd64.whl", hash = "sha256:f4510b93a3bec6eba8fd8f1093e9d7fb0d4a24d1a81377c10c0e5bbfe9e4ed24"}, - {file = "protobuf-4.25.8-py3-none-any.whl", hash = "sha256:15a0af558aa3b13efef102ae6e4f3efac06f1eea11afb3a57db2901447d9fb59"}, - {file = "protobuf-4.25.8.tar.gz", hash = "sha256:6135cf8affe1fc6f76cced2641e4ea8d3e59518d1f24ae41ba97bcad82d397cd"}, + {file = "protobuf-5.29.5-cp310-abi3-win32.whl", hash = "sha256:3f1c6468a2cfd102ff4703976138844f78ebd1fb45f49011afc5139e9e283079"}, + {file = "protobuf-5.29.5-cp310-abi3-win_amd64.whl", hash = "sha256:3f76e3a3675b4a4d867b52e4a5f5b78a2ef9565549d4037e06cf7b0942b1d3fc"}, + {file = "protobuf-5.29.5-cp38-abi3-macosx_10_9_universal2.whl", hash = "sha256:e38c5add5a311f2a6eb0340716ef9b039c1dfa428b28f25a7838ac329204a671"}, + {file = "protobuf-5.29.5-cp38-abi3-manylinux2014_aarch64.whl", hash = "sha256:fa18533a299d7ab6c55a238bf8629311439995f2e7eca5caaff08663606e9015"}, + {file = "protobuf-5.29.5-cp38-abi3-manylinux2014_x86_64.whl", hash = "sha256:63848923da3325e1bf7e9003d680ce6e14b07e55d0473253a690c3a8b8fd6e61"}, + {file = "protobuf-5.29.5-cp38-cp38-win32.whl", hash = "sha256:ef91363ad4faba7b25d844ef1ada59ff1604184c0bcd8b39b8a6bef15e1af238"}, + {file = "protobuf-5.29.5-cp38-cp38-win_amd64.whl", hash = "sha256:7318608d56b6402d2ea7704ff1e1e4597bee46d760e7e4dd42a3d45e24b87f2e"}, + {file = "protobuf-5.29.5-cp39-cp39-win32.whl", hash = "sha256:6f642dc9a61782fa72b90878af134c5afe1917c89a568cd3476d758d3c3a0736"}, + {file = "protobuf-5.29.5-cp39-cp39-win_amd64.whl", hash = "sha256:470f3af547ef17847a28e1f47200a1cbf0ba3ff57b7de50d22776607cd2ea353"}, + {file = "protobuf-5.29.5-py3-none-any.whl", hash = "sha256:6cf42630262c59b2d8de33954443d94b746c952b01434fc58a417fdbd2e84bd5"}, + {file = "protobuf-5.29.5.tar.gz", hash = "sha256:bc1463bafd4b0929216c35f437a8e28731a2b7fe3d98bb77a600efced5a15c84"}, ] markers = {main = "python_version >= \"3.10\" and (extra == \"mlflow\" or extra == \"extra-proxy\") or extra == \"extra-proxy\""} @@ -7675,7 +7698,7 @@ version = "1.17.3" description = "Module for decorators, wrappers and monkey patching." optional = false python-versions = ">=3.8" -groups = ["main", "dev", "proxy-dev"] +groups = ["dev"] files = [ {file = "wrapt-1.17.3-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:88bbae4d40d5a46142e70d58bf664a89b6b4befaea7b2ecc14e03cedb8e06c04"}, {file = "wrapt-1.17.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:e6b13af258d6a9ad602d57d889f83b9d5543acd471eee12eb51f5b01f8eb1bc2"}, @@ -7759,7 +7782,6 @@ files = [ {file = "wrapt-1.17.3-py3-none-any.whl", hash = "sha256:7171ae35d2c33d326ac19dd8facb1e82e5fd04ef8c6c0e394d7af55a55051c22"}, {file = "wrapt-1.17.3.tar.gz", hash = "sha256:f66eb08feaa410fe4eebd17f2a2c8e2e46d3476e9f8c783daa8e09e0faa666d0"}, ] -markers = {main = "python_version >= \"3.10\""} [[package]] name = "wsproto" @@ -7960,7 +7982,7 @@ type = ["pytest-mypy"] [extras] caching = ["diskcache"] -extra-proxy = ["azure-identity", "azure-keyvault-secrets", "google-cloud-iam", "google-cloud-kms", "prisma", "redisvl", "resend"] +extra-proxy = ["a2a-sdk", "azure-identity", "azure-keyvault-secrets", "google-cloud-iam", "google-cloud-kms", "prisma", "redisvl", "resend"] mlflow = ["mlflow"] proxy = ["PyJWT", "apscheduler", "azure-identity", "azure-storage-blob", "backoff", "boto3", "cryptography", "fastapi", "fastapi-sso", "gunicorn", "litellm-enterprise", "litellm-proxy-extras", "mcp", "orjson", "polars", "pynacl", "python-multipart", "pyyaml", "rich", "rq", "soundfile", "uvicorn", "uvloop", "websockets"] semantic-router = ["semantic-router"] @@ -7969,4 +7991,4 @@ utils = ["numpydoc"] [metadata] lock-version = "2.1" python-versions = ">=3.9,<4.0" -content-hash = "7c9f917a46adc0d0b57dbc48cbdc3622551aa733d86909da1be87773c2857694" +content-hash = "3a929b2e1dc2b85edcf78f93b0c15eda2bf0cdf8d3e0e30778fc63178c650e40" diff --git a/pyproject.toml b/pyproject.toml index d80d442be2..c694bfddcd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -59,6 +59,7 @@ websockets = {version = "^15.0.1", optional = true} boto3 = {version = "1.36.0", optional = true} redisvl = {version = "^0.4.1", optional = true, markers = "python_version >= '3.9' and python_version < '3.14'"} mcp = {version = ">=1.25.0,<2.0.0", optional = true, python = ">=3.10"} +a2a-sdk = {version = "^0.3.22", optional = true, python = ">=3.10"} litellm-proxy-extras = {version = "0.4.23", optional = true} rich = {version = "13.7.1", optional = true} litellm-enterprise = {version = "0.1.27", optional = true} @@ -111,7 +112,8 @@ extra_proxy = [ "google-cloud-kms", "google-cloud-iam", "resend", - "redisvl" + "redisvl", + "a2a-sdk" ] utils = [ @@ -147,9 +149,9 @@ types-requests = "*" types-setuptools = "*" types-redis = "*" types-PyYAML = "*" -opentelemetry-api = "1.25.0" -opentelemetry-sdk = "1.25.0" -opentelemetry-exporter-otlp = "1.25.0" +opentelemetry-api = "^1.28.0" +opentelemetry-sdk = "^1.28.0" +opentelemetry-exporter-otlp = "^1.28.0" langfuse = "^2.45.0" fastapi-offline = "^1.7.3" @@ -157,9 +159,9 @@ fastapi-offline = "^1.7.3" prisma = "0.11.0" hypercorn = "^0.15.0" prometheus-client = "0.20.0" -opentelemetry-api = "1.25.0" -opentelemetry-sdk = "1.25.0" -opentelemetry-exporter-otlp = "1.25.0" +opentelemetry-api = "^1.28.0" +opentelemetry-sdk = "^1.28.0" +opentelemetry-exporter-otlp = "^1.28.0" azure-identity = {version = "^1.15.0", python = ">=3.9"} [build-system] diff --git a/requirements.txt b/requirements.txt index a49a94ca27..68b85c3935 100644 --- a/requirements.txt +++ b/requirements.txt @@ -16,12 +16,12 @@ prisma==0.11.0 # for db nodejs-wheel-binaries==24.12.0 ## required by prisma for migrations, prevents runtime download (updated from nodejs-bin for security fixes) mangum==0.17.0 # for aws lambda functions pynacl==1.6.2 # for encrypting keys -google-cloud-aiplatform==1.47.0 # for vertex ai calls +google-cloud-aiplatform==1.133.0 # for vertex ai calls google-cloud-iam==2.19.1 # for GCP IAM Redis authentication -google-genai==1.22.0 +google-genai==1.37.0 anthropic[vertex]==0.54.0 mcp==1.25.0 ; python_version >= "3.10" # for MCP server -google-generativeai==0.5.0 # for vertex ai calls +# google-generativeai removed - deprecated, replaced by google-genai (line 21) async_generator==1.10.0 # for async ollama calls langfuse==2.59.7 # for langfuse self-hosted logging prometheus_client==0.20.0 # for /metrics endpoint on proxy @@ -38,9 +38,10 @@ azure-ai-contentsafety==1.0.0 # for azure content safety azure-identity==1.16.1 ; python_version >= "3.9" # for azure content safety azure-keyvault==4.2.0 # for azure KMS integration azure-storage-file-datalake==12.20.0 # for azure buck storage logging -opentelemetry-api==1.25.0 -opentelemetry-sdk==1.25.0 -opentelemetry-exporter-otlp==1.25.0 +opentelemetry-api==1.28.0 +opentelemetry-sdk==1.28.0 +opentelemetry-exporter-otlp==1.28.0 +a2a-sdk>=0.3.22 ; python_version >= "3.10" # grpcio: 1.68.0-1.68.1 has reconnect bug (#38290), 1.75+ has Python 3.14 wheels + fix grpcio>=1.62.3,!=1.68.*,!=1.69.*,!=1.70.*,!=1.71.0,!=1.71.1,!=1.72.0,!=1.72.1,!=1.73.0; python_version < "3.14" grpcio>=1.75.0; python_version >= "3.14" diff --git a/scripts/health_check/health_check_client.py b/scripts/health_check/health_check_client.py new file mode 100644 index 0000000000..337c754a75 --- /dev/null +++ b/scripts/health_check/health_check_client.py @@ -0,0 +1,406 @@ +#!/usr/bin/env python3 +""" +LiteLLM Health Check Client + +A sentinel health check tool that tests all configured models on a LiteLLM proxy. +Similar to HRT's health check system, this script: +- Can read models from YAML config file (like HRT) or fetch from proxy API +- Sends a simple test request to each model concurrently +- Reports health status for each model +- Supports both chat/completion and embedding models +""" + +import asyncio +import json +import os +import sys +import time +from typing import Dict, List, Optional, Tuple + +import httpx +import yaml + + +class LiteLLMHealthCheckClient: + """Client for health checking LiteLLM proxy models.""" + + def __init__( + self, + base_url: str, + api_key: str, + timeout: int = 120, # Match Go implementation's 120s timeout + completion_prompt: str = "Say this is a test", # Match Go implementation + embedding_text: str = "This is a test for vectorization.", # Match Go implementation + ): + """ + Initialize the health check client. + + Args: + base_url: Base URL of the LiteLLM proxy (e.g., https://litellm.example.com) + api_key: API key for authentication + timeout: Request timeout in seconds (default: 120, matching Go implementation) + completion_prompt: Test prompt for chat/completion models + embedding_text: Test text for embedding models + """ + self.base_url = base_url.rstrip("/") + self.api_key = api_key + self.timeout = timeout + self.completion_prompt = completion_prompt + self.embedding_text = embedding_text + self.headers = { + "Authorization": f"Bearer {api_key}", + "Content-Type": "application/json", + } + + def load_models_from_yaml(self, yaml_path: str) -> List[Dict]: + """ + Load models from a YAML config file (similar to Go implementation). + + Args: + yaml_path: Path to the YAML config file + + Returns: + List of model dictionaries with 'id' and 'mode' keys + """ + try: + with open(yaml_path, "r") as f: + config = yaml.safe_load(f) + + model_list = config.get("model_list", []) + models = [] + + for entry in model_list: + model_name = entry.get("model_name", "") + litellm_params = entry.get("litellm_params", {}) + model_info = litellm_params.get("model_info", {}) + mode = model_info.get("mode", "") + + # Use model_name as the ID (this is what gets sent to the API) + models.append( + { + "id": model_name, + "mode": mode.lower() if mode else "", + "provider": model_info.get("provider", ""), + } + ) + + return models + except Exception as e: + print(f"Error loading models from YAML file {yaml_path}: {e}", file=sys.stderr) + return [] + + async def fetch_models(self, client: httpx.AsyncClient) -> List[Dict]: + """ + Fetch all available models from the proxy API. + + Returns: + List of model dictionaries with 'id' and 'mode' keys + """ + try: + # Try /v1/models first (OpenAI-compatible endpoint) + response = await client.get( + f"{self.base_url}/v1/models", + headers=self.headers, + timeout=self.timeout, + ) + response.raise_for_status() + data = response.json() + models_data = data.get("data", []) + models = [] + for m in models_data: + models.append({"id": m["id"], "mode": "", "provider": ""}) + return models + except Exception as e: + print(f"Error fetching models from /v1/models: {e}", file=sys.stderr) + # Fallback to /model/info endpoint which has more details + try: + response = await client.get( + f"{self.base_url}/model/info", + headers=self.headers, + timeout=self.timeout, + ) + response.raise_for_status() + data = response.json() + if isinstance(data, dict) and "data" in data: + models_data = data["data"] + elif isinstance(data, list): + models_data = data + else: + models_data = [] + + models = [] + for m in models_data: + model_info = m.get("model_info", {}) + mode = model_info.get("mode", "") + models.append( + { + "id": m.get("model_name", m.get("id", "unknown")), + "mode": mode.lower() if mode else "", + "provider": model_info.get("provider", ""), + } + ) + return models + except Exception as e2: + print(f"Error fetching models from /model/info: {e2}", file=sys.stderr) + return [] + + async def check_model_health( + self, client: httpx.AsyncClient, model: Dict + ) -> Tuple[str, Dict]: + """ + Check health of a single model by sending a test request. + + Args: + client: HTTP client + model: Model dictionary with 'id' and 'mode' keys + + Returns: + Tuple of (model_id, result_dict) + """ + model_id = model["id"] + mode = model.get("mode", "") + + start_time = time.time() + result = { + "model": model_id, + "healthy": False, + "error": None, + "response_time_ms": None, + "mode": mode, + } + + try: + # Determine if this is an embedding model + # Check mode first (from config), then fall back to name-based detection + is_embedding = ( + mode == "embedding" + or any( + keyword in model_id.lower() + for keyword in ["embedding", "embed", "text-embedding"] + ) + ) + + if is_embedding: + # Test embedding endpoint (matching Go implementation) + embedding_response = await client.post( + f"{self.base_url}/v1/embeddings", + headers=self.headers, + json={ + "model": model_id, + "input": self.embedding_text, + }, + timeout=self.timeout, + ) + embedding_response.raise_for_status() + embedding_data = embedding_response.json() + dimensions = 0 + if "data" in embedding_data and len(embedding_data["data"]) > 0: + dimensions = len(embedding_data["data"][0].get("embedding", [])) + + result["healthy"] = True + result["mode"] = "embedding" + result["dimensions"] = dimensions + else: + # Test chat completion endpoint (matching Go implementation) + completion_response = await client.post( + f"{self.base_url}/v1/chat/completions", + headers=self.headers, + json={ + "model": model_id, + "messages": [ + {"role": "user", "content": self.completion_prompt} + ], + "max_tokens": 10, # Minimal tokens for health check + }, + timeout=self.timeout, + ) + completion_response.raise_for_status() + completion_data = completion_response.json() + response_text = "" + if "choices" in completion_data and len(completion_data["choices"]) > 0: + response_text = ( + completion_data["choices"][0] + .get("message", {}) + .get("content", "") + ) + + result["healthy"] = True + result["mode"] = "chat" + result["response_text"] = response_text[:100] # Truncate for display + + elapsed_ms = (time.time() - start_time) * 1000 + result["response_time_ms"] = round(elapsed_ms, 2) + + except httpx.HTTPStatusError as e: + result["error"] = f"HTTP {e.response.status_code}: {e.response.text[:200]}" + except httpx.TimeoutException: + result["error"] = f"Request timeout after {self.timeout}s" + except Exception as e: + result["error"] = str(e)[:200] + + return model_id, result + + async def run_health_checks( + self, + models: Optional[List[Dict]] = None, + models_only: Optional[List[str]] = None, + ) -> Dict[str, Dict]: + """ + Run health checks on all models concurrently. + + Args: + models: Optional list of models to check. If None, fetches from proxy. + models_only: Optional list of model IDs to check. If set, only these + models are health-checked (must exist in the models list). + + Returns: + Dictionary mapping model_id to health check result + """ + async with httpx.AsyncClient() as client: + if models is None: + models = await self.fetch_models(client) + + if not models: + print("No models found to health check", file=sys.stderr) + return {} + + if models_only: + allowlist = {m.strip() for m in models_only if m and m.strip()} + models = [m for m in models if m.get("id") in allowlist] + print( + f"Filtering to only check {len(models)} models: {', '.join(sorted(allowlist))}", + file=sys.stderr, + ) + if not models: + print( + "No models matched LITELLM_MODELS_ONLY filter", + file=sys.stderr, + ) + return {} + + print(f"Running health checks on {len(models)} models...", file=sys.stderr) + + # Run all health checks concurrently + tasks = [self.check_model_health(client, model) for model in models] + results_list = await asyncio.gather(*tasks, return_exceptions=True) + + # Convert to dictionary format + results = {} + for result in results_list: + if isinstance(result, Exception): + print( + f"Exception in health check task: {result}", file=sys.stderr + ) + continue + # Type narrowing: after checking it's not an Exception, it's a Tuple + if isinstance(result, tuple) and len(result) == 2: + model_id, result_dict = result + results[model_id] = result_dict + + return results + + def print_results(self, results: Dict[str, Dict], json_output: bool = False): + """ + Print health check results. + + Args: + results: Dictionary of health check results + json_output: If True, output as JSON + """ + if json_output: + print(json.dumps(results, indent=2)) + return + + healthy_count = sum(1 for r in results.values() if r.get("healthy")) + unhealthy_count = len(results) - healthy_count + + # Print detailed results for each model (matching Go output format) + print(f"\n{'='*60}", file=sys.stderr) + print(f"Starting health check queries\n", file=sys.stderr) + + for model_id, result in results.items(): + if result.get("healthy"): + if result.get("mode") == "embedding": + dimensions = result.get("dimensions", 0) + print( + f"---- {model_id} ----\n✅ Success. " + f"Generated embedding vector with {dimensions} dimensions.\n\n", + file=sys.stderr, + ) + else: + response_text = result.get("response_text", "") + print( + f"---- {model_id} ----\n✅ Success. " + f"Response:\n{response_text}\n\n", + file=sys.stderr, + ) + else: + error = result.get("error", "Unknown error") + print(f"---- {model_id} ----\n❌ ERROR: {error}\n\n", file=sys.stderr) + + print(f"{'='*60}", file=sys.stderr) + print(f"Health Check Summary", file=sys.stderr) + print(f"{'='*60}", file=sys.stderr) + print(f"Total models: {len(results)}", file=sys.stderr) + print(f"Healthy: {healthy_count}", file=sys.stderr) + print(f"Unhealthy: {unhealthy_count}", file=sys.stderr) + print(f"{'='*60}\n", file=sys.stderr) + + # Exit with non-zero code if any models are unhealthy + if unhealthy_count > 0: + sys.exit(1) + else: + sys.exit(0) + + +async def main(): + """Main entry point.""" + base_url = os.environ.get("LITELLM_BASE_URL", "http://localhost:4000") + api_key = os.environ.get("LITELLM_API_KEY", "sk-1234") + yaml_path = os.environ.get("LITELLM_MODELS_YAML") + + if not base_url: + print("Error: LITELLM_BASE_URL environment variable not set", file=sys.stderr) + sys.exit(1) + + if not api_key: + print("Error: LITELLM_API_KEY environment variable not set", file=sys.stderr) + sys.exit(1) + + timeout = int(os.environ.get("LITELLM_TIMEOUT", "120")) # Match Go's 120s default + completion_prompt = os.environ.get( + "LITELLM_COMPLETION_PROMPT", "Say this is a test" + ) + embedding_text = os.environ.get( + "LITELLM_EMBEDDING_TEXT", "This is a test for vectorization." + ) + json_output = os.environ.get("LITELLM_JSON_OUTPUT", "").lower() == "true" + # Optional: only health-check these model IDs (comma-separated). E.g.: + # LITELLM_MODELS_ONLY=claude-3.7-sonnet,claude-3.5-sonnet,claude-4.5-haiku + models_only_raw = os.environ.get("LITELLM_MODELS_ONLY", "") + models_only = [m.strip() for m in models_only_raw.split(",") if m.strip()] or None + + client = LiteLLMHealthCheckClient( + base_url=base_url, + api_key=api_key, + timeout=timeout, + completion_prompt=completion_prompt, + embedding_text=embedding_text, + ) + + # Load models from YAML if provided, otherwise fetch from API + models = None + if yaml_path: + models = client.load_models_from_yaml(yaml_path) + if models: + print( + f"Successfully loaded {len(models)} models from {yaml_path}", + file=sys.stderr, + ) + + results = await client.run_health_checks(models=models, models_only=models_only) + client.print_results(results, json_output=json_output) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/scripts/health_check/health_check_client_README.md b/scripts/health_check/health_check_client_README.md new file mode 100644 index 0000000000..e3132499e0 --- /dev/null +++ b/scripts/health_check/health_check_client_README.md @@ -0,0 +1,246 @@ +# LiteLLM Health Check Client + +A health check tool for testing all configured models on a LiteLLM proxy. Tests each model with completion/embedding requests and reports health status, errors, and response times. + +## Features + +- **YAML Config Support**: Reads models from YAML config file OR fetches from proxy API +- **Smart Mode Detection**: Detects embedding vs chat models from config or model name +- **Concurrent Testing**: Tests all models concurrently using asyncio +- **Containerized**: Docker image for easy deployment +- **Parallel Execution**: Supports parallel execution for stress testing +- **Configurable**: Customizable timeouts (default 120s) and test prompts + +## Quick Start + +### As a Python Script + +**Option 1: Fetch models from proxy API** +```bash +export LITELLM_BASE_URL="https://litellm.example.com" +export LITELLM_API_KEY="your-api-key" +python scripts/health_check/health_check_client.py +``` + +**Option 2: Use YAML config file** +```bash +export LITELLM_BASE_URL="https://litellm.example.com" +export LITELLM_API_KEY="your-api-key" +export LITELLM_MODELS_YAML="/path/to/config.yaml" +python scripts/health_check/health_check_client.py +``` + +### As a Docker Container + +1. Build the Docker image: + +```bash +docker build -f docker/Dockerfile.health_check -t litellm/litellm-health-check:latest . +``` + +2. Run a single health check: + +```bash +docker run --rm \ + -e LITELLM_BASE_URL="https://litellm.example.com" \ + -e LITELLM_API_KEY="your-api-key" \ + litellm/litellm-health-check:latest +``` + +### Parallel Execution (Stress Testing) + +Run multiple health check containers in parallel: + +**PowerShell:** +```powershell +$env:LITELLM_BASE_URL="https://litellm.example.com" +$env:LITELLM_API_KEY="your-api-key" +.\scripts\health_check\run_parallel_health_checks.ps1 16 +``` + +**Bash/Shell:** +```bash +export LITELLM_BASE_URL="https://litellm.example.com" +export LITELLM_API_KEY="your-api-key" +./scripts/health_check/run_parallel_health_checks.sh 16 +``` + + +## Configuration + +### Environment Variables + +- `LITELLM_BASE_URL` (required): Base URL of the LiteLLM proxy + - Example: `https://litellm.example.com` +- `LITELLM_API_KEY` (required): API key for authentication +- `LITELLM_MODELS_YAML` (optional): Path to YAML config file with model_list + - If provided, reads models from YAML instead of fetching from API + - Example: `/path/to/config.yaml` +- `LITELLM_TIMEOUT` (optional): Request timeout in seconds (default: 120) +- `LITELLM_COMPLETION_PROMPT` (optional): Test prompt for chat/completion models (default: "Say this is a test") +- `LITELLM_EMBEDDING_TEXT` (optional): Test text for embedding models (default: "This is a test for vectorization.") +- `LITELLM_JSON_OUTPUT` (optional): Output results as JSON (default: false) + +## Output + +### Standard Output (Human-Readable) + +Example output format: + +``` +============================================================ +Starting health check queries + +---- gpt-4o ---- +✅ Success. Response: +This is a test + +---- text-embedding-3-small ---- +✅ Success. Generated embedding vector with 1536 dimensions. + +---- gpt-5-codex ---- +❌ ERROR: HTTP 503: Service unavailable + +============================================================ +Health Check Summary +============================================================ +Total models: 47 +Healthy: 45 +Unhealthy: 2 +============================================================ +``` + +Exit code: `0` if all models are healthy, `1` if any models are unhealthy. + +### JSON Output + +When `LITELLM_JSON_OUTPUT=true`, outputs JSON: + +```json +{ + "gpt-4o": { + "model": "gpt-4o", + "healthy": true, + "error": null, + "response_time_ms": 245.67, + "mode": "chat", + "response_text": "This is a test" + }, + "text-embedding-3-small": { + "model": "text-embedding-3-small", + "healthy": true, + "error": null, + "response_time_ms": 123.45, + "mode": "embedding", + "dimensions": 1536 + } +} +``` + +## How It Works + +1. **Model Discovery**: + - If `LITELLM_MODELS_YAML` is set: Reads models from YAML config file + - Otherwise: Queries `/v1/models` (OpenAI-compatible) or `/model/info` to get all configured models +2. **Mode Detection**: + - Checks `mode` field from YAML config, or falls back to model name patterns (embedding, embed, text-embedding) +3. **Concurrent Testing**: + - Chat models: `POST /v1/chat/completions` with configurable prompt (default: "Say this is a test") + - Embedding models: `POST /v1/embeddings` with configurable text (default: "This is a test for vectorization.") +4. **Reporting**: Health status, errors, response times, and response details are reported + +## Use Cases + +### 1. Regular Health Monitoring + +Run as a cron job or scheduled task: + +```bash +# Cron job: Run every 5 minutes +*/5 * * * * /path/to/health_check.sh +``` + +### 2. Load/Stress Testing + +Run multiple health checks in parallel: + +**PowerShell:** +```powershell +.\scripts\health_check\run_parallel_health_checks.ps1 16 +``` + +### 3. CI/CD Integration + +Add to your deployment pipeline: + +```yaml +# GitHub Actions example +- name: Health Check + run: | + docker run --rm \ + -e LITELLM_BASE_URL="${{ secrets.LITELLM_BASE_URL }}" \ + -e LITELLM_API_KEY="${{ secrets.LITELLM_API_KEY }}" \ + litellm/litellm-health-check:latest +``` + +### 4. Kubernetes Deployment + +Deploy as a CronJob: + +```yaml +apiVersion: batch/v1 +kind: CronJob +metadata: + name: litellm-health-check +spec: + schedule: "*/5 * * * *" # Every 5 minutes + jobTemplate: + spec: + template: + spec: + containers: + - name: health-check + image: litellm/litellm-health-check:latest + env: + - name: LITELLM_BASE_URL + value: "https://litellm.example.com" + - name: LITELLM_API_KEY + valueFrom: + secretKeyRef: + name: litellm-secrets + key: api-key + restartPolicy: OnFailure +``` + +## Troubleshooting + +### No Models Found + +- Verify `LITELLM_BASE_URL` is correct +- Check that the API key has permissions to list models +- Ensure the proxy is running and accessible +- If using YAML, verify `LITELLM_MODELS_YAML` path is correct + +### Timeout Errors + +- Increase `LITELLM_TIMEOUT` for slower models (default is 120s) +- Check network connectivity to the proxy +- Verify proxy isn't overloaded + +### Authentication Errors + +- Verify `LITELLM_API_KEY` is correct +- Check API key has not expired +- Ensure the key has necessary permissions + +## Dependencies + +- Python 3.11+ +- httpx (for async HTTP requests) +- pyyaml (for YAML config file support) +- Docker or Podman (for containerized execution) +- PowerShell (for parallel execution script on Windows) + +## License + +Same as LiteLLM project. diff --git a/scripts/health_check/health_check_requirements.txt b/scripts/health_check/health_check_requirements.txt new file mode 100644 index 0000000000..c9d2650c88 --- /dev/null +++ b/scripts/health_check/health_check_requirements.txt @@ -0,0 +1,2 @@ +httpx>=0.24.0 +pyyaml>=6.0 diff --git a/scripts/health_check/run_parallel_health_checks.ps1 b/scripts/health_check/run_parallel_health_checks.ps1 new file mode 100644 index 0000000000..856e7f20ec --- /dev/null +++ b/scripts/health_check/run_parallel_health_checks.ps1 @@ -0,0 +1,69 @@ +# Parallel LiteLLM Health Check Runner (PowerShell version) +# +# This script runs multiple health check containers in parallel. +# +# Usage: +# $env:LITELLM_BASE_URL="https://litellm.example.com" +# $env:LITELLM_API_KEY="your-api-key" +# .\run_parallel_health_checks.ps1 [num_parallel_jobs] [image_name] +# +# Defaults: +# - num_parallel_jobs: 16 +# - image_name: litellm/litellm-health-check:latest + +param( + [int]$NumParallelJobs = 16, + [string]$ImageName = "litellm/litellm-health-check:latest", + [string]$ContainerRuntime = "docker" +) + +# Set defaults for environment variables if not provided +if (-not $env:LITELLM_BASE_URL) { + $env:LITELLM_BASE_URL = "https://litellm-perf-cache-and-router.onrender.com" + Write-Warning "LITELLM_BASE_URL not set, using default: $env:LITELLM_BASE_URL" +} + +if (-not $env:LITELLM_API_KEY) { + $env:LITELLM_API_KEY = "sk-1234" + Write-Warning "LITELLM_API_KEY not set, using default: $env:LITELLM_API_KEY" +} + +# Check if container runtime is available +$runtimeExists = Get-Command $ContainerRuntime -ErrorAction SilentlyContinue +if (-not $runtimeExists) { + Write-Error "Error: $ContainerRuntime is not installed" + exit 1 +} + +Write-Host "Running $NumParallelJobs parallel health check containers..." -ForegroundColor Yellow +Write-Host "Using image: $ImageName" -ForegroundColor Yellow +Write-Host "Container runtime: $ContainerRuntime" -ForegroundColor Yellow +Write-Host "LiteLLM Base URL: $env:LITELLM_BASE_URL" -ForegroundColor Cyan +Write-Host "" +Write-Host "NOTE: This will run continuously. Press Ctrl+C to stop." -ForegroundColor Red +Write-Host "" +Write-Host "Troubleshooting:" -ForegroundColor Yellow +Write-Host " - If you see 'All connection attempts failed', check:" -ForegroundColor Yellow +Write-Host " 1. Is the LiteLLM proxy running on the expected port?" -ForegroundColor Yellow +Write-Host " 2. Set LITELLM_BASE_URL to the correct URL (e.g., http://host.docker.internal:PORT)" -ForegroundColor Yellow +Write-Host " 3. On Linux, you may need to use the host IP instead of host.docker.internal" -ForegroundColor Yellow +Write-Host "" + +# Run parallel health checks +# This creates an infinite loop that keeps spawning containers +# Each container tests all models, then exits, and a new one starts +while ($true) { + # Start up to NumParallelJobs containers in parallel + 1..$NumParallelJobs | ForEach-Object -Parallel { + $runtime = $using:ContainerRuntime + $imageName = $using:ImageName + $baseUrl = $env:LITELLM_BASE_URL + $apiKey = $env:LITELLM_API_KEY + + & $runtime run --rm ` + -e LITELLM_BASE_URL="$baseUrl" ` + -e LITELLM_API_KEY="$apiKey" ` + -e LITELLM_JSON_OUTPUT="true" ` + $imageName + } -ThrottleLimit $NumParallelJobs +} diff --git a/scripts/health_check/run_parallel_health_checks.sh b/scripts/health_check/run_parallel_health_checks.sh new file mode 100644 index 0000000000..9b6c5d9f39 --- /dev/null +++ b/scripts/health_check/run_parallel_health_checks.sh @@ -0,0 +1,79 @@ +#!/bin/bash +# Parallel LiteLLM Health Check Runner (Bash version) +# +# This script runs multiple health check containers in parallel. +# +# Usage: +# export LITELLM_BASE_URL="https://litellm.example.com" +# export LITELLM_API_KEY="your-api-key" +# ./run_parallel_health_checks.sh [num_parallel_jobs] [image_name] [container_runtime] +# +# Defaults: +# - num_parallel_jobs: 16 +# - image_name: litellm/litellm-health-check:latest +# - container_runtime: docker + +set -e + +# Default values +NUM_PARALLEL_JOBS="${1:-16}" +IMAGE_NAME="${2:-litellm/litellm-health-check:latest}" +CONTAINER_RUNTIME="${3:-docker}" + +# Set defaults for environment variables if not provided +if [ -z "$LITELLM_BASE_URL" ]; then + export LITELLM_BASE_URL="https://litellm-perf-cache-and-router.onrender.com" + echo "Warning: LITELLM_BASE_URL not set, using default: $LITELLM_BASE_URL" >&2 +fi + +if [ -z "$LITELLM_API_KEY" ]; then + export LITELLM_API_KEY="sk-1234" + echo "Warning: LITELLM_API_KEY not set, using default: $LITELLM_API_KEY" >&2 +fi + +# Check if container runtime is available +if ! command -v "$CONTAINER_RUNTIME" &> /dev/null; then + echo "Error: $CONTAINER_RUNTIME is not installed" >&2 + exit 1 +fi + +# Print configuration +echo "Running $NUM_PARALLEL_JOBS parallel health check containers..." +echo "Using image: $IMAGE_NAME" +echo "Container runtime: $CONTAINER_RUNTIME" +echo "LiteLLM Base URL: $LITELLM_BASE_URL" +echo "" +echo "NOTE: This will run continuously. Press Ctrl+C to stop." +echo "" +echo "Troubleshooting:" +echo " - If you see 'All connection attempts failed', check:" +echo " 1. Is the LiteLLM proxy running on the expected port?" +echo " 2. Set LITELLM_BASE_URL to the correct URL (e.g., http://host.docker.internal:PORT)" +echo " 3. On Linux, you may need to use the host IP instead of host.docker.internal" +echo "" + +# Function to run a single health check container +run_health_check() { + "$CONTAINER_RUNTIME" run --rm \ + -e LITELLM_BASE_URL="$LITELLM_BASE_URL" \ + -e LITELLM_API_KEY="$LITELLM_API_KEY" \ + -e LITELLM_JSON_OUTPUT="true" \ + "$IMAGE_NAME" +} + +# Run parallel health checks +# This creates an infinite loop that keeps spawning containers +# Each container tests all models, then exits, and a new one starts +while true; do + # Start containers in parallel using background jobs + pids=() + for ((i=1; i<=NUM_PARALLEL_JOBS; i++)); do + run_health_check & + pids+=($!) + done + + # Wait for all background jobs to complete + for pid in "${pids[@]}"; do + wait "$pid" 2>/dev/null || true + done +done diff --git a/tests/code_coverage_tests/liccheck.ini b/tests/code_coverage_tests/liccheck.ini index feb182921d..ea87b56ff3 100644 --- a/tests/code_coverage_tests/liccheck.ini +++ b/tests/code_coverage_tests/liccheck.ini @@ -89,6 +89,7 @@ tokenizers: >=0.20.2 # Apache 2.0 License jinja2: >=3.1.4 # BSD 3-Clause License litellm-proxy-extras: >=0.1.1 # MIT License litellm-enterprise: >=0.1.1 # LiteLLM Enterprise License +a2a-sdk: >=0.3.22 # Apache 2.0 license anyio: >=4.5.0 # Unknown license httpx-aiohttp: >=0.1.4 # Unknown license backoff: >=2.2.1 # Unknown license diff --git a/tests/llm_responses_api_testing/test_base_responses_api_streaming_iterator.py b/tests/llm_responses_api_testing/test_base_responses_api_streaming_iterator.py index 1103eaf92b..b81161881e 100644 --- a/tests/llm_responses_api_testing/test_base_responses_api_streaming_iterator.py +++ b/tests/llm_responses_api_testing/test_base_responses_api_streaming_iterator.py @@ -231,7 +231,7 @@ class TestBaseResponsesAPIStreamingIterator: mock_logging_obj = Mock(spec=LiteLLMLoggingObj) mock_logging_obj.model_call_details = {"litellm_params": {}} mock_config = Mock(spec=BaseResponsesAPIConfig) - + # Create the iterator instance iterator = BaseResponsesAPIStreamingIterator( response=mock_response, @@ -239,11 +239,73 @@ class TestBaseResponsesAPIStreamingIterator: responses_api_provider_config=mock_config, logging_obj=mock_logging_obj ) - + # Test with empty chunk result = iterator._process_chunk("") assert result is None - + # Test with None chunk result = iterator._process_chunk(None) - assert result is None \ No newline at end of file + assert result is None + + def test_handle_logging_completed_response_with_unpickleable_objects(self): + """ + Test that _handle_logging_completed_response handles responses containing + objects that cannot be pickled (like Pydantic ValidatorIterator). + + This test verifies the fix for issue #17192 where streaming with tool_choice + containing allowed_tools would fail with: + "cannot pickle 'pydantic_core._pydantic_core.ValidatorIterator' object" + + The fix uses model_dump + model_validate instead of copy.deepcopy. + """ + import asyncio + from litellm.responses.streaming_iterator import ResponsesAPIStreamingIterator + + # Mock dependencies + mock_response = Mock() + mock_response.headers = {} + mock_response.aiter_lines = Mock() + mock_logging_obj = Mock(spec=LiteLLMLoggingObj) + mock_logging_obj.model_call_details = {"litellm_params": {}} + mock_logging_obj.async_success_handler = Mock() + mock_logging_obj.success_handler = Mock() + mock_config = Mock(spec=BaseResponsesAPIConfig) + + # Create the iterator instance + iterator = ResponsesAPIStreamingIterator( + response=mock_response, + model="gpt-4", + responses_api_provider_config=mock_config, + logging_obj=mock_logging_obj, + litellm_metadata={"model_info": {"id": "model_123"}}, + custom_llm_provider="openai" + ) + + # Create a ResponseCompletedEvent with tool_choice that has model_dump + mock_completed_response = Mock() + mock_completed_response.model_dump.return_value = { + "type": "response.completed", + "response": { + "id": "resp_123", + "output": [{"type": "function_call", "name": "search_web"}], + "tool_choice": {"type": "function", "name": "search_web"} + } + } + # model_validate should return a new mock (the copy) + type(mock_completed_response).model_validate = Mock(return_value=Mock()) + + iterator.completed_response = mock_completed_response + + # This should NOT raise an exception + # Previously it would fail with: TypeError: cannot pickle 'ValidatorIterator' + # Mock asyncio.create_task and executor.submit since we're not in async context + with patch('asyncio.create_task') as mock_create_task, \ + patch('litellm.responses.streaming_iterator.executor') as mock_executor: + try: + iterator._handle_logging_completed_response() + except TypeError as e: + if "pickle" in str(e): + pytest.fail(f"_handle_logging_completed_response failed with pickle error: {e}") + raise + diff --git a/tests/pass_through_unit_tests/test_bedrock_tool_use_beta_header.py b/tests/pass_through_unit_tests/test_bedrock_tool_use_beta_header.py new file mode 100644 index 0000000000..635ace016f --- /dev/null +++ b/tests/pass_through_unit_tests/test_bedrock_tool_use_beta_header.py @@ -0,0 +1,69 @@ +""" +Simple E2E test for Bedrock with advanced-tool-use beta header. + +Tests that LiteLLM correctly filters out the advanced-tool-use-2025-11-20 beta header +for Bedrock Invoke API, which doesn't support it and returns a 400 "invalid beta flag" error. +""" +import os +import sys +import pytest + +sys.path.insert(0, os.path.abspath("../..")) + +import litellm + + +@pytest.mark.asyncio +async def test_bedrock_sonnet_4_5_with_advanced_tool_use_beta_header(): + """ + Simple E2E test: Call Bedrock Sonnet 4.5 with advanced-tool-use beta header. + + This should work without throwing "invalid beta flag" error because LiteLLM + filters out the advanced-tool-use beta header for Bedrock Invoke API. + """ + litellm._turn_on_debug() + response = await litellm.anthropic.messages.acreate( + model="bedrock/invoke/us.anthropic.claude-sonnet-4-5-20250929-v1:0", + messages=[{"role": "user", "content": "What is 2+2?"}], + max_tokens=100, + provider_specific_header={ + "custom_llm_provider": "bedrock", + "extra_headers": { + "anthropic-beta": "advanced-tool-use-2025-11-20", + }, + }, + ) + + # Verify response + assert response is not None + assert "content" in response + print(f"✅ Test passed! Response: {response}") + + +@pytest.mark.asyncio +async def test_bedrock_claude_3_5_with_advanced_tool_use_beta_header_filtered(): + """ + Simple E2E test: Call Bedrock Claude 3.5 with advanced-tool-use beta header. + + This should work because the beta header is filtered out by LiteLLM before + sending the request to Bedrock Invoke API. + """ + + response = await litellm.anthropic.messages.acreate( + model="bedrock/invoke/us.anthropic.claude-3-5-sonnet-20240620-v1:0", + messages=[{"role": "user", "content": "What is 2+2?"}], + max_tokens=100, + provider_specific_header={ + "custom_llm_provider": "bedrock", + "extra_headers": { + "anthropic-beta": "advanced-tool-use-2025-11-20", + }, + }, + ) + + # Verify response + assert response is not None + assert "content" in response + print(f"✅ Test passed! Claude 3.5 response (beta header filtered): {response}") + + diff --git a/tests/pass_through_unit_tests/test_claude_code_marketplace.py b/tests/pass_through_unit_tests/test_claude_code_marketplace.py new file mode 100644 index 0000000000..b4ba30e9c7 --- /dev/null +++ b/tests/pass_through_unit_tests/test_claude_code_marketplace.py @@ -0,0 +1,145 @@ +""" +Tests for Claude Code Marketplace endpoints. + +Tests: +1. Register a plugin +2. Get marketplace.json (list enabled plugins) +""" + +import os +import sys +import time + +import pytest + +sys.path.insert(0, os.path.abspath("../..")) + +import litellm +from litellm.proxy._types import UserAPIKeyAuth +from litellm.proxy.proxy_server import LitellmUserRoles +from litellm.proxy.utils import PrismaClient, ProxyLogging +from litellm.caching.caching import DualCache +from litellm.types.proxy.claude_code_endpoints import RegisterPluginRequest + +# Import the functions we're testing +from litellm.proxy.anthropic_endpoints.claude_code_endpoints.claude_code_marketplace import ( + register_plugin, + get_marketplace, +) + +proxy_logging_obj = ProxyLogging(user_api_key_cache=DualCache()) + + +@pytest.fixture +def prisma_client(): + from litellm.proxy.proxy_cli import append_query_params + + params = {"connection_limit": 100, "pool_timeout": 60} + database_url = os.getenv("DATABASE_URL") + modified_url = append_query_params(database_url, params) + os.environ["DATABASE_URL"] = modified_url + + prisma_client = PrismaClient( + database_url=os.environ["DATABASE_URL"], proxy_logging_obj=proxy_logging_obj + ) + + litellm.proxy.proxy_server.litellm_proxy_budget_name = ( + f"litellm-proxy-budget-{time.time()}" + ) + + return prisma_client + + +@pytest.mark.asyncio +async def test_register_plugin(prisma_client): + """Test registering a plugin in the marketplace.""" + setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) + setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") + + await litellm.proxy.proxy_server.prisma_client.connect() + + # Create a unique plugin name for this test + plugin_name = f"test-plugin-{int(time.time())}" + + request = RegisterPluginRequest( + name=plugin_name, + source={"source": "github", "repo": "test-org/test-repo"}, + version="1.0.0", + description="Test plugin for unit tests", + ) + + user_api_key_dict = UserAPIKeyAuth( + user_role=LitellmUserRoles.PROXY_ADMIN, + api_key="sk-1234", + user_id="test-user", + ) + + response = await register_plugin( + request=request, + user_api_key_dict=user_api_key_dict, + ) + + assert response["status"] == "success" + assert response["action"] == "created" + assert response["plugin"]["name"] == plugin_name + assert response["plugin"]["version"] == "1.0.0" + assert response["plugin"]["enabled"] is True + + # Cleanup - delete the plugin + await prisma_client.db.litellm_claudecodeplugintable.delete( + where={"name": plugin_name} + ) + + +@pytest.mark.asyncio +async def test_get_marketplace(prisma_client): + """Test getting marketplace.json with registered plugins.""" + setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) + setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") + + await litellm.proxy.proxy_server.prisma_client.connect() + + # First register a plugin + plugin_name = f"test-marketplace-plugin-{int(time.time())}" + + request = RegisterPluginRequest( + name=plugin_name, + source={"source": "github", "repo": "test-org/marketplace-test"}, + version="2.0.0", + description="Test plugin for marketplace test", + ) + + user_api_key_dict = UserAPIKeyAuth( + user_role=LitellmUserRoles.PROXY_ADMIN, + api_key="sk-1234", + user_id="test-user", + ) + + await register_plugin( + request=request, + user_api_key_dict=user_api_key_dict, + ) + + # Now get the marketplace + response = await get_marketplace() + + # Response is a JSONResponse, get the body + import json + body = json.loads(response.body.decode()) + + assert body["name"] == "litellm" + assert "plugins" in body + + # Find our plugin in the list + our_plugin = next( + (p for p in body["plugins"] if p["name"] == plugin_name), + None + ) + assert our_plugin is not None + assert our_plugin["source"] == {"source": "github", "repo": "test-org/marketplace-test"} + assert our_plugin["version"] == "2.0.0" + + # Cleanup + await prisma_client.db.litellm_claudecodeplugintable.delete( + where={"name": plugin_name} + ) diff --git a/tests/test_litellm/litellm_core_utils/prompt_templates/test_litellm_core_utils_prompt_templates_factory.py b/tests/test_litellm/litellm_core_utils/prompt_templates/test_litellm_core_utils_prompt_templates_factory.py index 42a2b5d097..a22fe13798 100644 --- a/tests/test_litellm/litellm_core_utils/prompt_templates/test_litellm_core_utils_prompt_templates_factory.py +++ b/tests/test_litellm/litellm_core_utils/prompt_templates/test_litellm_core_utils_prompt_templates_factory.py @@ -1392,3 +1392,134 @@ def test_anthropic_messages_pt_server_tool_use_passthrough(): b for b in assistant_msg["content"] if b.get("type") == "text" ) assert text_block["text"] == "I found the time tool. How can I help you?" + + +def test_bedrock_tools_unpack_defs_no_oom_with_nested_refs(): + """ + Regression test for issue #19098: unpack_defs() causes OOM with nested tool schemas. + + The old implementation had a "flatten defs" loop that would pre-expand each def + using unpack_defs(), but since defs often reference each other, each subsequent + call would copy already-expanded content, causing exponential memory growth. + + This test creates a schema with multiple nested $defs that reference each other + to verify the fix prevents memory explosion while still correctly resolving refs. + """ + import sys + import copy + + from litellm.litellm_core_utils.prompt_templates.factory import _bedrock_tools_pt + + # Schema with multiple nested $defs that reference each other + # This pattern would cause OOM with the old "flatten defs" loop + complex_nested_schema = { + "type": "object", + "properties": { + "query": {"$ref": "#/$defs/Expression"}, + }, + "$defs": { + "Expression": { + "type": "object", + "properties": { + "type": {"type": "string", "enum": ["and", "or", "not", "comparison"]}, + "left": {"$ref": "#/$defs/Operand"}, + "right": {"$ref": "#/$defs/Operand"}, + "operator": {"$ref": "#/$defs/Operator"}, + }, + }, + "Operand": { + "type": "object", + "anyOf": [ + {"$ref": "#/$defs/Literal"}, + {"$ref": "#/$defs/FieldRef"}, + {"$ref": "#/$defs/Expression"}, # Circular: Operand -> Expression -> Operand + ], + }, + "Literal": { + "type": "object", + "properties": { + "type": {"type": "string", "const": "literal"}, + "value": {"$ref": "#/$defs/LiteralValue"}, + }, + }, + "LiteralValue": { + "oneOf": [ + {"type": "string"}, + {"type": "number"}, + {"type": "boolean"}, + {"type": "null"}, + ], + }, + "FieldRef": { + "type": "object", + "properties": { + "type": {"type": "string", "const": "field"}, + "name": {"type": "string"}, + "table": {"$ref": "#/$defs/TableRef"}, + }, + }, + "TableRef": { + "type": "object", + "properties": { + "name": {"type": "string"}, + "alias": {"type": "string"}, + }, + }, + "Operator": { + "type": "string", + "enum": ["=", "!=", "<", ">", "<=", ">=", "LIKE", "IN"], + }, + }, + } + + tools = [ + { + "type": "function", + "function": { + "name": "execute_query", + "description": "Execute a query with complex expressions", + "parameters": complex_nested_schema, + }, + } + ] + + # Measure initial size + def get_size(obj, seen=None): + size = sys.getsizeof(obj) + if seen is None: + seen = set() + obj_id = id(obj) + if obj_id in seen: + return 0 + seen.add(obj_id) + if isinstance(obj, dict): + size += sum([get_size(v, seen) for v in obj.values()]) + size += sum([get_size(k, seen) for k in obj.keys()]) + elif hasattr(obj, "__iter__") and not isinstance(obj, (str, bytes, bytearray)): + size += sum([get_size(i, seen) for i in obj]) + return size + + initial_size = get_size(tools) + + # Process through _bedrock_tools_pt - this should complete without OOM + tools_copy = copy.deepcopy(tools) + result = _bedrock_tools_pt(tools=tools_copy) + + final_size = get_size(result) + + # The expansion factor should be reasonable (< 100x), not exponential (35000x as in #19098) + expansion_factor = final_size / initial_size + assert expansion_factor < 100, ( + f"Memory expansion factor {expansion_factor:.1f}x is too high. " + f"Initial: {initial_size} bytes, Final: {final_size} bytes" + ) + + # Verify the result is valid Bedrock tools format + assert isinstance(result, list) + assert len(result) == 1 + assert "toolSpec" in result[0] + assert result[0]["toolSpec"]["name"] == "execute_query" + + # Verify $defs have been removed (Bedrock doesn't support them) + tool_schema = result[0]["toolSpec"].get("inputSchema", {}).get("json", {}) + assert "$defs" not in tool_schema, "$defs should be removed after expansion" diff --git a/tests/test_litellm/llms/anthropic/chat/test_anthropic_chat_transformation.py b/tests/test_litellm/llms/anthropic/chat/test_anthropic_chat_transformation.py index e7a123aa8c..7e96c4634f 100644 --- a/tests/test_litellm/llms/anthropic/chat/test_anthropic_chat_transformation.py +++ b/tests/test_litellm/llms/anthropic/chat/test_anthropic_chat_transformation.py @@ -681,6 +681,73 @@ def test_anthropic_chat_headers_add_context_management_beta(): assert headers["anthropic-beta"] == "context-management-2025-06-27" +def test_anthropic_beta_header_merging_with_output_format(): + """ + Test that anthropic-beta headers from extra_headers are merged with + output_format beta headers instead of being overridden. + + This is a regression test for: https://github.com/BerriAI/litellm/issues/... + When using response_format with a Pydantic model AND extra_headers with + anthropic-beta (e.g., for context-1m extension), both beta headers should + be present in the final request. + """ + config = AnthropicConfig() + + # Simulate headers that already have the context-1m beta header from extra_headers + headers = {"anthropic-beta": "context-1m-2025-08-07"} + + # Simulate output_format being set (happens when using response_format with Sonnet 4.5) + optional_params = { + "output_format": { + "type": "json_schema", + "schema": {"type": "object", "properties": {}} + } + } + + result_headers = config.update_headers_with_optional_anthropic_beta( + headers, optional_params + ) + + # Both beta headers should be present + beta_value = result_headers["anthropic-beta"] + assert "context-1m-2025-08-07" in beta_value, \ + f"User's context-1m beta header missing from: {beta_value}" + assert "structured-outputs-2025-11-13" in beta_value, \ + f"Structured output beta header missing from: {beta_value}" + + +def test_anthropic_beta_header_merging_with_multiple_features(): + """ + Test that multiple beta headers can be merged when using multiple features. + """ + config = AnthropicConfig() + + # Start with a user-provided beta header + headers = {"anthropic-beta": "context-1m-2025-08-07"} + + # Use multiple features that require beta headers + optional_params = { + "output_format": { + "type": "json_schema", + "schema": {"type": "object", "properties": {}} + }, + "context_management": _sample_context_management_payload(), + "tools": [{"type": "web_fetch_20250910", "name": "web_fetch"}] + } + + result_headers = config.update_headers_with_optional_anthropic_beta( + headers, optional_params + ) + + beta_value = result_headers["anthropic-beta"] + + # All beta headers should be present + assert "context-1m-2025-08-07" in beta_value + assert "structured-outputs-2025-11-13" in beta_value + assert "context-management-2025-06-27" in beta_value + assert "web-fetch-2025-09-10" in beta_value + + def test_anthropic_chat_transform_request_includes_context_management(): config = AnthropicConfig() headers = {} diff --git a/tests/test_litellm/llms/vertex_ai/files/test_vertex_ai_binary_file_upload.py b/tests/test_litellm/llms/vertex_ai/files/test_vertex_ai_binary_file_upload.py new file mode 100644 index 0000000000..ceea3d0b16 --- /dev/null +++ b/tests/test_litellm/llms/vertex_ai/files/test_vertex_ai_binary_file_upload.py @@ -0,0 +1,260 @@ +""" +Test Vertex AI binary file upload functionality + +This test ensures that binary files (like PDFs, images) are correctly handled +during upload without attempting UTF-8 decoding, which would cause errors. + +Regression test for: UTF-8 codec error when uploading binary files +""" + +import io +import pytest +from unittest.mock import AsyncMock, MagicMock, patch + +import httpx + +from litellm.llms.custom_httpx.llm_http_handler import AsyncHTTPHandler +from litellm.llms.vertex_ai.files.transformation import VertexAIFilesConfig +from litellm.types.llms.openai import CreateFileRequest + + +class TestVertexAIBinaryFileUpload: + """Test binary file upload handling for Vertex AI""" + + def setup_method(self): + """Setup test method""" + self.http_handler = AsyncHTTPHandler() + self.vertex_config = VertexAIFilesConfig() + + @pytest.mark.asyncio + async def test_pdf_file_upload_bytes_handling(self): + """ + Test that PDF binary data is correctly handled without UTF-8 decoding. + + This is a regression test for the error: + 'utf-8' codec can't decode byte 0xc4 in position 10: invalid continuation byte + """ + # Create mock PDF binary data (with non-UTF-8 bytes) + # PDF files start with %PDF- and contain binary data + mock_pdf_content = b"%PDF-1.4\n%\xc4\xe5\xf2\xe5\xeb\xa7\xf3\xa0\xd0\xc4\xc6\n" + mock_pdf_content += b"\x00\x01\x02\x03\xff\xfe\xfd" * 100 # Add more binary data + + # Create file object + file_obj = io.BytesIO(mock_pdf_content) + file_obj.name = "test_document.pdf" + + # Create file request + create_file_data: CreateFileRequest = { + "file": file_obj, + "purpose": "user_data", + } + + # Transform the request + transformed_request = self.vertex_config.transform_create_file_request( + model="vertex_ai/gemini-flash", + create_file_data=create_file_data, + optional_params={}, + litellm_params={}, + ) + + # Verify the transformation returns bytes (not string) + assert isinstance(transformed_request, bytes), ( + f"Expected bytes for binary file, got {type(transformed_request)}" + ) + + # Verify the bytes match the original content + assert transformed_request == mock_pdf_content, ( + "Transformed request should preserve binary content exactly" + ) + + # Verify that the bytes contain non-UTF-8 characters + # This should raise UnicodeDecodeError if we try to decode + with pytest.raises(UnicodeDecodeError): + transformed_request.decode("utf-8") + + @pytest.mark.asyncio + async def test_image_file_upload_bytes_handling(self): + """Test that image binary data (PNG) is correctly handled""" + # Create mock PNG binary data (PNG signature + some binary data) + mock_png_content = b"\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR" + mock_png_content += b"\x00\x01\x02\x03\xff\xfe\xfd" * 50 + + file_obj = io.BytesIO(mock_png_content) + file_obj.name = "test_image.png" + + create_file_data: CreateFileRequest = { + "file": file_obj, + "purpose": "user_data", + } + + transformed_request = self.vertex_config.transform_create_file_request( + model="vertex_ai/gemini-flash", + create_file_data=create_file_data, + optional_params={}, + litellm_params={}, + ) + + # Verify bytes are preserved + assert isinstance(transformed_request, bytes) + assert transformed_request == mock_png_content + + @pytest.mark.asyncio + async def test_http_handler_accepts_bytes_without_decoding(self): + """ + Test that httpx correctly accepts binary data without decoding. + + This test verifies that bytes can be passed to httpx's post/put methods + without needing UTF-8 decoding, which is the core of our fix. + """ + # Create mock binary data with non-UTF-8 bytes + mock_binary_data = b"\x00\x01\x02\x03\xff\xfe\xfd\xc4\xe5\xf2" + + # Test that httpx accepts bytes in the data parameter + # We're testing the behavior, not making an actual request + + # Verify that attempting to decode would fail (proving it's binary) + with pytest.raises(UnicodeDecodeError): + mock_binary_data.decode("utf-8") + + # Verify that httpx Request accepts bytes + try: + request = httpx.Request( + method="POST", + url="https://example.com/upload", + data=mock_binary_data, + headers={"Content-Type": "application/octet-stream"}, + ) + # If we get here, httpx accepts bytes - which is what we need + assert request.content == mock_binary_data + except Exception as e: + pytest.fail(f"httpx should accept bytes in data parameter: {e}") + + # Document the expected behavior + assert isinstance(mock_binary_data, bytes), ( + "Binary file data should remain as bytes" + ) + + @pytest.mark.asyncio + async def test_jsonl_file_upload_returns_string(self): + """ + Test that JSONL files (text) are correctly transformed to strings. + + This ensures we handle both binary and text files correctly. + """ + # Create mock JSONL content + mock_jsonl_content = ( + '{"custom_id": "req-1", "method": "POST", "url": "/v1/chat/completions", ' + '"body": {"model": "gemini-flash", "messages": [{"role": "user", "content": "Hello"}]}}\n' + ) + + file_obj = io.BytesIO(mock_jsonl_content.encode("utf-8")) + file_obj.name = "batch_requests.jsonl" + + create_file_data: CreateFileRequest = { + "file": file_obj, + "purpose": "batch", + } + + transformed_request = self.vertex_config.transform_create_file_request( + model="vertex_ai/gemini-flash", + create_file_data=create_file_data, + optional_params={}, + litellm_params={}, + ) + + # JSONL files should be transformed to string + assert isinstance(transformed_request, str), ( + f"Expected string for JSONL file, got {type(transformed_request)}" + ) + + @pytest.mark.asyncio + async def test_mixed_file_types_in_sequence(self): + """ + Test uploading different file types in sequence to ensure no state pollution. + """ + # Test 1: Upload binary file + binary_content = b"\x00\x01\x02\x03\xff\xfe\xfd" + binary_file = io.BytesIO(binary_content) + binary_file.name = "binary.dat" + + binary_request: CreateFileRequest = { + "file": binary_file, + "purpose": "user_data", + } + + result1 = self.vertex_config.transform_create_file_request( + model="vertex_ai/gemini-flash", + create_file_data=binary_request, + optional_params={}, + litellm_params={}, + ) + assert isinstance(result1, bytes) + + # Test 2: Upload JSONL file + jsonl_content = '{"test": "data"}\n' + jsonl_file = io.BytesIO(jsonl_content.encode("utf-8")) + jsonl_file.name = "batch.jsonl" + + jsonl_request: CreateFileRequest = { + "file": jsonl_file, + "purpose": "batch", + } + + result2 = self.vertex_config.transform_create_file_request( + model="vertex_ai/gemini-flash", + create_file_data=jsonl_request, + optional_params={}, + litellm_params={}, + ) + assert isinstance(result2, str) + + # Test 3: Upload another binary file + binary_content2 = b"\xc4\xe5\xf2\xe5\xeb" + binary_file2 = io.BytesIO(binary_content2) + binary_file2.name = "binary2.dat" + + binary_request2: CreateFileRequest = { + "file": binary_file2, + "purpose": "user_data", + } + + result3 = self.vertex_config.transform_create_file_request( + model="vertex_ai/gemini-flash", + create_file_data=binary_request2, + optional_params={}, + litellm_params={}, + ) + assert isinstance(result3, bytes) + + def test_bytes_type_preservation_documentation(self): + """ + Documentation test: Verify that bytes are the correct type for binary uploads. + + This test documents the expected behavior: + - Binary files (PDF, images, etc.) should remain as bytes + - Text files (JSONL) should be strings + - httpx accepts both bytes and strings in the 'data' parameter + - bytes should NEVER be decoded to UTF-8 for binary files + """ + # This is a documentation test - it always passes + # but serves as a reference for the expected behavior + + expected_behavior = { + "binary_files": { + "input_type": "bytes", + "output_type": "bytes", + "examples": ["PDF", "PNG", "JPEG", "binary data"], + "http_method": "POST or PUT", + "encoding": "none - preserve raw bytes", + }, + "text_files": { + "input_type": "str or bytes", + "output_type": "str", + "examples": ["JSONL", "CSV", "TXT"], + "http_method": "POST", + "encoding": "UTF-8", + }, + } + + assert expected_behavior["binary_files"]["encoding"] == "none - preserve raw bytes" + assert expected_behavior["text_files"]["encoding"] == "UTF-8" diff --git a/tests/test_litellm/proxy/agent_endpoints/test_a2a_endpoints.py b/tests/test_litellm/proxy/agent_endpoints/test_a2a_endpoints.py index 061e27da91..9588c3b55c 100644 --- a/tests/test_litellm/proxy/agent_endpoints/test_a2a_endpoints.py +++ b/tests/test_litellm/proxy/agent_endpoints/test_a2a_endpoints.py @@ -49,18 +49,20 @@ async def test_invoke_agent_a2a_adds_litellm_data(): # Mock request mock_request = MagicMock() - mock_request.json = AsyncMock(return_value={ - "jsonrpc": "2.0", - "id": "test-id", - "method": "message/send", - "params": { - "message": { - "role": "user", - "parts": [{"kind": "text", "text": "Hello"}], - "messageId": "msg-123", - } - }, - }) + mock_request.json = AsyncMock( + return_value={ + "jsonrpc": "2.0", + "id": "test-id", + "method": "message/send", + "params": { + "message": { + "role": "user", + "parts": [{"kind": "text", "text": "Hello"}], + "messageId": "msg-123", + } + }, + } + ) mock_user_api_key_dict = UserAPIKeyAuth( api_key="sk-test-key", @@ -77,40 +79,44 @@ async def test_invoke_agent_a2a_adds_litellm_data(): SendMessageRequest, SendStreamingMessageRequest, ) + # Real types available - use them - use_real_types = True + pass except ImportError: # Real types not available - create realistic mocks - use_real_types = False - + pass + def make_mock_pydantic_class(name): """Create a mock class that behaves like a Pydantic model.""" + class MockPydanticClass: def __init__(self, **kwargs): self.__dict__.update(kwargs) # Store kwargs for model_dump() if needed self._kwargs = kwargs - + def model_dump(self, mode="json", exclude_none=False): """Mock model_dump method.""" result = dict(self._kwargs) if exclude_none: result = {k: v for k, v in result.items() if v is not None} return result - + MockPydanticClass.__name__ = name return MockPydanticClass - + MessageSendParams = make_mock_pydantic_class("MessageSendParams") SendMessageRequest = make_mock_pydantic_class("SendMessageRequest") - SendStreamingMessageRequest = make_mock_pydantic_class("SendStreamingMessageRequest") - + SendStreamingMessageRequest = make_mock_pydantic_class( + "SendStreamingMessageRequest" + ) + # Create a mock module for a2a.types mock_a2a_types = MagicMock() mock_a2a_types.MessageSendParams = MessageSendParams mock_a2a_types.SendMessageRequest = SendMessageRequest mock_a2a_types.SendStreamingMessageRequest = SendStreamingMessageRequest - + # Patch at the source modules with patch( "litellm.proxy.agent_endpoints.a2a_endpoints._get_agent", @@ -137,12 +143,15 @@ async def test_invoke_agent_a2a_adds_litellm_data(): ), patch.dict( sys.modules, {"a2a": MagicMock(), "a2a.types": mock_a2a_types}, + ), patch( + "litellm.a2a_protocol.main.A2A_SDK_AVAILABLE", + True, ): from litellm.proxy.agent_endpoints.a2a_endpoints import invoke_agent_a2a mock_fastapi_response = MagicMock() - result = await invoke_agent_a2a( + await invoke_agent_a2a( agent_id="test-agent", request=mock_request, fastapi_response=mock_fastapi_response, diff --git a/tests/test_litellm/proxy/openai_files_endpoint/test_files_endpoint.py b/tests/test_litellm/proxy/openai_files_endpoint/test_files_endpoint.py index 4651bf59b4..b86f927ea0 100644 --- a/tests/test_litellm/proxy/openai_files_endpoint/test_files_endpoint.py +++ b/tests/test_litellm/proxy/openai_files_endpoint/test_files_endpoint.py @@ -856,3 +856,97 @@ def test_create_file_without_expires_after(mocker: MockerFixture, monkeypatch, l result = response.json() assert result["id"] == "file-abc123" assert result["purpose"] == "fine-tune" + + +def test_managed_files_with_loadbalancing(mocker: MockerFixture, monkeypatch, llm_router: Router): + """ + Test that managed files work with loadbalancing when both target_model_names + and enable_loadbalancing_on_batch_endpoints are enabled. + + This ensures that the priority order is correct: + - managed files should take precedence over deprecated loadbalancing + - managed files internally use llm_router.acreate_file() which provides loadbalancing + """ + from litellm.llms.base_llm.files.transformation import BaseFileEndpoints + from litellm.types.llms.openai import OpenAIFileObject + + # Enable loadbalancing on batch endpoints + monkeypatch.setattr("litellm.enable_loadbalancing_on_batch_endpoints", True) + + proxy_logging_obj = ProxyLogging( + user_api_key_cache=DualCache(default_in_memory_ttl=1) + ) + proxy_logging_obj._add_proxy_hooks(llm_router) + + # Track calls to verify loadbalancing through router + router_acreate_file_calls = [] + + class ManagedFilesWithLoadbalancing(BaseFileEndpoints): + async def acreate_file(self, llm_router, create_file_request, target_model_names_list, litellm_parent_otel_span, user_api_key_dict): + # Verify we receive the target model names + assert len(target_model_names_list) > 0, "Should have target_model_names_list" + + # Simulate what managed files does - call llm_router.acreate_file for each model + # This is where loadbalancing happens internally + for model in target_model_names_list: + router_acreate_file_calls.append({ + "model": model, + "via_router": True + }) + + # Return a managed file ID (base64 encoded) + return OpenAIFileObject( + id="litellm_managed_file_abc123", + object="file", + bytes=100, + created_at=1234567890, + filename="batch_data.jsonl", + purpose="batch", + status="uploaded", + ) + + async def afile_retrieve(self, file_id, litellm_parent_otel_span, llm_router): + raise NotImplementedError("Not implemented for test") + + async def afile_list(self, purpose, litellm_parent_otel_span): + raise NotImplementedError("Not implemented for test") + + async def afile_delete(self, file_id, litellm_parent_otel_span, llm_router, **data): + raise NotImplementedError("Not implemented for test") + + async def afile_content(self, file_id, litellm_parent_otel_span, llm_router, **data): + raise NotImplementedError("Not implemented for test") + + proxy_logging_obj.proxy_hook_mapping["managed_files"] = ManagedFilesWithLoadbalancing() + monkeypatch.setattr("litellm.proxy.proxy_server.llm_router", llm_router) + monkeypatch.setattr( + "litellm.proxy.proxy_server.proxy_logging_obj", proxy_logging_obj + ) + + # Create batch file content + test_file_content = b'{"custom_id": "request-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gpt-3.5-turbo", "messages": [{"role": "user", "content": "Hello"}]}}' + test_file = ("batch_data.jsonl", test_file_content, "application/jsonl") + + # Make request with both target_model_names AND enable_loadbalancing_on_batch_endpoints + response = client.post( + "/v1/files", + files={"file": test_file}, + data={ + "purpose": "batch", + "target_model_names": "azure-gpt-3-5-turbo,gpt-3.5-turbo", # Multiple models + }, + headers={"Authorization": "Bearer test-key"}, + ) + + # Verify success + assert response.status_code == 200 + result = response.json() + assert result["id"] == "litellm_managed_file_abc123" + assert result["purpose"] == "batch" + + # Verify that managed files was called (via router for loadbalancing) + # This proves that managed files took precedence over deprecated loadbalancing + assert len(router_acreate_file_calls) == 2, "Should have called router for both models" + assert router_acreate_file_calls[0]["model"] == "azure-gpt-3-5-turbo" + assert router_acreate_file_calls[1]["model"] == "gpt-3.5-turbo" + assert all(call["via_router"] for call in router_acreate_file_calls), "All calls should go through router" diff --git a/tests/test_litellm/proxy/vector_store_endpoints/test_vector_store_endpoints.py b/tests/test_litellm/proxy/vector_store_endpoints/test_vector_store_endpoints.py index 352e84719f..558fe18ae3 100644 --- a/tests/test_litellm/proxy/vector_store_endpoints/test_vector_store_endpoints.py +++ b/tests/test_litellm/proxy/vector_store_endpoints/test_vector_store_endpoints.py @@ -1051,6 +1051,189 @@ async def test_vector_store_synchronization_across_instances(): ) +@pytest.mark.asyncio +async def test_vector_store_update_and_list_synchronization(): + """ + Test that vector store updates are properly synchronized across multiple instances. + + This test simulates the scenario where: + 1. Instance 1 creates a vector store + 2. Instance 2 caches it in memory + 3. Instance 1 updates the vector store in the database + 4. Instance 2 should see the updated data when listing (database is source of truth) + + This is a regression test to prevent the bug where Instance 2 would show + stale cached data instead of the updated database version. + """ + from datetime import datetime, timezone + from unittest.mock import AsyncMock, MagicMock + + from litellm.types.vector_stores import LiteLLM_ManagedVectorStore + from litellm.vector_stores.vector_store_registry import VectorStoreRegistry + + # Simulate two instances with separate in-memory registries + instance_1_registry = VectorStoreRegistry(vector_stores=[]) + instance_2_registry = VectorStoreRegistry(vector_stores=[]) + + # Mock database that both instances share + mock_db_vector_stores = [] + + async def mock_find_many(order=None): + """Mock find_many for listing vector stores""" + result = [] + for vs in mock_db_vector_stores: + class MockVectorStore: + def __init__(self, data): + for key, value in data.items(): + setattr(self, key, value) + self._data = data + + def __iter__(self): + return iter(self._data.items()) + result.append(MockVectorStore(vs)) + return result + + async def mock_create(data): + """Mock create for adding vector store to DB""" + vector_store = data.copy() + mock_db_vector_stores.append(vector_store) + mock_obj = MagicMock() + mock_obj.model_dump.return_value = vector_store + return mock_obj + + async def mock_update(where, data): + """Mock update for modifying vector store in DB""" + vector_store_id = where.get("vector_store_id") + for i, vs in enumerate(mock_db_vector_stores): + if vs.get("vector_store_id") == vector_store_id: + # Update the vector store + mock_db_vector_stores[i].update(data) + mock_obj = MagicMock() + mock_obj.model_dump.return_value = mock_db_vector_stores[i] + return mock_obj + raise Exception(f"Vector store {vector_store_id} not found") + + # Create mock prisma client + mock_prisma_client = MagicMock() + mock_prisma_client.db.litellm_managedvectorstorestable.find_many = AsyncMock( + side_effect=mock_find_many + ) + mock_prisma_client.db.litellm_managedvectorstorestable.create = AsyncMock( + side_effect=mock_create + ) + mock_prisma_client.db.litellm_managedvectorstorestable.update = AsyncMock( + side_effect=mock_update + ) + + # Test vector store data + test_vector_store_id = "test-update-store-001" + original_name = "Original Name" + updated_name = "Updated Name" + + test_vector_store: LiteLLM_ManagedVectorStore = { + "vector_store_id": test_vector_store_id, + "custom_llm_provider": "bedrock", + "vector_store_name": original_name, + "vector_store_description": "Testing update synchronization", + "litellm_params": { + "vector_store_id": test_vector_store_id, + "custom_llm_provider": "bedrock", + "region_name": "us-east-1" + }, + "created_at": datetime.now(timezone.utc), + "updated_at": datetime.now(timezone.utc), + } + + # Step 1: Create vector store on Instance 1 + await mock_prisma_client.db.litellm_managedvectorstorestable.create( + data=test_vector_store + ) + instance_1_registry.add_vector_store_to_registry(vector_store=test_vector_store) + + # Step 2: Instance 2 fetches and caches the vector store + vector_stores_from_db = await VectorStoreRegistry._get_vector_stores_from_db( + prisma_client=mock_prisma_client + ) + for vs in vector_stores_from_db: + if vs.get("vector_store_id") == test_vector_store_id: + instance_2_registry.add_vector_store_to_registry(vector_store=vs) + + # Verify both instances have the original data + instance_1_vs = instance_1_registry.get_litellm_managed_vector_store_from_registry( + test_vector_store_id + ) + instance_2_vs = instance_2_registry.get_litellm_managed_vector_store_from_registry( + test_vector_store_id + ) + assert instance_1_vs.get("vector_store_name") == original_name + assert instance_2_vs.get("vector_store_name") == original_name + + # Step 3: Instance 1 updates the vector store in the database + # (Simulating what happens in update_vector_store endpoint) + update_data = {"vector_store_name": updated_name} + await mock_prisma_client.db.litellm_managedvectorstorestable.update( + where={"vector_store_id": test_vector_store_id}, + data=update_data + ) + + # Instance 1 updates its own cache + updated_vs_instance_1 = test_vector_store.copy() + updated_vs_instance_1["vector_store_name"] = updated_name + instance_1_registry.update_vector_store_in_registry( + vector_store_id=test_vector_store_id, + updated_data=updated_vs_instance_1 + ) + + # Verify Instance 1 has the updated data + instance_1_vs_after_update = instance_1_registry.get_litellm_managed_vector_store_from_registry( + test_vector_store_id + ) + assert instance_1_vs_after_update.get("vector_store_name") == updated_name + + # Verify Instance 2 still has stale data in cache + instance_2_vs_before_list = instance_2_registry.get_litellm_managed_vector_store_from_registry( + test_vector_store_id + ) + assert instance_2_vs_before_list.get("vector_store_name") == original_name, ( + "Instance 2 should still have stale cached data before list operation" + ) + + # Step 4: Instance 2 calls list endpoint (which should sync with database) + # This simulates what list_vector_stores endpoint does + vector_stores_from_db_after_update = await VectorStoreRegistry._get_vector_stores_from_db( + prisma_client=mock_prisma_client + ) + + # Build map from database vector stores (database is source of truth) + vector_store_map = {} + for vector_store in vector_stores_from_db_after_update: + vector_store_id = vector_store.get("vector_store_id") + if vector_store_id: + vector_store_map[vector_store_id] = vector_store + + # Update in-memory registry with database versions (this is the key fix) + instance_2_registry.update_vector_store_in_registry( + vector_store_id=vector_store_id, + updated_data=vector_store + ) + + # Step 5: Verify Instance 2 now has the updated data + instance_2_vs_after_list = instance_2_registry.get_litellm_managed_vector_store_from_registry( + test_vector_store_id + ) + assert instance_2_vs_after_list.get("vector_store_name") == updated_name, ( + "Instance 2 should have updated data after list operation syncs with database" + ) + + # Verify the list returned the correct data + combined_vector_stores = list(vector_store_map.values()) + assert len(combined_vector_stores) == 1 + assert combined_vector_stores[0].get("vector_store_id") == test_vector_store_id + assert combined_vector_stores[0].get("vector_store_name") == updated_name, ( + "List should return updated data from database" + ) + + @pytest.mark.asyncio async def test_resolve_embedding_config_from_db(): """Test that _resolve_embedding_config_from_db correctly resolves embedding config from database.""" diff --git a/ui/litellm-dashboard/src/components/survey/ClaudeCodeModal.tsx b/ui/litellm-dashboard/src/components/survey/ClaudeCodeModal.tsx index 94527c160c..eac5e8b7a4 100644 --- a/ui/litellm-dashboard/src/components/survey/ClaudeCodeModal.tsx +++ b/ui/litellm-dashboard/src/components/survey/ClaudeCodeModal.tsx @@ -45,7 +45,7 @@ export function ClaudeCodeModal({ isOpen, onClose, onComplete }: ClaudeCodeModal Help us improve your experience

- We'd love to hear about your experience using LiteLLM with Claude Code. Your feedback helps us improve the product for everyone. + We'd love to hear about your experience using LiteLLM with Claude Code. Your feedback helps us improve the product for everyone.

This brief survey takes about 2-3 minutes to complete.